about summary refs log tree commit diff stats
path: root/miasm/expression
diff options
context:
space:
mode:
Diffstat (limited to 'miasm/expression')
-rw-r--r--miasm/expression/expression.py534
-rw-r--r--miasm/expression/simplifications.py54
-rw-r--r--miasm/expression/simplifications_common.py138
3 files changed, 452 insertions, 274 deletions
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index 93094979..c2bf5b8b 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent):
 def str_protected_child(child, parent):
     return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child)
 
-def visit_chk(visitor):
-    "Function decorator launching callback on Expression visit"
-    def wrapped(expr, callback, test_visit=lambda x: True):
-        if (test_visit is not None) and (not test_visit(expr)):
-            return expr
-        expr_new = visitor(expr, callback, test_visit)
-        if expr_new is None:
-            return None
-        expr_new2 = callback(expr_new)
-        return expr_new2
-    return wrapped
-
 
 # Expression display
 
@@ -152,6 +140,49 @@ class DiGraphExpr(DiGraph):
 
         return ""
 
+def is_expr(expr):
+    return isinstance(
+        expr,
+        (
+            ExprInt, ExprId, ExprMem,
+            ExprSlice, ExprCompose, ExprCond,
+            ExprLoc, ExprOp
+        )
+    )
+
+def is_associative(expr):
+    "Return True iff current operation is associative"
+    return (expr.op in ['+', '*', '^', '&', '|'])
+
+def is_commutative(expr):
+    "Return True iff current operation is commutative"
+    return (expr.op in ['+', '*', '^', '&', '|'])
+
+def is_op_segm(expr):
+    """Returns True if is ExprOp and op == 'segm'"""
+    return expr.is_op('segm')
+
+def is_mem_segm(expr):
+    """Returns True if is ExprMem and ptr is_op_segm"""
+    return expr.is_mem() and is_op_segm(expr.ptr)
+
+def canonize_to_exprloc(locdb, expr):
+    """
+    If expr is ExprInt, return ExprLoc with corresponding loc_key
+    Else, return expr
+
+    @expr: Expr instance
+    """
+    if expr.is_int():
+        loc_key = locdb.get_or_create_offset_location(int(expr))
+        ret = ExprLoc(loc_key, expr.size)
+        return ret
+    return expr
+
+def is_function_call(expr):
+    """Returns true if the considered Expr is a function call
+    """
+    return expr.is_op() and expr.op.startswith('call')
 
 @total_ordering
 class LocKey(object):
@@ -183,6 +214,263 @@ class LocKey(object):
     def __str__(self):
         return "loc_key_%d" % self.key
 
