about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2020-03-10 16:35:51 +0100
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2020-03-29 00:08:03 +0100
commitbf9babdcd886d51666c04e4fc39a4b03e281974a (patch)
tree021484913b08e5c72366da78e6a6c5d68e31893f
parent83196a14885467a043666882db7d8120bb127b61 (diff)
downloadmiasm-bf9babdcd886d51666c04e4fc39a4b03e281974a.tar.gz
miasm-bf9babdcd886d51666c04e4fc39a4b03e281974a.zip
Add Expression visitor
-rw-r--r--miasm/analysis/data_flow.py20
-rw-r--r--miasm/analysis/depgraph.py117
-rw-r--r--miasm/analysis/dse.py4
-rw-r--r--miasm/expression/expression.py478
-rw-r--r--miasm/expression/simplifications.py52
-rw-r--r--miasm/expression/simplifications_common.py8
-rw-r--r--test/expression/expression.py45
7 files changed, 420 insertions, 304 deletions
diff --git a/miasm/analysis/data_flow.py b/miasm/analysis/data_flow.py
index 7bd6d72f..6395fa8c 100644
--- a/miasm/analysis/data_flow.py
+++ b/miasm/analysis/data_flow.py
@@ -6,8 +6,8 @@ from future.utils import viewitems, viewvalues
 from miasm.core.utils import encode_hex
 from miasm.core.graph import DiGraph
 from miasm.ir.ir import AssignBlock, IRBlock
-from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\
-    ExprAssign, ExprOp
+from miasm.expression.expression import ExprLoc, ExprMem, ExprSlice, ExprId, \
+    ExprInt, ExprAssign, ExprOp, ExprCompose, ExprCond, ExprWalk
 from miasm.expression.simplifications import expr_simp
 from miasm.core.interval import interval
 from miasm.expression.expression_helper import possible_values
@@ -736,22 +736,16 @@ def expr_test_visit(expr, test):
         return False
 
 
-def expr_has_mem_test(expr, result):
-    if result:
-        # Don't analyse if we already found a candidate
-        return False
-    if expr.is_mem():
-        result.add(expr)
-        return False
-    return True
-
-
 def expr_has_mem(expr):
     """
     Return True if expr contains at least one memory access
     @expr: Expr instance
     """
-    return expr_test_visit(expr, expr_has_mem_test)
+
+    def has_mem(self):
+        return self.is_mem()
+    visitor = ExprWalk(has_mem)
+    return visitor.visit(expr)
 
 
 class PropagateThroughExprId(object):
diff --git a/miasm/analysis/depgraph.py b/miasm/analysis/depgraph.py
index 7113dd51..b0d13318 100644
--- a/miasm/analysis/depgraph.py
+++ b/miasm/analysis/depgraph.py
@@ -4,7 +4,8 @@ from functools import total_ordering
 
 from future.utils import viewitems
 
-from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign
+from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign, \
+    ExprWalk
 from miasm.core.graph import DiGraph
 from miasm.core.locationdb import LocationDB
 from miasm.expression.simplifications import expr_simp_explicit
@@ -449,6 +450,50 @@ class FollowExpr(object):
                    if not(only_follow) or follow_expr.follow)
 
 
