about summary refs log tree commit diff stats
path: root/miasm/expression/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm/expression/expression.py')
-rw-r--r--miasm/expression/expression.py534
1 files changed, 344 insertions, 190 deletions
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index 93094979..c2bf5b8b 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent):
 def str_protected_child(child, parent):
     return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child)
 
-def visit_chk(visitor):
-    "Function decorator launching callback on Expression visit"
-    def wrapped(expr, callback, test_visit=lambda x: True):
-        if (test_visit is not None) and (not test_visit(expr)):
-            return expr
-        expr_new = visitor(expr, callback, test_visit)
-        if expr_new is None:
-            return None
-        expr_new2 = callback(expr_new)
-        return expr_new2
-    return wrapped
-
 
 # Expression display
 
@@ -152,6 +140,49 @@ class DiGraphExpr(DiGraph):
 
         return ""
 
+def is_expr(expr):
+    return isinstance(
+        expr,
+        (
+            ExprInt, ExprId, ExprMem,
+            ExprSlice, ExprCompose, ExprCond,
+            ExprLoc, ExprOp
+        )
+    )
+
+def is_associative(expr):
+    "Return True iff current operation is associative"
+    return (expr.op in ['+', '*', '^', '&', '|'])
+
+def is_commutative(expr):
+    "Return True iff current operation is commutative"
+    return (expr.op in ['+', '*', '^', '&', '|'])
+
+def is_op_segm(expr):
+    """Returns True if is ExprOp and op == 'segm'"""
+    return expr.is_op('segm')
+
+def is_mem_segm(expr):
+    """Returns True if is ExprMem and ptr is_op_segm"""
+    return expr.is_mem() and is_op_segm(expr.ptr)
+
+def canonize_to_exprloc(locdb, expr):
+    """
+    If expr is ExprInt, return ExprLoc with corresponding loc_key
+    Else, return expr
+
+    @expr: Expr instance
+    """
+    if expr.is_int():
+        loc_key = locdb.get_or_create_offset_location(int(expr))
+        ret = ExprLoc(loc_key, expr.size)
+        return ret
+    return expr
+
+def is_function_call(expr):
+    """Returns true if the considered Expr is a function call
+    """
+    return expr.is_op() and expr.op.startswith('call')
 
 @total_ordering
 class LocKey(object):
@@ -183,6 +214,263 @@ class LocKey(object):
     def __str__(self):
         return "loc_key_%d" % self.key
 