+
+class ExprWalkBase(object):
+    """
+    Walk through sub-expressions, call @callback on them.
+    If @callback returns a non None value, stop walk and return this value
+    """
+
+    def __init__(self, callback):
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr.is_int() or expr.is_id() or expr.is_loc():
+            pass
+        elif expr.is_assign():
+            ret = self.visit(expr.dst, *args, **kwargs)
+            if ret:
+                return ret
+            src = self.visit(expr.src, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_cond():
+            ret = self.visit(expr.cond, *args, **kwargs)
+            if ret:
+                return ret
+            ret = self.visit(expr.src1, *args, **kwargs)
+            if ret:
+                return ret
+            ret = self.visit(expr.src2, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_mem():
+            ret = self.visit(expr.ptr, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_slice():
+            ret = self.visit(expr.arg, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_op():
+            for arg in expr.args:
+                ret = self.visit(arg, *args, **kwargs)
+                if ret:
+                    return ret
+        elif expr.is_compose():
+            for arg in expr.args:
+                ret = self.visit(arg, *args, **kwargs)
+                if ret:
+                    return ret
+        else:
+            raise TypeError("Visitor can only take Expr")
+
+        ret = self.callback(expr, *args, **kwargs)
+        return ret
+
+
+class ExprWalk(ExprWalkBase):
+    """
+    Walk through sub-expressions, call @callback on them.
+    If @callback returns a non None value, stop walk and return this value
+    Use cache mechanism.
+    """
+    def __init__(self, callback):
+        self.cache = set()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return None
+        ret = super(ExprWalk, self).visit(expr, *args, **kwargs)
+        if ret:
+            return ret
+        self.cache.add(expr)
+        return None
+
+
+class ExprGetR(ExprWalkBase):
+    """
+    Return ExprId/ExprMem used by a given expression
+    """
+    def __init__(self, mem_read=False, cst_read=False):
+        super(ExprGetR, self).__init__(lambda x:None)
+        self.mem_read = mem_read
+        self.cst_read = cst_read
+        self.elements = set()
+        self.cache = dict()
+
+    def get_r_leaves(self, expr):
+        if (expr.is_int() or expr.is_loc()) and self.cst_read:
+            self.elements.add(expr)
+        elif expr.is_mem():
+            self.elements.add(expr)
+        elif expr.is_id():
+            self.elements.add(expr)
+
+    def visit(self, expr, *args, **kwargs):
+        cache_key = (expr, self.mem_read, self.cst_read)
+        if cache_key in self.cache:
+            return self.cache[cache_key]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[cache_key] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        self.get_r_leaves(expr)
+        if expr.is_mem() and not self.mem_read:
+            # Don't visit memory sons
+            return None
+
+        if expr.is_assign():
+            if expr.dst.is_mem() and self.mem_read:
+                ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs)
+            if expr.src.is_mem():
+                self.elements.add(expr.src)
+            self.get_r_leaves(expr.src)
+            if expr.src.is_mem() and not self.mem_read:
+                return None
+            ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs)
+            return ret
+        ret = super(ExprGetR, self).visit(expr, *args, **kwargs)
+        return ret
+
+
+class ExprVisitorBase(object):
+    """
+    Rebuild expression by visiting sub-expressions
+    """
+    def visit(self, expr, *args, **kwargs):
+        if expr.is_int() or expr.is_id() or expr.is_loc():
+            ret = expr
+        elif expr.is_assign():
+            dst = self.visit(expr.dst, *args, **kwargs)
+            src = self.visit(expr.src, *args, **kwargs)
+            ret = ExprAssign(dst, src)
+        elif expr.is_cond():
+            cond = self.visit(expr.cond, *args, **kwargs)
+            src1 = self.visit(expr.src1, *args, **kwargs)
+            src2 = self.visit(expr.src2, *args, **kwargs)
+            ret = ExprCond(cond, src1, src2)
+        elif expr.is_mem():
+            ptr = self.visit(expr.ptr, *args, **kwargs)
+            ret = ExprMem(ptr, expr.size)
+        elif expr.is_slice():
+            arg = self.visit(expr.arg, *args, **kwargs)
+            ret = ExprSlice(arg, expr.start, expr.stop)
+        elif expr.is_op():
+            args = [self.visit(arg, *args, **kwargs) for arg in expr.args]
+            ret = ExprOp(expr.op, *args)
+        elif expr.is_compose():
+            args = [self.visit(arg, *args, **kwargs) for arg in expr.args]
+            ret = ExprCompose(*args)
+        else:
+            raise TypeError("Visitor can only take Expr")
+        return ret
+
+
+class ExprVisitorCallbackTopToBottom(ExprVisitorBase):
+    """
+    Rebuild expression by visiting sub-expressions
+    Call @callback on each sub-expression
+    if @callback return non None value, replace current node with this value
+    Else, continue visit of sub-expressions
+    """
+    def __init__(self, callback):
+        super(ExprVisitorCallbackTopToBottom, self).__init__()
+        self.cache = dict()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return self.cache[expr]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[expr] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        ret = self.callback(expr)
+        if ret:
+            return ret
+        ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs)
+        return ret
+
+
+class ExprVisitorCallbackBottomToTop(ExprVisitorBase):
+    """
+    Rebuild expression by visiting sub-expressions
+    Call @callback from leaves to root expressions
+    """
+    def __init__(self, callback):
+        super(ExprVisitorCallbackBottomToTop, self).__init__()
+        self.cache = dict()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return self.cache[expr]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[expr] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs)
+        ret = self.callback(ret)
+        return ret
+
+
+class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop):
+    def __init__(self):
+        super(ExprVisitorCanonize, self).__init__(self.canonize)
+
+    def canonize(self, expr):
+        if not expr.is_op():
+            return expr
+        if not expr.is_associative():
+            return expr
+
+        # ((a+b) + c) => (a + b + c)
+        args = []
+        for arg in expr.args:
+            if isinstance(arg, ExprOp) and expr.op == arg.op:
+                args += arg.args
+            else:
+                args.append(arg)
+        args = canonize_expr_list(args)
+        new_expr = ExprOp(expr.op, *args)
+        return new_expr
+
+
+class ExprVisitorContains(ExprWalkBase):
+    """
+    Visitor to test if a needle is in an Expression
+    Cache results
+    """
+    def __init__(self):
+        self.cache = set()
+        super(ExprVisitorContains, self).__init__(self.eq_expr)
+
+    def eq_expr(self, expr, needle, *args, **kwargs):
+        if expr == needle:
+            return True
+        return None
+
+    def visit(self, expr, needle,  *args, **kwargs):
+        if (expr, needle) in self.cache:
+            return None
+        ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs)
+        if ret:
+            return ret
+        self.cache.add((expr, needle))
+        return None
+
+
+    def contains(self, expr, needle):
+        return self.visit(expr, needle)
+
+contains_visitor = ExprVisitorContains()
+canonize_visitor = ExprVisitorCanonize()
+
 # IR definitions
 
 class Expr(object):