+class FilterExprSources(ExprWalk):
+    """
+    Walk Expression to find sources to track
+    @follow_mem: (optional) Track memory syntactically
+    @follow_call: (optional) Track through "call"
+    """
+    def __init__(self, follow_mem, follow_call):
+        super(FilterExprSources, self).__init__(lambda x:None)
+        self.follow_mem = follow_mem
+        self.follow_call = follow_call
+        self.nofollow = set()
+        self.follow = set()
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return None
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache.add(expr)
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        if expr.is_id():
+            self.follow.add(expr)
+        elif expr.is_int():
+            self.nofollow.add(expr)
+        elif expr.is_loc():
+            self.nofollow.add(expr)
+        elif expr.is_mem():
+            if self.follow_mem:
+                self.follow.add(expr)
+            else:
+                self.nofollow.add(expr)
+                return None
+        elif expr.is_function_call():
+            if self.follow_call:
+                self.follow.add(expr)
+            else:
+                self.nofollow.add(expr)
+                return None
+
+        ret = super(FilterExprSources, self).visit(expr, *args, **kwargs)
+        return ret
+
+
 class DependencyGraph(object):
 
     """Implementation of a dependency graph
@@ -480,10 +525,14 @@ class DependencyGraph(object):
         self._cb_follow = []
         if apply_simp:
             self._cb_follow.append(self._follow_simp_expr)
-        self._cb_follow.append(lambda exprs: self._follow_exprs(exprs,
-                                                                follow_mem,
-                                                                follow_call))
-        self._cb_follow.append(self._follow_no_loc_key)
+        self._cb_follow.append(lambda exprs: self.do_follow(exprs, follow_mem, follow_call))
+
+    @staticmethod
+    def do_follow(exprs, follow_mem, follow_call):
+        visitor = FilterExprSources(follow_mem, follow_call)
+        for expr in exprs:
+            visitor.visit(expr)
+        return visitor.follow, visitor.nofollow
 
     @staticmethod
     def _follow_simp_expr(exprs):
@@ -495,64 +544,6 @@ class DependencyGraph(object):
             follow.add(expr_simp_explicit(expr))
         return follow, set()
 
-    @staticmethod
-    def get_expr(expr, follow, nofollow):
-        """Update @follow/@nofollow according to insteresting nodes
-        Returns same expression (non modifier visitor).
-
-        @expr: expression to handle
-        @follow: set of nodes to follow
-        @nofollow: set of nodes not to follow
-        """
-        if expr.is_id():
-            follow.add(expr)
-        elif expr.is_int():
-            nofollow.add(expr)
-        elif expr.is_mem():
-            follow.add(expr)
-        return expr
-
-    @staticmethod
-    def follow_expr(expr, _, nofollow, follow_mem=False, follow_call=False):
-        """Returns True if we must visit sub expressions.
-        @expr: expression to browse
-        @follow: set of nodes to follow
-        @nofollow: set of nodes not to follow
-        @follow_mem: force the visit of memory sub expressions
-        @follow_call: force the visit of call sub expressions
-        """
-        if not follow_mem and expr.is_mem():
-            nofollow.add(expr)
-            return False
-        if not follow_call and expr.is_function_call():
-            nofollow.add(expr)
-            return False
-        return True
-
-    @classmethod
-    def _follow_exprs(cls, exprs, follow_mem=False, follow_call=False):
-        """Extracts subnodes from exprs and returns followed/non followed
-        expressions according to @follow_mem/@follow_call
-
-        """
-        follow, nofollow = set(), set()
-        for expr in exprs:
-            expr.visit(lambda x: cls.get_expr(x, follow, nofollow),
-                       lambda x: cls.follow_expr(x, follow, nofollow,
-                                                 follow_mem, follow_call))
-        return follow, nofollow
-
-    @staticmethod
-    def _follow_no_loc_key(exprs):
-        """Do not follow loc_keys"""
-        follow = set()
-        for expr in exprs:
-            if expr.is_int() or expr.is_loc():
-                continue
-            follow.add(expr)
-
-        return follow, set()
-
     def _follow_apply_cb(self, expr):
         """Apply callback functions to @expr
         @expr : FollowExpr instance"""
diff --git a/miasm/analysis/dse.py b/miasm/analysis/dse.py
index 3a0482a3..ada3c4bd 100644
--- a/miasm/analysis/dse.py
+++ b/miasm/analysis/dse.py
@@ -333,8 +333,8 @@ class DSEEngine(object):
         self.handle(ExprInt(cur_addr, self.ir_arch.IRDst.size))
 
         # Avoid memory issue in ExpressionSimplifier
-        if len(self.symb.expr_simp.simplified_exprs) > 100000:
-            self.symb.expr_simp.simplified_exprs.clear()
+        if len(self.symb.expr_simp.cache) > 100000:
+            self.symb.expr_simp.cache.clear()
 
         # Get IR blocks
         if cur_addr in self.addr_to_cacheblocks:
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index d0e57b46..9ac631fa 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent):
 def str_protected_child(child, parent):
     return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child)
 
-def visit_chk(visitor):
-    "Function decorator launching callback on Expression visit"
-    def wrapped(expr, callback, test_visit=lambda x: True):
-        if (test_visit is not None) and (not test_visit(expr)):
-            return expr
-        expr_new = visitor(expr, callback, test_visit)
-        if expr_new is None:
-            return None
-        expr_new2 = callback(expr_new)
-        return expr_new2
-    return wrapped
-
 
 # Expression display
 
@@ -183,6 +171,263 @@ class LocKey(object):
     def __str__(self):
         return "loc_key_%d" % self.key
 
+
+class ExprWalkBase(object):
+    """
+    Walk through sub-expressions, call @callback on them.
+    If @callback returns a non None value, stop walk and return this value
+    """
+
+    def __init__(self, callback):
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr.is_int() or expr.is_id() or expr.is_loc():
+            pass
+        elif expr.is_assign():
+            ret = self.visit(expr.dst, *args, **kwargs)
+            if ret:
+                return ret
+            src = self.visit(expr.src, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_cond():
+            ret = self.visit(expr.cond, *args, **kwargs)
+            if ret:
+                return ret
+            ret = self.visit(expr.src1, *args, **kwargs)
+            if ret:
+                return ret
+            ret = self.visit(expr.src2, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_mem():
+            ret = self.visit(expr.ptr, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_slice():
+            ret = self.visit(expr.arg, *args, **kwargs)
+            if ret:
+                return ret
+        elif expr.is_op():
+            for arg in expr.args:
+                ret = self.visit(arg, *args, **kwargs)
+                if ret:
+                    return ret
+        elif expr.is_compose():
+            for arg in expr.args:
+                ret = self.visit(arg, *args, **kwargs)
+                if ret:
+                    return ret
+        else:
+            raise TypeError("Visitor can only take Expr")
+
+        ret = self.callback(expr, *args, **kwargs)
+        return ret
+
+
+class ExprWalk(ExprWalkBase):
+    """
+    Walk through sub-expressions, call @callback on them.
+    If @callback returns a non None value, stop walk and return this value
+    Use cache mechanism.
+    """
+    def __init__(self, callback):
+        self.cache = set()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return None
+        ret = super(ExprWalk, self).visit(expr, *args, **kwargs)
+        if ret:
+            return ret
+        self.cache.add(expr)
+        return None
+
+
+class ExprGetR(ExprWalkBase):
+    """
+    Return ExprId/ExprMem used by a given expression
+    """
+    def __init__(self, mem_read=False, cst_read=False):
+        super(ExprGetR, self).__init__(lambda x:None)
+        self.mem_read = mem_read
+        self.cst_read = cst_read
+        self.elements = set()
+        self.cache = dict()
+
+    def get_r_leaves(self, expr):
+        if (expr.is_int() or expr.is_loc()) and self.cst_read:
+            self.elements.add(expr)
+        elif expr.is_mem():
+            self.elements.add(expr)
+        elif expr.is_id():
+            self.elements.add(expr)
+
+    def visit(self, expr, *args, **kwargs):
+        cache_key = (expr, self.mem_read, self.cst_read)
+        if cache_key in self.cache:
+            return self.cache[cache_key]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[cache_key] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        self.get_r_leaves(expr)
+        if expr.is_mem() and not self.mem_read:
+            # Don't visit memory sons
+            return None
+
+        if expr.is_assign():
+            if expr.dst.is_mem() and self.mem_read:
+                ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs)
+            if expr.src.is_mem():
+                self.elements.add(expr.src)
+            self.get_r_leaves(expr.src)
+            if expr.src.is_mem() and not self.mem_read:
+                return None
+            ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs)
+            return ret
+        ret = super(ExprGetR, self).visit(expr, *args, **kwargs)
+        return ret
+
+
+class ExprVisitorBase(object):
+    """
+    Rebuild expression by visiting sub-expressions
+    """
+    def visit(self, expr, *args, **kwargs):
+        if expr.is_int() or expr.is_id() or expr.is_loc():
+            ret = expr
+        elif expr.is_assign():
+            dst = self.visit(expr.dst, *args, **kwargs)
+            src = self.visit(expr.src, *args, **kwargs)
+            ret = ExprAssign(dst, src)
+        elif expr.is_cond():
+            cond = self.visit(expr.cond, *args, **kwargs)
+            src1 = self.visit(expr.src1, *args, **kwargs)
+            src2 = self.visit(expr.src2, *args, **kwargs)
+            ret = ExprCond(cond, src1, src2)
+        elif expr.is_mem():
+            ptr = self.visit(expr.ptr, *args, **kwargs)
+            ret = ExprMem(ptr, expr.size)
+        elif expr.is_slice():
+            arg = self.visit(expr.arg, *args, **kwargs)
+            ret = ExprSlice(arg, expr.start, expr.stop)
+        elif expr.is_op():
+            args = [self.visit(arg, *args, **kwargs) for arg in expr.args]
+            ret = ExprOp(expr.op, *args)
+        elif expr.is_compose():
+            args = [self.visit(arg, *args, **kwargs) for arg in expr.args]
+            ret = ExprCompose(*args)
+        else:
+            raise TypeError("Visitor can only take Expr")
+        return ret
+
+
+class ExprVisitorCallbackTopToBottom(ExprVisitorBase):
+    """
+    Rebuild expression by visiting sub-expressions
+    Call @callback on each sub-expression
+    if @Ā¢allback return non None value, replace current node with this value
+    Else, continue visit of sub-expressions
+    """
+    def __init__(self, callback):
+        super(ExprVisitorCallbackTopToBottom, self).__init__()
+        self.cache = dict()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return self.cache[expr]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[expr] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        ret = self.callback(expr)
+        if ret:
+            return ret
+        ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs)
+        return ret
+
+
+class ExprVisitorCallbackBottomToTop(ExprVisitorBase):
+    """
+    Rebuild expression by visiting sub-expressions
+    Call @callback from leaves to root expressions
+    """
+    def __init__(self, callback):
+        super(ExprVisitorCallbackBottomToTop, self).__init__()
+        self.cache = dict()
+        self.callback = callback
+
+    def visit(self, expr, *args, **kwargs):
+        if expr in self.cache:
+            return self.cache[expr]
+        ret = self.visit_inner(expr, *args, **kwargs)
+        self.cache[expr] = ret
+        return ret
+
+    def visit_inner(self, expr, *args, **kwargs):
+        ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs)
+        ret = self.callback(ret)
+        return ret
+
+
+class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop):
+    def __init__(self):
+        super(ExprVisitorCanonize, self).__init__(self.canonize)
+
+    def canonize(self, expr):
+        if not expr.is_op():
+            return expr
+        if not expr.is_associative():
+            return expr
+
+        # ((a+b) + c) => (a + b + c)
+        args = []
+        for arg in expr.args:
+            if isinstance(arg, ExprOp) and expr.op == arg.op:
+                args += arg.args
+            else:
+                args.append(arg)
+        args = canonize_expr_list(args)
+        new_expr = ExprOp(expr.op, *args)
+        return new_expr
+
+
+class ExprVisitorContains(ExprWalkBase):
+    """
+    Visitor to test if a needle is in an Expression
+    Cache results
+    """
+    def __init__(self):
+        self.cache = set()
+        super(ExprVisitorContains, self).__init__(self.eq_expr)
+
+    def eq_expr(self, expr, needle, *args, **kwargs):
+        if expr == needle:
+            return True
+        return None
+
+    def visit(self, expr, needle,  *args, **kwargs):
+        if (expr, needle) in self.cache:
+            return None
+        ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs)
+        if ret:
+            return ret
+        self.cache.add((expr, needle))
+        return None
+
+
+    def contains(self, expr, needle):
+        return self.visit(expr, needle)
+
+contains_visitor = ExprVisitorContains()
+canonize_visitor = ExprVisitorCanonize()
+
 # IR definitions
 
 class Expr(object):
@@ -337,36 +582,16 @@ class Expr(object):
         """Find and replace sub expression using dct
         @dct: dictionary associating replaced Expr to its new Expr value
         """
-        return self.visit(lambda expr: dct.get(expr, expr))
+        def replace(expr):
+            if expr in dct:
+                return dct[expr]
+            return None
+        visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr))
+        return visitor.visit(self)
 
     def canonize(self):
         "Canonize the Expression"
-
-        def must_canon(expr):
-            return not expr.is_canon
-
-        def canonize_visitor(expr):
-            if expr.is_canon:
-                return expr
-            if isinstance(expr, ExprOp):
-                if expr.is_associative():
-                    # ((a+b) + c) => (a + b + c)
-                    args = []
-                    for arg in expr.args:
-                        if isinstance(arg, ExprOp) and expr.op == arg.op:
-                            args += arg.args
-                        else:
-                            args.append(arg)
-                    args = canonize_expr_list(args)
-                    new_e = ExprOp(expr.op, *args)
-                else:
-                    new_e = expr
-            else:
-                new_e = expr
-            new_e.is_canon = True
-            return new_e
-
-        return self.visit(canonize_visitor, must_canon)
+        return canonize_visitor.visit(self)
 
     def msb(self):
         "Return the Most Significant Bit"
@@ -453,6 +678,32 @@ class Expr(object):
         """Returns True if is ExprMem and ptr is_op_segm"""
         return False
 
+    def __contains__(self, expr):
+        ret = contains_visitor.contains(self, expr)
+        return ret
+
+    def visit(self, callback):
+        """
+        Apply callbak to all sub expression of @self
+        This function keeps a cache to avoid rerunning @callbak on common sub
+        expressions.
+
+        @callback: fn(Expr) -> Expr
+        """
+        visitor = ExprVisitorCallbackBottomToTop(callback)
+        return visitor.visit(self)
+
+    def get_r(self, mem_read=False, cst_read=False):
+        visitor = ExprGetR(mem_read, cst_read)
+        visitor.visit(self)
+        return visitor.elements
+
+
+    def get_w(self, mem_read=False, cst_read=False):
+        if self.is_assign():
+            return set([self.dst])
+        return set()
+
 class ExprInt(Expr):
 
     """An ExprInt represent a constant in Miasm IR.