+
+class ExprWalkBase(object):
+    """
+    Walk through sub-expressions, call @callback on them.
+    If @callback returns a non None value, stop walk and return this value
+    """
+
+    def __init__(self, callback):
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr.is_int() or expr.is_id() or expr.is_loc():
+            pass
+        elif expr.is_assign():
+            ret = self.visit(expr.dst, *args, **kwargs)
+            if ret:
+                return ret
+            src = self.visit(expr.src, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_cond():
+            ret = self.visit(expr.cond, *args, **kwargs)
+            if ret:
+                return ret
+            ret = self.visit(expr.src1, *args, **kwargs)
+            if ret:
+                return ret
+            ret = self.visit(expr.src2, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_mem():
+            ret = self.visit(expr.ptr, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_slice():
+            ret = self.visit(expr.arg, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_op():
+            for arg in expr.args:
+                ret = self.visit(arg, *args, **kwargs)
+                if ret:
+                    return ret
+        elif expr.is_compose():
+            for arg in expr.args:
+                ret = self.visit(arg, *args, **kwargs)
+                if ret:
+                    return ret
+        else:
+            raise TypeError("Visitor can only take Expr")
+
+        ret = self.callback(expr, *args, **kwargs)
+        return ret
+
+
+class ExprWalk(ExprWalkBase):
+    """
+    Walk through sub-expressions, call @callback on them.
+    If @callback returns a non None value, stop walk and return this value
+    Use cache mechanism.
+    """
+    def __init__(self, callback):
+        self.cache = set()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return None
+        ret = super(ExprWalk, self).visit(expr, *args, **kwargs)
+        if ret:
+            return ret
+        self.cache.add(expr)
+        return None
+
+
+class ExprGetR(ExprWalkBase):
+    """
+    Return ExprId/ExprMem used by a given expression
+    """
+    def __init__(self, mem_read=False, cst_read=False):
+        super(ExprGetR, self).__init__(lambda x:None)
+        self.mem_read = mem_read
+        self.cst_read = cst_read
+        self.elements = set()
+        self.cache = dict()
+
+    def get_r_leaves(self, expr):
+        if (expr.is_int() or expr.is_loc()) and self.cst_read:
+            self.elements.add(expr)
+        elif expr.is_mem():
+            self.elements.add(expr)
+        elif expr.is_id():
+            self.elements.add(expr)
+
+    def visit(self, expr, *args, **kwargs):
+        cache_key = (expr, self.mem_read, self.cst_read)
+        if cache_key in self.cache:
+            return self.cache[cache_key]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[cache_key] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        self.get_r_leaves(expr)
+        if expr.is_mem() and not self.mem_read:
+            # Don't visit memory sons
+            return None
+
+        if expr.is_assign():
+            if expr.dst.is_mem() and self.mem_read:
+                ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs)
+            if expr.src.is_mem():
+                self.elements.add(expr.src)
+            self.get_r_leaves(expr.src)
+            if expr.src.is_mem() and not self.mem_read:
+                return None
+            ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs)
+            return ret
+        ret = super(ExprGetR, self).visit(expr, *args, **kwargs)
+        return ret
+
+
+class ExprVisitorBase(object):
+    """
+    Rebuild expression by visiting sub-expressions
+    """
+    def visit(self, expr, *args, **kwargs):
+        if expr.is_int() or expr.is_id() or expr.is_loc():
+            ret = expr
+        elif expr.is_assign():
+            dst = self.visit(expr.dst, *args, **kwargs)
+            src = self.visit(expr.src, *args, **kwargs)
+            ret = ExprAssign(dst, src)
+        elif expr.is_cond():
+            cond = self.visit(expr.cond, *args, **kwargs)
+            src1 = self.visit(expr.src1, *args, **kwargs)
+            src2 = self.visit(expr.src2, *args, **kwargs)
+            ret = ExprCond(cond, src1, src2)
+        elif expr.is_mem():
+            ptr = self.visit(expr.ptr, *args, **kwargs)
+            ret = ExprMem(ptr, expr.size)
+        elif expr.is_slice():
+            arg = self.visit(expr.arg, *args, **kwargs)
+            ret = ExprSlice(arg, expr.start, expr.stop)
+        elif expr.is_op():
+            args = [self.visit(arg, *args, **kwargs) for arg in expr.args]
+            ret = ExprOp(expr.op, *args)
+        elif expr.is_compose():
+            args = [self.visit(arg, *args, **kwargs) for arg in expr.args]
+            ret = ExprCompose(*args)
+        else:
+            raise TypeError("Visitor can only take Expr")
+        return ret
+
+
+class ExprVisitorCallbackTopToBottom(ExprVisitorBase):
+    """
+    Rebuild expression by visiting sub-expressions
+    Call @callback on each sub-expression
+    if @callback return non None value, replace current node with this value
+    Else, continue visit of sub-expressions
+    """
+    def __init__(self, callback):
+        super(ExprVisitorCallbackTopToBottom, self).__init__()
+        self.cache = dict()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return self.cache[expr]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[expr] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        ret = self.callback(expr)
+        if ret:
+            return ret
+        ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs)
+        return ret
+
+
+class ExprVisitorCallbackBottomToTop(ExprVisitorBase):
+    """
+    Rebuild expression by visiting sub-expressions
+    Call @callback from leaves to root expressions
+    """
+    def __init__(self, callback):
+        super(ExprVisitorCallbackBottomToTop, self).__init__()
+        self.cache = dict()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return self.cache[expr]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[expr] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs)
+        ret = self.callback(ret)
+        return ret
+
+
+class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop):
+    def __init__(self):
+        super(ExprVisitorCanonize, self).__init__(self.canonize)
+
+    def canonize(self, expr):
+        if not expr.is_op():
+            return expr
+        if not expr.is_associative():
+            return expr
+
+        # ((a+b) + c) => (a + b + c)
+        args = []
+        for arg in expr.args:
+            if isinstance(arg, ExprOp) and expr.op == arg.op:
+                args += arg.args
+            else:
+                args.append(arg)
+        args = canonize_expr_list(args)
+        new_expr = ExprOp(expr.op, *args)
+        return new_expr
+
+
+class ExprVisitorContains(ExprWalkBase):
+    """
+    Visitor to test if a needle is in an Expression
+    Cache results
+    """
+    def __init__(self):
+        self.cache = set()
+        super(ExprVisitorContains, self).__init__(self.eq_expr)
+
+    def eq_expr(self, expr, needle, *args, **kwargs):
+        if expr == needle:
+            return True
+        return None
+
+    def visit(self, expr, needle,  *args, **kwargs):
+        if (expr, needle) in self.cache:
+            return None
+        ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs)
+        if ret:
+            return ret
+        self.cache.add((expr, needle))
+        return None
+
+
+    def contains(self, expr, needle):
+        return self.visit(expr, needle)
+
+contains_visitor = ExprVisitorContains()
+canonize_visitor = ExprVisitorCanonize()
+
 # IR definitions
 
 class Expr(object):