@@ -337,36 +625,16 @@ class Expr(object):
         """Find and replace sub expression using dct
         @dct: dictionary associating replaced Expr to its new Expr value
         """
-        return self.visit(lambda expr: dct.get(expr, expr))
+        def replace(expr):
+            if expr in dct:
+                return dct[expr]
+            return None
+        visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr))
+        return visitor.visit(self)
 
     def canonize(self):
         "Canonize the Expression"
-
-        def must_canon(expr):
-            return not expr.is_canon
-
-        def canonize_visitor(expr):
-            if expr.is_canon:
-                return expr
-            if isinstance(expr, ExprOp):
-                if expr.is_associative():
-                    # ((a+b) + c) => (a + b + c)
-                    args = []
-                    for arg in expr.args:
-                        if isinstance(arg, ExprOp) and expr.op == arg.op:
-                            args += arg.args
-                        else:
-                            args.append(arg)
-                    args = canonize_expr_list(args)
-                    new_e = ExprOp(expr.op, *args)
-                else:
-                    new_e = expr
-            else:
-                new_e = expr
-            new_e.is_canon = True
-            return new_e
-
-        return self.visit(canonize_visitor, must_canon)
+        return canonize_visitor.visit(self)
 
     def msb(self):
         "Return the Most Significant Bit"
@@ -424,6 +692,10 @@ class Expr(object):
         return False
 
     def is_aff(self):
+        warnings.warn('DEPRECATION WARNING: use is_assign()')
+        return False
+
+    def is_assign(self):
         return False
 
     def is_cond(self):
@@ -449,6 +721,32 @@ class Expr(object):
         """Returns True if is ExprMem and ptr is_op_segm"""
         return False
 