@@ -512,12 +763,6 @@ class ExprInt(Expr):
         else:
             return str("0x%X" % self._get_int())
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if cst_read:
-            return set([self])
-        else:
-            return set()
-
     def get_w(self):
         return set()
 
@@ -528,13 +773,6 @@ class ExprInt(Expr):
         return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(),
                                  self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprInt(self._arg, self._size)
 
@@ -595,9 +833,6 @@ class ExprId(Expr):
     def __str__(self):
         return str(self._name)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return set([self])
-
     def get_w(self):
         return set([self])
 
@@ -607,13 +842,6 @@ class ExprId(Expr):
     def _exprrepr(self):
         return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprId(self._name, self._size)
 
@@ -657,12 +885,6 @@ class ExprLoc(Expr):
     def __str__(self):
         return str(self._loc_key)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if cst_read:
-            return set([self])
-        else:
-            return set()
-
     def get_w(self):
         return set()
 
@@ -672,13 +894,6 @@ class ExprLoc(Expr):
     def _exprrepr(self):
         return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size)
 
-    def __contains__(self, expr):
-        return self == expr
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        return self
-
     def copy(self):
         return ExprLoc(self._loc_key, self._size)
 
@@ -749,12 +964,6 @@ class ExprAssign(Expr):
     def __str__(self):
         return "%s = %s" % (str(self._dst), str(self._src))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        elements = self._src.get_r(mem_read, cst_read)
