about summary refs log tree commit diff stats
path: root/miasm2/expression
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/expression')
-rw-r--r--miasm2/expression/expression_helper.py10
-rw-r--r--miasm2/expression/simplifications.py14
-rw-r--r--miasm2/expression/simplifications_common.py58
3 files changed, 62 insertions, 20 deletions
diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py
index 1e718faa..722d169d 100644
--- a/miasm2/expression/expression_helper.py
+++ b/miasm2/expression/expression_helper.py
@@ -21,6 +21,7 @@ import itertools
 import collections
 import random
 import string
+import warnings
 
 import miasm2.expression.expression as m2_expr
 
@@ -468,16 +469,14 @@ class ExprRandom(object):
 
         return got
 
-def _expr_cmp_gen(arg1, arg2):
-    return (arg2 - arg1) ^ ((arg2 ^ arg1) & ((arg2 - arg1) ^ arg2))
-
 def expr_cmpu(arg1, arg2):
     """
     Returns a one bit long Expression:
     * 1 if @arg1 is strictly greater than @arg2 (unsigned)
     * 0 otherwise.
     """
-    return (_expr_cmp_gen(arg1, arg2) ^ arg2 ^ arg1).msb()
+    warnings.warn('DEPRECATION WARNING: use "expr_is_unsigned_greater" instead"')
+    return m2_expr.expr_is_unsigned_greater(arg1, arg2)
 
 def expr_cmps(arg1, arg2):
     """
@@ -485,7 +484,8 @@ def expr_cmps(arg1, arg2):
     * 1 if @arg1 is strictly greater than @arg2 (signed)
     * 0 otherwise.
     """
-    return _expr_cmp_gen(arg1, arg2).msb()
+    warnings.warn('DEPRECATION WARNING: use "expr_is_signed_greater" instead"')
+    return m2_expr.expr_is_signed_greater(arg1, arg2)
 
 
 class CondConstraint(object):
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index d3483d9e..e6c5dc54 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -2,6 +2,8 @@
 #                     Simplification methods library                           #
 #                                                                              #
 
+import logging
+
 from miasm2.expression import simplifications_common
 from miasm2.expression import simplifications_cond
 from miasm2.expression.expression_helper import fast_unify
@@ -10,6 +12,12 @@ import miasm2.expression.expression as m2_expr
 # Expression Simplifier
 # ---------------------
 
+log_exprsimp = logging.getLogger("exprsimp")
+console_handler = logging.StreamHandler()
+console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
+log_exprsimp.addHandler(console_handler)
+log_exprsimp.setLevel(logging.WARNING)
+
 
 class ExpressionSimplifier(object):
 
@@ -67,9 +75,15 @@ class ExpressionSimplifier(object):
         Return an Expr instance"""
 
         cls = expression.__class__
+        debug_level = log_exprsimp.level >= logging.DEBUG
         for simp_func in self.expr_simp_cb.get(cls, []):
             # Apply simplifications
+            before = expression
             expression = simp_func(self, expression)
+            after = expression
+
+            if debug_level and before != after:
+                log_exprsimp.debug("[%s] %s => %s", simp_func, before, after)
 
             # If class changes, stop to prevent wrong simplifications
             if expression.__class__ is not cls:
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 02b43c4b..ccb97cb3 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -196,26 +196,44 @@ def simp_cst_propagation(e_s, expr):
         args[1].arg == args[0].size):
         return args[0]
 
-    # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>)
+    # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
     if (op_name in ['<<<', '>>>'] and
         args[0].is_op() and
         args[0].op in ['<<<', '>>>']):
-        op1 = op_name
-        op2 = args[0].op
-        if op1 == op2:
-            op_name = op1
-            args1 = args[0].args[1] + args[1]
-        else:
-            op_name = op2
-            args1 = args[0].args[1] - args[1]
+        A = args[0].args[0]
+        X = args[0].args[1]
+        Y = args[1]
+        if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size):
+            return args[0].args[0]
+        elif X.is_int() and Y.is_int():
+            new_X = int(X) % expr.size
+            new_Y = int(Y) % expr.size
+            if op_name == args[0].op:
+                rot = (new_X + new_Y) % expr.size
+                op = op_name
+            else:
+                rot = new_Y - new_X
+                op = op_name
+                if rot < 0:
+                    rot = - rot
+                    op = {">>>": "<<<", "<<<": ">>>"}[op_name]
+            args = [A, ExprInt(rot, expr.size)]
+            op_name = op
 
-        args0 = args[0].args[0]
-        args = [args0, args1]
+        else:
+            # Do not consider this case, too tricky (overflow on addition /
+            # substraction)
+            pass
 
-    # A >> X >> Y  =>  A >> (X+Y)
+    # A >> X >> Y  =>  A >> (X+Y) if X + Y does not overflow
+    # To be sure, only consider the simplification when X.msb and Y.msb are 0
     if (op_name in ['<<', '>>'] and
         args[0].is_op(op_name)):
-        args = [args[0].args[0], args[0].args[1] + args[1]]
+        X = args[0].args[1]
+        Y = args[1]
+        if (e_s(X.msb()) == ExprInt(0, 1) and
+            e_s(Y.msb()) == ExprInt(0, 1)):
+            args = [args[0].args[0], X + Y]
 
     # ((A & A.mask)
     if op_name == "&" and args[-1] == expr.mask:
@@ -327,7 +345,7 @@ def simp_cond_op_int(e_s, expr):
     "Extract conditions from operations"
 
 
-    # x?a:b + x?c:d + e => x?(a+b+e:c+d+e)
+    # x?a:b + x?c:d + e => x?(a+c+e:b+d+e)
     if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']:
         return expr
     if len(expr.args) < 2:
@@ -360,6 +378,14 @@ def simp_cond_factor(e_s, expr):
         return expr
     if len(expr.args) < 2:
         return expr
+
+    if expr.op in ['>>', '<<', 'a>>']:
+        assert len(expr.args) == 2
+
+    # Note: the following code is correct for non-commutative operation only if
+    # there is 2 arguments. Otherwise, the order is not conserved
+
+    # Regroup sub-expression by similar conditions
     conds = {}
     not_conds = []
     multi_cond = False
@@ -375,7 +401,9 @@ def simp_cond_factor(e_s, expr):
         conds[cond].append(arg)
     if not multi_cond:
         return expr
-    c_out = not_conds[:]
+
+    # Rebuild the new expression
+    c_out = not_conds
     for cond, vals in conds.items():
         new_src1 = [x.src1 for x in vals]
         new_src2 = [x.src2 for x in vals]