+    def __contains__(self, expr):
+        ret = contains_visitor.contains(self, expr)
+        return ret
+
+    def visit(self, callback):
+        """
+        Apply callback to all sub expression of @self
+        This function keeps a cache to avoid rerunning @callback on common sub
+        expressions.
+
+        @callback: fn(Expr) -> Expr
+        """
+        visitor = ExprVisitorCallbackBottomToTop(callback)
+        return visitor.visit(self)
+
+    def get_r(self, mem_read=False, cst_read=False):
+        visitor = ExprGetR(mem_read, cst_read)
+        visitor.visit(self)
+        return visitor.elements
+
+
+    def get_w(self, mem_read=False, cst_read=False):
+        if self.is_assign():
+            return set([self.dst])
+        return set()
+
 class ExprInt(Expr):
 
     """An ExprInt represent a constant in Miasm IR.
@@ -508,12 +806,6 @@ class ExprInt(Expr):
         else:
             return str("0x%X" % self._get_int())
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if cst_read:
-            return set([self])
-        else:
-            return set()
-
     def get_w(self):
         return set()
 
@@ -524,13 +816,6 @@ class ExprInt(Expr):
         return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(),
                                  self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprInt(self._arg, self._size)
 
@@ -591,9 +876,6 @@ class ExprId(Expr):
     def __str__(self):
         return str(self._name)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return set([self])
-
     def get_w(self):
         return set([self])
 
@@ -603,13 +885,6 @@ class ExprId(Expr):
     def _exprrepr(self):
         return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprId(self._name, self._size)
 
@@ -653,12 +928,6 @@ class ExprLoc(Expr):
     def __str__(self):
         return str(self._loc_key)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if cst_read:
-            return set([self])
-        else:
-            return set()
-
     def get_w(self):
         return set()
 
@@ -668,13 +937,6 @@ class ExprLoc(Expr):
     def _exprrepr(self):
         return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprLoc(self._loc_key, self._size)
 
@@ -745,12 +1007,6 @@ class ExprAssign(Expr):
     def __str__(self):
         return "%s = %s" % (str(self._dst), str(self._src))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        elements = self._src.get_r(mem_read, cst_read)
-        if isinstance(self._dst, ExprMem) and mem_read:
-            elements.update(self._dst.ptr.get_r(mem_read, cst_read))
-        return elements
-
     def get_w(self):
         if isinstance(self._dst, ExprMem):
             return set([self._dst])  # [memreg]
@@ -763,19 +1019,6 @@ class ExprAssign(Expr):
     def _exprrepr(self):
         return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src)
 
-    def __contains__(self, expr):
-        return (self == expr or
-                self._src.__contains__(expr) or
-                self._dst.__contains__(expr))
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit)
-        if dst == self._dst and src == self._src:
-            return self
-        else:
-            return ExprAssign(dst, src)
-
     def copy(self):
         return ExprAssign(self._dst.copy(), self._src.copy())
 
@@ -788,7 +1031,12 @@ class ExprAssign(Expr):
             arg.graph_recursive(graph)
             graph.add_uniq_edge(self, arg)
 
+
     def is_aff(self):
+        warnings.warn('DEPRECATION WARNING: use is_assign()')
+        return True
+
+    def is_assign(self):
         return True
 
 
@@ -845,12 +1093,6 @@ class ExprCond(Expr):
     def __str__(self):
         return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        out_src1 = self.src1.get_r(mem_read, cst_read)
-        out_src2 = self.src2.get_r(mem_read, cst_read)
-        return self.cond.get_r(mem_read,
-                               cst_read).union(out_src1).union(out_src2)
-
     def get_w(self):
         return set()
 
@@ -862,21 +1104,6 @@ class ExprCond(Expr):
         return "%s(%r, %r, %r)" % (self.__class__.__name__,
                                    self._cond, self._src1, self._src2)
 
-    def __contains__(self, expr):
-        return (self == expr or
-                self.cond.__contains__(expr) or
-                self.src1.__contains__(expr) or
-                self.src2.__contains__(expr))
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        cond = self._cond.visit(callback, test_visit)
-        src1 = self._src1.visit(callback, test_visit)
-        src2 = self._src2.visit(callback, test_visit)
-        if cond == self._cond and src1 == self._src1 and src2 == self._src2:
-            return self
-        return ExprCond(cond, src1, src2)
-
     def copy(self):
         return ExprCond(self._cond.copy(),
                         self._src1.copy(),
@@ -953,12 +1180,6 @@ class ExprMem(Expr):
     def __str__(self):
         return "@%d[%s]" % (self.size, str(self.ptr))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if mem_read:
-            return set(self._ptr.get_r(mem_read, cst_read).union(set([self])))
-        else:
-            return set([self])
-
     def get_w(self):
         return set([self])  # [memreg]
 
@@ -969,16 +1190,6 @@ class ExprMem(Expr):
         return "%s(%r, %r)" % (self.__class__.__name__,
                                self._ptr, self._size)
 
-    def __contains__(self, expr):
-        return self == expr or self._ptr.__contains__(expr)
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        ptr = self._ptr.visit(callback, test_visit)
-        if ptr == self._ptr:
-            return self
-        return ExprMem(ptr, self.size)
-
     def copy(self):
         ptr = self.ptr.copy()
         return ExprMem(ptr, size=self.size)
@@ -1108,10 +1319,6 @@ class ExprOp(Expr):
         return (self._op + '(' +
                 ', '.join([str(arg) for arg in self._args]) + ')')
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
-
     def get_w(self):
         raise ValueError('op cannot be written!', self)
 
@@ -1123,14 +1330,6 @@ class ExprOp(Expr):
         return "%s(%r, %s)" % (self.__class__.__name__, self._op,
                                ', '.join(repr(arg) for arg in self._args))
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        for arg in self._args:
-            if arg.__contains__(expr):
-                return True
-        return False
-
     def is_function_call(self):
         return self._op.startswith('call')
 
@@ -1153,14 +1352,6 @@ class ExprOp(Expr):
         "Return True iff current operation is commutative"
         return (self._op in ['+', '*', '^', '&', '|'])
 
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        args = [arg.visit(callback, test_visit) for arg in self._args]
-        modified = any([arg[0] != arg[1] for arg in zip(self._args, args)])
-        if modified:
-            return ExprOp(self._op, *args)
-        return self
-
     def copy(self):
         args = [arg.copy() for arg in self._args]
         return ExprOp(self._op, *args)
@@ -1213,9 +1404,6 @@ class ExprSlice(Expr):
     def __str__(self):
         return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return self._arg.get_r(mem_read, cst_read)
-
     def get_w(self):
         return self._arg.get_w()
 
@@ -1226,18 +1414,6 @@ class ExprSlice(Expr):
         return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg,
                                    self._start, self._stop)
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        return self._arg.__contains__(expr)
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        arg = self._arg.visit(callback, test_visit)
-        if arg == self._arg:
-            return self
-        return ExprSlice(arg, self._start, self._stop)
-
     def copy(self):
         return ExprSlice(self._arg.copy(), self._start, self._stop)
 
@@ -1310,10 +1486,6 @@ class ExprCompose(Expr):
     def __str__(self):
         return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}'
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
-
     def get_w(self):
         return reduce(lambda elements, arg:
                       elements.union(arg.get_w()), self._args, set())
@@ -1325,24 +1497,6 @@ class ExprCompose(Expr):
     def _exprrepr(self):
         return "%s%r" % (self.__class__.__name__, self._args)
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        for arg in self._args:
-            if arg == expr:
-                return True
-            if arg.__contains__(expr):
-                return True
-        return False
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        args = [arg.visit(callback, test_visit) for arg in self._args]
-        modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)])
-        if modified:
-            return ExprCompose(*args)
-        return self
-
     def copy(self):
         args = [arg.copy() for arg in self._args]
         return ExprCompose(*args)
@@ -1669,8 +1823,8 @@ def match_expr(expr, pattern, tks, result=None):
                 return False
         return result
 
-    elif expr.is_aff():
-        if not pattern.is_aff():
+    elif expr.is_assign():
+        if not pattern.is_assign():
             return False
         if match_expr(expr.src, pattern.src, tks, result) is False:
             return False
diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py
index 03a779a6..3f54b158 100644
--- a/miasm/expression/simplifications.py
+++ b/miasm/expression/simplifications.py
@@ -11,6 +11,7 @@ from miasm.expression import simplifications_cond
 from miasm.expression import simplifications_explicit
 from miasm.expression.expression_helper import fast_unify
 import miasm.expression.expression as m2_expr
+from miasm.expression.expression import ExprVisitorCallbackBottomToTop
 
 # Expression Simplifier
 # ---------------------
@@ -22,7 +23,7 @@ log_exprsimp.addHandler(console_handler)
 log_exprsimp.setLevel(logging.WARNING)
 
 
-class ExpressionSimplifier(object):
+class ExpressionSimplifier(ExprVisitorCallbackBottomToTop):
 
     """Wrapper on expression simplification passes.
 