-        if isinstance(self._dst, ExprMem) and mem_read:
-            elements.update(self._dst.ptr.get_r(mem_read, cst_read))
-        return elements
-
     def get_w(self):
         if isinstance(self._dst, ExprMem):
             return set([self._dst])  # [memreg]
@@ -767,19 +976,6 @@ class ExprAssign(Expr):
     def _exprrepr(self):
         return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src)
 
-    def __contains__(self, expr):
-        return (self == expr or
-                self._src.__contains__(expr) or
-                self._dst.__contains__(expr))
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit)
-        if dst == self._dst and src == self._src:
-            return self
-        else:
-            return ExprAssign(dst, src)
-
     def copy(self):
         return ExprAssign(self._dst.copy(), self._src.copy())
 
@@ -854,12 +1050,6 @@ class ExprCond(Expr):
     def __str__(self):
         return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        out_src1 = self.src1.get_r(mem_read, cst_read)
-        out_src2 = self.src2.get_r(mem_read, cst_read)
-        return self.cond.get_r(mem_read,
-                               cst_read).union(out_src1).union(out_src2)
-
     def get_w(self):
         return set()
 
@@ -871,21 +1061,6 @@ class ExprCond(Expr):
         return "%s(%r, %r, %r)" % (self.__class__.__name__,
                                    self._cond, self._src1, self._src2)
 
