about summary refs log tree commit diff stats
path: root/miasm/expression/simplifications.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm/expression/simplifications.py')
-rw-r--r--miasm/expression/simplifications.py54
1 files changed, 20 insertions, 34 deletions
diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py
index 03a779a6..3f54b158 100644
--- a/miasm/expression/simplifications.py
+++ b/miasm/expression/simplifications.py
@@ -11,6 +11,7 @@ from miasm.expression import simplifications_cond
 from miasm.expression import simplifications_explicit
 from miasm.expression.expression_helper import fast_unify
 import miasm.expression.expression as m2_expr
+from miasm.expression.expression import ExprVisitorCallbackBottomToTop
 
 # Expression Simplifier
 # ---------------------
@@ -22,7 +23,7 @@ log_exprsimp.addHandler(console_handler)
 log_exprsimp.setLevel(logging.WARNING)
 
 
-class ExpressionSimplifier(object):
+class ExpressionSimplifier(ExprVisitorCallbackBottomToTop):
 
     """Wrapper on expression simplification passes.
 
@@ -49,6 +50,8 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_double_signext,
             simplifications_common.simp_zeroext_eq_cst,
             simplifications_common.simp_ext_eq_ext,
+            simplifications_common.simp_ext_cond_int,
+            simplifications_common.simp_sub_cf_zero,
 
             simplifications_common.simp_cmp_int,
             simplifications_common.simp_cmp_bijective_op,
@@ -118,8 +121,8 @@ class ExpressionSimplifier(object):
 
 
     def __init__(self):
+        super(ExpressionSimplifier, self).__init__(self.expr_simp_inner)
         self.expr_simp_cb = {}
-        self.simplified_exprs = set()
 
     def enable_passes(self, passes):
         """Add passes from @passes
@@ -129,7 +132,7 @@ class ExpressionSimplifier(object):
         """
 
         # Clear cache of simplifiied expressions when adding a new pass
-        self.simplified_exprs.clear()
+        self.cache.clear()
 
         for k, v in viewitems(passes):
             self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v)
@@ -156,46 +159,29 @@ class ExpressionSimplifier(object):
 
         return expression
 
-    def expr_simp(self, expression):
+    def expr_simp_inner(self, expression):
         """Apply enabled simplifications on expression and find a stable state
         @expression: Expr instance
         Return an Expr instance"""
 
-        if expression in self.simplified_exprs:
-            return expression
-
         # Find a stable state
         while True:
             # Canonize and simplify
-            e_new = self.apply_simp(expression.canonize())
-            if e_new == expression:
-                break
-
-            # Launch recursivity
-            expression = self.expr_simp_wrapper(e_new)
-            self.simplified_exprs.add(expression)
-        # Mark expression as simplified
-        self.simplified_exprs.add(e_new)
-
-        return e_new
-
-    def expr_simp_wrapper(self, expression, callback=None):
-        """Apply enabled simplifications on expression
-        @expression: Expr instance
-        @manual_callback: If set, call this function instead of normal one
-        Return an Expr instance"""
+            new_expr = self.apply_simp(expression.canonize())
+            if new_expr == expression:
+                return new_expr
+            # Run recursively simplification on fresh new expression
+            new_expr = self.visit(new_expr)
+            expression = new_expr
+        return new_expr
 
-        if expression in self.simplified_exprs:
-            return expression
-
-        if callback is None:
-            callback = self.expr_simp
-
-        return expression.visit(callback, lambda e: e not in self.simplified_exprs)
+    def expr_simp(self, expression):
+        "Call simplification recursively"
+        return self.visit(expression)
 
-    def __call__(self, expression, callback=None):
-        "Wrapper on expr_simp_wrapper"
-        return self.expr_simp_wrapper(expression, callback)
+    def __call__(self, expression):
+        "Call simplification recursively"
+        return self.visit(expression)
 
 
 # Public ExprSimplificationPass instance with commons passes