@@ -49,6 +50,8 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_double_signext,
             simplifications_common.simp_zeroext_eq_cst,
             simplifications_common.simp_ext_eq_ext,
+            simplifications_common.simp_ext_cond_int,
+            simplifications_common.simp_sub_cf_zero,
 
             simplifications_common.simp_cmp_int,
             simplifications_common.simp_cmp_bijective_op,
@@ -118,8 +121,8 @@ class ExpressionSimplifier(object):
 
 
     def __init__(self):
+        super(ExpressionSimplifier, self).__init__(self.expr_simp_inner)
         self.expr_simp_cb = {}
-        self.simplified_exprs = set()
 
     def enable_passes(self, passes):
         """Add passes from @passes
@@ -129,7 +132,7 @@ class ExpressionSimplifier(object):
         """
 
         # Clear cache of simplifiied expressions when adding a new pass
-        self.simplified_exprs.clear()
+        self.cache.clear()
 
         for k, v in viewitems(passes):
             self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v)
@@ -156,46 +159,29 @@ class ExpressionSimplifier(object):
 
         return expression
 
-    def expr_simp(self, expression):
+    def expr_simp_inner(self, expression):
         """Apply enabled simplifications on expression and find a stable state
         @expression: Expr instance
         Return an Expr instance"""
 