-    def __contains__(self, expr):
-        return (self == expr or
-                self.cond.__contains__(expr) or
-                self.src1.__contains__(expr) or
-                self.src2.__contains__(expr))
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        cond = self._cond.visit(callback, test_visit)
-        src1 = self._src1.visit(callback, test_visit)
-        src2 = self._src2.visit(callback, test_visit)
-        if cond == self._cond and src1 == self._src1 and src2 == self._src2:
-            return self
-        return ExprCond(cond, src1, src2)
-
     def copy(self):
         return ExprCond(self._cond.copy(),
                         self._src1.copy(),
@@ -962,12 +1137,6 @@ class ExprMem(Expr):
     def __str__(self):
         return "@%d[%s]" % (self.size, str(self.ptr))
 
-    def get_r(self, mem_read=False, cst_read=False):
-        if mem_read:
-            return set(self._ptr.get_r(mem_read, cst_read).union(set([self])))
-        else:
-            return set([self])
-
     def get_w(self):
         return set([self])  # [memreg]
 
@@ -978,16 +1147,6 @@ class ExprMem(Expr):
         return "%s(%r, %r)" % (self.__class__.__name__,
                                self._ptr, self._size)
 
-    def __contains__(self, expr):
-        return self == expr or self._ptr.__contains__(expr)
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        ptr = self._ptr.visit(callback, test_visit)
-        if ptr == self._ptr:
-            return self
-        return ExprMem(ptr, self.size)
-
     def copy(self):
         ptr = self.ptr.copy()
         return ExprMem(ptr, size=self.size)
@@ -1117,10 +1276,6 @@ class ExprOp(Expr):
         return (self._op + '(' +
                 ', '.join([str(arg) for arg in self._args]) + ')')
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
-
     def get_w(self):
         raise ValueError('op cannot be written!', self)
 
@@ -1132,14 +1287,6 @@ class ExprOp(Expr):
         return "%s(%r, %s)" % (self.__class__.__name__, self._op,
                                ', '.join(repr(arg) for arg in self._args))
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        for arg in self._args:
-            if arg.__contains__(expr):
-                return True
-        return False
-
     def is_function_call(self):
         return self._op.startswith('call')
 
@@ -1162,14 +1309,6 @@ class ExprOp(Expr):
         "Return True iff current operation is commutative"
         return (self._op in ['+', '*', '^', '&', '|'])
 
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        args = [arg.visit(callback, test_visit) for arg in self._args]
-        modified = any([arg[0] != arg[1] for arg in zip(self._args, args)])
-        if modified:
-            return ExprOp(self._op, *args)
-        return self
-
     def copy(self):
         args = [arg.copy() for arg in self._args]
         return ExprOp(self._op, *args)
@@ -1222,9 +1361,6 @@ class ExprSlice(Expr):
     def __str__(self):
         return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop)
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return self._arg.get_r(mem_read, cst_read)
-
     def get_w(self):
         return self._arg.get_w()
 