@@ -337,36 +625,16 @@ class Expr(object):
         """Find and replace sub expression using dct
         @dct: dictionary associating replaced Expr to its new Expr value
         """
-        return self.visit(lambda expr: dct.get(expr, expr))
+        def replace(expr):
+            if expr in dct:
+                return dct[expr]
+            return None
+        visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr))
+        return visitor.visit(self)
 
     def canonize(self):
         "Canonize the Expression"
-
-        def must_canon(expr):
-            return not expr.is_canon
-
-        def canonize_visitor(expr):
-            if expr.is_canon:
-                return expr
-            if isinstance(expr, ExprOp):
-                if expr.is_associative():
-                    # ((a+b) + c) => (a + b + c)
-                    args = []
-                    for arg in expr.args:
-                        if isinstance(arg, ExprOp) and expr.op == arg.op:
-                            args += arg.args
-                        else:
-                            args.append(arg)
-                    args = canonize_expr_list(args)
-                    new_e = ExprOp(expr.op, *args)
-                else:
-                    new_e = expr
-            else:
-                new_e = expr
-            new_e.is_canon = True
-            return new_e
-
-        return self.visit(canonize_visitor, must_canon)
+        return canonize_visitor.visit(self)
 
     def msb(self):
         "Return the Most Significant Bit"
@@ -424,6 +692,10 @@ class Expr(object):
         return False
 
     def is_aff(self):
+        warnings.warn('DEPRECATION WARNING: use is_assign()')
+        return False
+
+    def is_assign(self):
         return False
 
     def is_cond(self):
@@ -449,6 +721,32 @@ class Expr(object):
         """Returns True if is ExprMem and ptr is_op_segm"""
         return False
 