-        if expression in self.simplified_exprs:
-            return expression
-
         # Find a stable state
         while True:
             # Canonize and simplify
-            e_new = self.apply_simp(expression.canonize())
-            if e_new == expression:
-                break
-
-            # Launch recursivity
-            expression = self.expr_simp_wrapper(e_new)
-            self.simplified_exprs.add(expression)
-        # Mark expression as simplified
-        self.simplified_exprs.add(e_new)
-
-        return e_new
-
-    def expr_simp_wrapper(self, expression, callback=None):
-        """Apply enabled simplifications on expression
-        @expression: Expr instance
-        @manual_callback: If set, call this function instead of normal one
-        Return an Expr instance"""
+            new_expr = self.apply_simp(expression.canonize())
+            if new_expr == expression:
+                return new_expr
+            # Run recursively simplification on fresh new expression
+            new_expr = self.visit(new_expr)
+            expression = new_expr
+        return new_expr
 
-        if expression in self.simplified_exprs:
-            return expression
-
-        if callback is None:
-            callback = self.expr_simp
-
-        return expression.visit(callback, lambda e: e not in self.simplified_exprs)
+    def expr_simp(self, expression):
+        "Call simplification recursively"
+        return self.visit(expression)
 
-    def __call__(self, expression, callback=None):
-        "Wrapper on expr_simp_wrapper"
-        return self.expr_simp_wrapper(expression, callback)
+    def __call__(self, expression):
+        "Call simplification recursively"
+        return self.visit(expression)
 
 
 # Public ExprSimplificationPass instance with commons passes
diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py
index 1c0bb04c..932db49a 100644
--- a/miasm/expression/simplifications_common.py
+++ b/miasm/expression/simplifications_common.py
@@ -32,30 +32,30 @@ def simp_cst_propagation(e_s, expr):
             int2 = args.pop()
             int1 = args.pop()
             if op_name == '+':
-                out = int1.arg + int2.arg
+                out = mod_size2uint[int1.size](int(int1) + int(int2))
             elif op_name == '*':
-                out = int1.arg * int2.arg
+                out = mod_size2uint[int1.size](int(int1) * int(int2))
             elif op_name == '**':
-                out =int1.arg ** int2.arg
+                out = mod_size2uint[int1.size](int(int1) ** int(int2))
             elif op_name == '^':
-                out = int1.arg ^ int2.arg
+                out = mod_size2uint[int1.size](int(int1) ^ int(int2))
             elif op_name == '&':
-                out = int1.arg & int2.arg
+                out = mod_size2uint[int1.size](int(int1) & int(int2))
             elif op_name == '|':
-                out = int1.arg | int2.arg
+                out = mod_size2uint[int1.size](int(int1) | int(int2))
             elif op_name == '>>':
                 if int(int2) > int1.size:
                     out = 0
                 else:
-                    out = int1.arg >> int2.arg
+                    out = mod_size2uint[int1.size](int(int1) >> int(int2))
             elif op_name == '<<':
                 if int(int2) > int1.size:
                     out = 0
                 else:
-                    out = int1.arg << int2.arg
+                    out = mod_size2uint[int1.size](int(int1) << int(int2))
             elif op_name == 'a>>':
-                tmp1 = mod_size2int[int1.arg.size](int1.arg)
-                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
                 if tmp2 > int1.size:
                     is_signed = int(int1) & (1 << (int1.size - 1))
                     if is_signed:
@@ -63,55 +63,57 @@ def simp_cst_propagation(e_s, expr):
                     else:
                         out = 0
                 else:
-                    out = mod_size2uint[int1.arg.size](tmp1 >> tmp2)
+                    out = mod_size2uint[int1.size](tmp1 >> tmp2)
             elif op_name == '>>>':
-                shifter = int2.arg % int2.size
-                out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter))
+                shifter = int(int2) % int2.size
+                out = (int(int1) >> shifter) | (int(int1) << (int2.size - shifter))
             elif op_name == '<<<':