@@ -1235,18 +1371,6 @@ class ExprSlice(Expr):
         return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg,
                                    self._start, self._stop)
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        return self._arg.__contains__(expr)
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        arg = self._arg.visit(callback, test_visit)
-        if arg == self._arg:
-            return self
-        return ExprSlice(arg, self._start, self._stop)
-
     def copy(self):
         return ExprSlice(self._arg.copy(), self._start, self._stop)
 
@@ -1319,10 +1443,6 @@ class ExprCompose(Expr):
     def __str__(self):
         return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}'
 
-    def get_r(self, mem_read=False, cst_read=False):
-        return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
-
     def get_w(self):
         return reduce(lambda elements, arg:
                       elements.union(arg.get_w()), self._args, set())
@@ -1334,24 +1454,6 @@ class ExprCompose(Expr):
     def _exprrepr(self):
         return "%s%r" % (self.__class__.__name__, self._args)
 
-    def __contains__(self, expr):
-        if self == expr:
-            return True
-        for arg in self._args:
-            if arg == expr:
-                return True
-            if arg.__contains__(expr):
-                return True
-        return False
-
-    @visit_chk
-    def visit(self, callback, test_visit=None):
-        args = [arg.visit(callback, test_visit) for arg in self._args]
-        modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)])
-        if modified:
-            return ExprCompose(*args)
-        return self
-
     def copy(self):
         args = [arg.copy() for arg in self._args]
         return ExprCompose(*args)
diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py
index 7b1d3629..a56aa0f8 100644
--- a/miasm/expression/simplifications.py
+++ b/miasm/expression/simplifications.py
@@ -11,6 +11,7 @@ from miasm.expression import simplifications_cond
 from miasm.expression import simplifications_explicit
 from miasm.expression.expression_helper import fast_unify
 import miasm.expression.expression as m2_expr
+from miasm.expression.expression import ExprVisitorCallbackBottomToTop
 
 # Expression Simplifier
 # ---------------------
@@ -22,7 +23,7 @@ log_exprsimp.addHandler(console_handler)
 log_exprsimp.setLevel(logging.WARNING)
 
 
-class ExpressionSimplifier(object):
+class ExpressionSimplifier(ExprVisitorCallbackBottomToTop):
 
     """Wrapper on expression simplification passes.
 
@@ -118,8 +119,8 @@ class ExpressionSimplifier(object):
 
 
     def __init__(self):
+        super(ExpressionSimplifier, self).__init__(self.expr_simp_inner)
         self.expr_simp_cb = {}
-        self.simplified_exprs = set()
 
     def enable_passes(self, passes):
         """Add passes from @passes
@@ -129,7 +130,7 @@ class ExpressionSimplifier(object):
         """
 
         # Clear cache of simplifiied expressions when adding a new pass
-        self.simplified_exprs.clear()
+        self.cache.clear()
 
         for k, v in viewitems(passes):
             self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v)
@@ -156,46 +157,29 @@ class ExpressionSimplifier(object):
 
         return expression
 
-    def expr_simp(self, expression):
+    def expr_simp_inner(self, expression):
         """Apply enabled simplifications on expression and find a stable state
         @expression: Expr instance
         Return an Expr instance"""
 
-        if expression in self.simplified_exprs:
-            return expression
-
         # Find a stable state
         while True:
             # Canonize and simplify
-            e_new = self.apply_simp(expression.canonize())
-            if e_new == expression:
-                break
-
-            # Launch recursivity
-            expression = self.expr_simp_wrapper(e_new)
-            self.simplified_exprs.add(expression)
-        # Mark expression as simplified
-        self.simplified_exprs.add(e_new)
-
-        return e_new
-
-    def expr_simp_wrapper(self, expression, callback=None):
-        """Apply enabled simplifications on expression
-        @expression: Expr instance
-        @manual_callback: If set, call this function instead of normal one
-        Return an Expr instance"""
+            new_expr = self.apply_simp(expression.canonize())
+            if new_expr == expression:
+                return new_expr
+            # Run recursively simplification on fresh new expression
+            new_expr = self.visit(new_expr)
+            expression = new_expr
+        return new_expr
 
-        if expression in self.simplified_exprs:
-            return expression
-
-        if callback is None:
-            callback = self.expr_simp
-
-        return expression.visit(callback, lambda e: e not in self.simplified_exprs)
+    def expr_simp(self, expression):
+        "Call simplification recursively"
+        return self.visit(expression)
 
-    def __call__(self, expression, callback=None):
-        "Wrapper on expr_simp_wrapper"
-        return self.expr_simp_wrapper(expression, callback)
+    def __call__(self, expression):
+        "Call simplification recursively"
+        return self.visit(expression)
 
 
 # Public ExprSimplificationPass instance with commons passes
diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py
index 90f8945b..38859f3a 100644
--- a/miasm/expression/simplifications_common.py
+++ b/miasm/expression/simplifications_common.py
@@ -450,8 +450,8 @@ def simp_cond_factor(e_s, expr):
     for cond, vals in viewitems(conds):
         new_src1 = [x.src1 for x in vals]
         new_src2 = [x.src2 for x in vals]
-        src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1))
-        src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2))
+        src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1))
+        src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2))
         c_out.append(ExprCond(cond, src1, src2))
 
     if len(c_out) == 1:
@@ -521,7 +521,7 @@ def simp_slice(e_s, expr):
     # distributivity of slice and &
     # (a & int)[x:y] => 0 if int[x:y] == 0
     if expr.arg.is_op("&") and expr.arg.args[-1].is_int():
-        tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop])
+        tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop])
         if tmp.is_int(0):
             return tmp
     # distributivity of slice and exprcond
@@ -536,7 +536,7 @@ def simp_slice(e_s, expr):
 
     # (a * int)[0:y] => (a[0:y] * int[0:y])
     if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int():
-        args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args]
+        args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args]
         return ExprOp(expr.arg.op, *args)
 
     # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size
diff --git a/test/expression/expression.py b/test/expression/expression.py
index 3597eae8..9b0c2807 100644
--- a/test/expression/expression.py
+++ b/test/expression/expression.py
@@ -17,6 +17,7 @@ assert big_cst.size == 0x1000
 # Possible values
 #- Common constants
 A = ExprId("A", 32)
+B = ExprId("B", 32)
 cond1 = ExprId("cond1", 1)
 cond2 = ExprId("cond2", 16)
 cst1 = ExprInt(1, 32)
@@ -71,3 +72,47 @@ for expr in [
 aff = ExprAssign(A[0:32], cst1)
 
 assert aff.dst == A and aff.src == cst1
+
+
+mem = ExprMem(A, 32)
+assert mem.get_r() == set([mem])
+assert mem.get_r(mem_read=True) == set([mem, A])
+
+C = A+B
+D = C + A
+
+assert A in A
+assert A in C
+assert B in C
+assert C in C
+
+assert A in D
+assert B in D
+assert C in D
+assert D in D
+
+assert C not in A
+assert C not in B
+
+assert D not in A
+assert D not in B
+assert D not in C
+
+
+assert cst1.get_r(cst_read=True) == set([cst1])
+mem1 = ExprMem(A, 32)
+mem2 = ExprMem(mem1 + B, 32)
+assert mem2.get_r() == set([mem2])
+
+assign1 = ExprAssign(A, cst1)
+assert assign1.get_r() == set([])
+
+assign2 = ExprAssign(mem1, D)
+assert assign2.get_r() == set([A, B])
+assert assign2.get_r(mem_read=True) == set([A, B])
+assert assign2.get_w() == set([mem1])
+
+assign3 = ExprAssign(mem1, mem2)
+assert assign3.get_r() == set([mem2])
+assert assign3.get_r(mem_read=True) == set([mem1, mem2, A, B])
+assert assign3.get_w() == set([mem1])