+    def __contains__(self, expr):
+        ret = contains_visitor.contains(self, expr)
+        return ret
+
+    def visit(self, callback):
+        """
+        Apply callback to all sub expression of @self
+        This function keeps a cache to avoid rerunning @callback on common sub
+        expressions.
+
+        @callback: fn(Expr) -> Expr
+        """
+        visitor = ExprVisitorCallbackBottomToTop(callback)
+        return visitor.visit(self)
+
+    def get_r(self, mem_read=False, cst_read=False):
+        visitor = ExprGetR(mem_read, cst_read)
+        visitor.visit(self)
+        return visitor.elements
+
+
+    def get_w(self, mem_read=False, cst_read=False):
+        if self.is_assign():
+            return set([self.dst])
+        return set()
+
 class ExprInt(Expr):
 
     """An ExprInt represent a constant in Miasm IR.
@@ -508,12 +806,6 @@ class ExprInt(Expr):
         else:
             return str("0x%X" % self._get_int())
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if cst_read:
-            return set([self])
-        else:
-            return set()
-
     def get_w(self):
         return set()
 
@@ -524,13 +816,6 @@ class ExprInt(Expr):
         return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(),
                                  self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprInt(self._arg, self._size)
 
@@ -591,9 +876,6 @@ class ExprId(Expr):
     def __str__(self):
         return str(self._name)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return set([self])
-
     def get_w(self):
         return set([self])
 
@@ -603,13 +885,6 @@ class ExprId(Expr):
     def _exprrepr(self):
         return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprId(self._name, self._size)
 
@@ -653,12 +928,6 @@ class ExprLoc(Expr):
     def __str__(self):
         return str(self._loc_key)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if cst_read:
-            return set([self])
-        else:
-            return set()
-
     def get_w(self):
         return set()
 
@@ -668,13 +937,6 @@ class ExprLoc(Expr):
     def _exprrepr(self):
         return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprLoc(self._loc_key, self._size)
 
@@ -745,12 +1007,6 @@ class ExprAssign(Expr):
     def __str__(self):
         return "%s = %s" % (str(self._dst), str(self._src))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        elements = self._src.get_r(mem_read, cst_read)
-        if isinstance(self._dst, ExprMem) and mem_read:
-            elements.update(self._dst.ptr.get_r(mem_read, cst_read))
-        return elements
-
     def get_w(self):
         if isinstance(self._dst, ExprMem):
             return set([self._dst])  # [memreg]
@@ -763,19 +1019,6 @@ class ExprAssign(Expr):
     def _exprrepr(self):
         return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src)
 
-    def __contains__(self, expr):
-        return (self == expr or
-                self._src.__contains__(expr) or
-                self._dst.__contains__(expr))
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit)
-        if dst == self._dst and src == self._src:
-            return self
-        else:
-            return ExprAssign(dst, src)
-
     def copy(self):
         return ExprAssign(self._dst.copy(), self._src.copy())
 
@@ -788,7 +1031,12 @@ class ExprAssign(Expr):
             arg.graph_recursive(graph)
             graph.add_uniq_edge(self, arg)
 
+
     def is_aff(self):
+        warnings.warn('DEPRECATION WARNING: use is_assign()')
+        return True
+
+    def is_assign(self):
         return True
 
 
@@ -845,12 +1093,6 @@ class ExprCond(Expr):
     def __str__(self):
         return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        out_src1 = self.src1.get_r(mem_read, cst_read)
-        out_src2 = self.src2.get_r(mem_read, cst_read)
-        return self.cond.get_r(mem_read,
-                               cst_read).union(out_src1).union(out_src2)
-
     def get_w(self):
         return set()
 
@@ -862,21 +1104,6 @@ class ExprCond(Expr):
         return "%s(%r, %r, %r)" % (self.__class__.__name__,
                                    self._cond, self._src1, self._src2)
 
-    def __contains__(self, expr):
-        return (self == expr or
-                self.cond.__contains__(expr) or
-                self.src1.__contains__(expr) or
-                self.src2.__contains__(expr))
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        cond = self._cond.visit(callback, test_visit)
-        src1 = self._src1.visit(callback, test_visit)
-        src2 = self._src2.visit(callback, test_visit)
-        if cond == self._cond and src1 == self._src1 and src2 == self._src2:
-            return self
-        return ExprCond(cond, src1, src2)
-
     def copy(self):
         return ExprCond(self._cond.copy(),
                         self._src1.copy(),
@@ -953,12 +1180,6 @@ class ExprMem(Expr):
     def __str__(self):
         return "@%d[%s]" % (self.size, str(self.ptr))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if mem_read:
-            return set(self._ptr.get_r(mem_read, cst_read).union(set([self])))
-        else:
-            return set([self])
-
     def get_w(self):
         return set([self])  # [memreg]
 
@@ -969,16 +1190,6 @@ class ExprMem(Expr):
         return "%s(%r, %r)" % (self.__class__.__name__,
                                self._ptr, self._size)
 
-    def __contains__(self, expr):
-        return self == expr or self._ptr.__contains__(expr)
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        ptr = self._ptr.visit(callback, test_visit)
-        if ptr == self._ptr:
-            return self
-        return ExprMem(ptr, self.size)
-
     def copy(self):
         ptr = self.ptr.copy()
         return ExprMem(ptr, size=self.size)
@@ -1108,10 +1319,6 @@ class ExprOp(Expr):
         return (self._op + '(' +
                 ', '.join([str(arg) for arg in self._args]) + ')')
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
-
     def get_w(self):
         raise ValueError('op cannot be written!', self)
 
@@ -1123,14 +1330,6 @@ class ExprOp(Expr):
         return "%s(%r, %s)" % (self.__class__.__name__, self._op,
                                ', '.join(repr(arg) for arg in self._args))
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        for arg in self._args:
-            if arg.__contains__(expr):
-                return True
-        return False
-
     def is_function_call(self):
         return self._op.startswith('call')
 
@@ -1153,14 +1352,6 @@ class ExprOp(Expr):
         "Return True iff current operation is commutative"
         return (self._op in ['+', '*', '^', '&', '|'])
 
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        args = [arg.visit(callback, test_visit) for arg in self._args]
-        modified = any([arg[0] != arg[1] for arg in zip(self._args, args)])
-        if modified:
-            return ExprOp(self._op, *args)
-        return self
-
     def copy(self):
         args = [arg.copy() for arg in self._args]
         return ExprOp(self._op, *args)
@@ -1213,9 +1404,6 @@ class ExprSlice(Expr):
     def __str__(self):
         return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return self._arg.get_r(mem_read, cst_read)
-
     def get_w(self):
         return self._arg.get_w()
 
@@ -1226,18 +1414,6 @@ class ExprSlice(Expr):
         return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg,
                                    self._start, self._stop)
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        return self._arg.__contains__(expr)
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        arg = self._arg.visit(callback, test_visit)
-        if arg == self._arg:
-            return self
-        return ExprSlice(arg, self._start, self._stop)
-
     def copy(self):
         return ExprSlice(self._arg.copy(), self._start, self._stop)
 
@@ -1310,10 +1486,6 @@ class ExprCompose(Expr):
     def __str__(self):
         return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}'
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
-
     def get_w(self):
         return reduce(lambda elements, arg:
                       elements.union(arg.get_w()), self._args, set())
@@ -1325,24 +1497,6 @@ class ExprCompose(Expr):
     def _exprrepr(self):
         return "%s%r" % (self.__class__.__name__, self._args)
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        for arg in self._args:
-            if arg == expr:
-                return True
-            if arg.__contains__(expr):
-                return True
-        return False
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        args = [arg.visit(callback, test_visit) for arg in self._args]
-        modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)])
-        if modified:
-            return ExprCompose(*args)
-        return self
-
     def copy(self):
         args = [arg.copy() for arg in self._args]
         return ExprCompose(*args)
@@ -1669,8 +1823,8 @@ def match_expr(expr, pattern, tks, result=None):
                 return False
         return result
 
-    elif expr.is_aff():
-        if not pattern.is_aff():
+    elif expr.is_assign():
+        if not pattern.is_assign():
             return False
         if match_expr(expr.src, pattern.src, tks, result) is False:
             return False