-                shifter = int2.arg % int2.size
-                out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter))
+                shifter = int(int2) % int2.size
+                out = (int(int1) << shifter) | (int(int1) >> (int2.size - shifter))
             elif op_name == '/':
-                out = int1.arg // int2.arg
+                assert int(int2), "division by 0"
+                out = int(int1) // int(int2)
             elif op_name == '%':
-                out = int1.arg % int2.arg
+                assert int(int2), "division by 0"
+                out = int(int1) % int(int2)
             elif op_name == 'sdiv':
-                assert int2.arg.arg
-                tmp1 = mod_size2int[int1.arg.size](int1.arg)
-                tmp2 = mod_size2int[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 // tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2int[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 // tmp2)
             elif op_name == 'smod':
-                assert int2.arg.arg
-                tmp1 = mod_size2int[int1.arg.size](int1.arg)
-                tmp2 = mod_size2int[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2int[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 % tmp2)
             elif op_name == 'umod':
-                assert int2.arg.arg
-                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
-                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2uint[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 % tmp2)
             elif op_name == 'udiv':
-                assert int2.arg.arg
-                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
-                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 // tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2uint[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 // tmp2)
 
 
 
-            args.append(ExprInt(out, int1.size))
+            args.append(ExprInt(int(out), int1.size))
 
     # cnttrailzeros(int) => int
     if op_name == "cnttrailzeros" and args[0].is_int():
         i = 0
-        while args[0].arg & (1 << i) == 0 and i < args[0].size:
+        while int(args[0]) & (1 << i) == 0 and i < args[0].size:
             i += 1
         return ExprInt(i, args[0].size)
 
     # cntleadzeros(int) => int
     if op_name == "cntleadzeros" and args[0].is_int():
-        if args[0].arg == 0:
+        if int(args[0]) == 0:
             return ExprInt(args[0].size, args[0].size)
         i = args[0].size - 1
-        while args[0].arg & (1 << i) == 0:
+        while int(args[0]) & (1 << i) == 0:
             i -= 1
         return ExprInt(expr.size - (i + 1), args[0].size)
 
@@ -120,6 +122,7 @@ def simp_cst_propagation(e_s, expr):
         len(args[0].args) == 1):
         return args[0].args[0]
 
+
     # -(int) => -int
     if op_name == '-' and len(args) == 1 and args[0].is_int():
         return ExprInt(-int(args[0]), expr.size)
@@ -207,13 +210,13 @@ def simp_cst_propagation(e_s, expr):
             j += 1
         i += 1
 
-    if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1:
+    if op_name in ['+', '^', '|', '&', '%', '/', '**'] and len(args) == 1:
         return args[0]
 
     # A <<< A.size => A
     if (op_name in ['<<<', '>>>'] and
         args[1].is_int() and
-        args[1].arg == args[0].size):
+        int(args[1]) == args[0].size):
         return args[0]
 
     # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
@@ -277,7 +280,10 @@ def simp_cst_propagation(e_s, expr):
 
     # ((A & A.mask)
     if op_name == "&" and args[-1] == expr.mask:
-        return ExprOp('&', *args[:-1])
+        args = args[:-1]
+        if len(args) == 1:
+            return args[0]
+        return ExprOp('&', *args)
 
     # ((A | A.mask)
     if op_name == "|" and args[-1] == expr.mask:
@@ -289,7 +295,7 @@ def simp_cst_propagation(e_s, expr):
     # ((A & mask) >> shift) with mask < 2**shift => 0
     if op_name == ">>" and args[1].is_int() and args[0].is_op("&"):
         if (args[0].args[1].is_int() and
-            2 ** args[1].arg > args[0].args[1].arg):
+            2 ** int(args[1]) > int(args[0].args[1])):
             return ExprInt(0, args[0].size)
 
     # parity(int) => int
@@ -315,7 +321,6 @@ def simp_cst_propagation(e_s, expr):
         args = args[0].args
         return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)]))
 
-
     # A << int with A ExprCompose => move index
     if (op_name == "<<" and args[0].is_compose() and
         args[1].is_int() and int(args[1]) != 0):
@@ -450,8 +455,8 @@ def simp_cond_factor(e_s, expr):
     for cond, vals in viewitems(conds):
         new_src1 = [x.src1 for x in vals]
         new_src2 = [x.src2 for x in vals]
