about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/simplifications_common.py35
-rw-r--r--test/expression/simplifications.py45
2 files changed, 59 insertions, 21 deletions
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 22e328e1..ccb97cb3 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -196,21 +196,34 @@ def simp_cst_propagation(e_s, expr):
         args[1].arg == args[0].size):
         return args[0]
 
-    # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>)
+    # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
     if (op_name in ['<<<', '>>>'] and
         args[0].is_op() and
         args[0].op in ['<<<', '>>>']):
-        op1 = op_name
-        op2 = args[0].op
-        if op1 == op2:
-            op_name = op1
-            args1 = args[0].args[1] + args[1]
-        else:
-            op_name = op2
-            args1 = args[0].args[1] - args[1]
+        A = args[0].args[0]
+        X = args[0].args[1]
+        Y = args[1]
+        if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size):
+            return args[0].args[0]
+        elif X.is_int() and Y.is_int():
+            new_X = int(X) % expr.size
+            new_Y = int(Y) % expr.size
+            if op_name == args[0].op:
+                rot = (new_X + new_Y) % expr.size
+                op = op_name
+            else:
+                rot = new_Y - new_X
+                op = op_name
+                if rot < 0:
+                    rot = - rot
+                    op = {">>>": "<<<", "<<<": ">>>"}[op_name]
+            args = [A, ExprInt(rot, expr.size)]
+            op_name = op
 
-        args0 = args[0].args[0]
-        args = [args0, args1]
+        else:
+            # Do not consider this case, too tricky (overflow on addition /
+            # substraction)
+            pass
 
     # A >> X >> Y  =>  A >> (X+Y) if X + Y does not overflow
     # To be sure, only consider the simplification when X.msb and Y.msb are 0
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index cb8dc4f8..2650d4d1 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -86,6 +86,13 @@ f = ExprId('f', size=64)
 b_msb_null = b[:31].zeroExtend(32)
 c_msb_null = c[:31].zeroExtend(32)
 
+a31 = ExprId('a31', 31)
+b31 = ExprId('b31', 31)
+c31 = ExprId('c31', 31)
+b31_msb_null = ExprId('b31', 31)[:30].zeroExtend(31)
+c31_msb_null = ExprId('c31', 31)[:30].zeroExtend(31)
+
+
 m = ExprMem(a)
 s = a[:8]
 
@@ -120,17 +127,35 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
            (ExprOp('>>>', a, ExprInt(32, 32)), a),
            (ExprOp('>>>', a, ExprInt(0, 32)), a),
            (ExprOp('<<', a, ExprInt(0, 32)), a),
+           (ExprOp('<<<', a31, ExprInt(31, 31)), a31),
+           (ExprOp('>>>', a31, ExprInt(31, 31)), a31),
+           (ExprOp('>>>', a31, ExprInt(0, 31)), a31),
+           (ExprOp('<<', a31, ExprInt(0, 31)), a31),
+
+           (ExprOp('<<<', a31, ExprOp('<<<', b31, c31)),
+            ExprOp('<<<', a31, ExprOp('<<<', b31, c31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, b31), c31),
+            ExprOp('<<<', ExprOp('>>>', a31, b31), c31)),
+           (ExprOp('>>>', ExprOp('<<<', a31, b31), c31),
+            ExprOp('>>>', ExprOp('<<<', a31, b31), c31)),
+           (ExprOp('>>>', ExprOp('<<<', a31, b31), b31),
+            a31),
+           (ExprOp('<<<', ExprOp('>>>', a31, b31), b31),
+            a31),
+           (ExprOp('>>>', ExprOp('>>>', a31, b31), b31),
+            ExprOp('>>>', ExprOp('>>>', a31, b31), b31)),
+           (ExprOp('<<<', ExprOp('<<<', a31, b31), b31),
+            ExprOp('<<<', ExprOp('<<<', a31, b31), b31)),
+
+           (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)),
+            ExprOp('>>>', a31, ExprInt(0x13, 31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)),
+            ExprOp('<<<', a31, ExprInt(0x13, 31))),
+           (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)),
+            ExprOp('>>>', a31, ExprInt(0x1c, 31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)),
+            ExprOp('<<<', a31, ExprInt(0x1c, 31))),
 
-           (ExprOp('<<<', a, ExprOp('<<<', b, c)),
-            ExprOp('<<<', a, ExprOp('<<<', b, c))),
-           (ExprOp('<<<', ExprOp('<<<', a, b), c),
-            ExprOp('<<<', a, (b+c))),
-           (ExprOp('<<<', ExprOp('>>>', a, b), c),
-            ExprOp('>>>', a, (b-c))),
-           (ExprOp('>>>', ExprOp('<<<', a, b), c),
-            ExprOp('<<<', a, (b-c))),
-           (ExprOp('>>>', ExprOp('<<<', a, b), b),
-            a),
            (ExprOp(">>>", ExprInt(0x1000, 16), ExprInt(0x11, 16)),
             ExprInt(0x800, 16)),
            (ExprOp("<<<", ExprInt(0x1000, 16), ExprInt(0x11, 16)),