-        src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1))
-        src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2))
+        src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1))
+        src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2))
         c_out.append(ExprCond(cond, src1, src2))
 
     if len(c_out) == 1:
@@ -471,7 +476,7 @@ def simp_slice(e_s, expr):
     if expr.arg.is_int():
         total_bit = expr.stop - expr.start
         mask = (1 << (expr.stop - expr.start)) - 1
-        return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit)
+        return ExprInt(int((int(expr.arg) >> expr.start) & mask), total_bit)
     # Slice(Slice(A, x), y) => Slice(A, z)
     if expr.arg.is_slice():
         if expr.stop - expr.start > expr.arg.stop - expr.arg.start:
@@ -521,7 +526,7 @@ def simp_slice(e_s, expr):
     # distributivity of slice and &
     # (a & int)[x:y] => 0 if int[x:y] == 0
     if expr.arg.is_op("&") and expr.arg.args[-1].is_int():
-        tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop])
+        tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop])
         if tmp.is_int(0):
             return tmp
     # distributivity of slice and exprcond
@@ -536,7 +541,7 @@ def simp_slice(e_s, expr):
 
     # (a * int)[0:y] => (a[0:y] * int[0:y])
     if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int():
-        args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args]
+        args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args]
         return ExprOp(expr.arg.op, *args)
 
     # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size
@@ -626,7 +631,7 @@ def simp_cond(_, expr):
         expr = expr.src1
     # int ? A:B => A or B
     elif expr.cond.is_int():
-        if expr.cond.arg == 0:
+        if int(expr.cond) == 0:
             expr = expr.src2
         else:
             expr = expr.src1
@@ -646,8 +651,8 @@ def simp_cond(_, expr):
     elif (expr.cond.is_cond() and
           expr.cond.src1.is_int() and
           expr.cond.src2.is_int()):
-        int1 = expr.cond.src1.arg.arg
-        int2 = expr.cond.src2.arg.arg
+        int1 = int(expr.cond.src1)
+        int2 = int(expr.cond.src2)
         if int1 and int2:
             expr = expr.src1
         elif int1 == 0 and int2 == 0:
@@ -906,6 +911,15 @@ def simp_cond_flag(_, expr):
     return expr
 
 
+def simp_sub_cf_zero(_, expr):
+    """FLAG_SUB_CF(0, X) => (X)?1:0"""
+    if not expr.is_op("FLAG_SUB_CF"):
+        return expr
+    if not expr.args[0].is_int(0):
+        return expr
+    return ExprCond(expr.args[1], ExprInt(1, 1), ExprInt(0, 1))
+
+
 def simp_cmp_int(expr_simp, expr):
     """
     ({X, 0} == int) => X == int[:]
@@ -1069,6 +1083,13 @@ def simp_cmp_bijective_op(expr_simp, expr):
             args_a.remove(value)
             args_b.remove(value)
 
+    # a + b == a + b + c
+    if not args_a:
+        return ExprOp(TOK_EQUAL, ExprOp(op, *args_b), ExprInt(0, args_b[0].size))
+    # a + b + c == a + b
+    if not args_b:
+        return ExprOp(TOK_EQUAL, ExprOp(op, *args_a), ExprInt(0, args_a[0].size))
+    
     arg_a = ExprOp(op, *args_a)
     arg_b = ExprOp(op, *args_b)
     return ExprOp(TOK_EQUAL, arg_a, arg_b)
@@ -1362,6 +1383,23 @@ def simp_ext_cst(_, expr):
     return ret
 
 
+
+def simp_ext_cond_int(e_s, expr):
+    """
+    zeroExt(ExprCond(X, Int, Int)) => ExprCond(X, Int, Int)
+    """
+    if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")):
+        return expr
+    arg = expr.args[0]
+    if not arg.is_cond():
+        return expr
+    if not (arg.src1.is_int() and arg.src2.is_int()):
+        return expr
+    src1 = ExprOp(expr.op, arg.src1)
+    src2 = ExprOp(expr.op, arg.src2)
+    return e_s(ExprCond(arg.cond, src1, src2))
+
+
 def simp_slice_of_ext(_, expr):
     """
     C.zeroExt(X)[A:B] => 0 if A >= size(C)