about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/simplifications_common.py20
-rw-r--r--test/expression/simplifications.py4
2 files changed, 24 insertions, 0 deletions
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 13b25ce2..149c5b8d 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -250,6 +250,26 @@ def simp_cst_propagation(e_s, expr):
             e_s(Y.msb()) == ExprInt(0, 1)):
             args = [args[0].args[0], X + Y]
 
+    # ((var >> int1) << int1) => var & mask
+    # ((var << int1) >> int1) => var & mask
+    if (op_name in ['<<', '>>'] and
+        args[0].is_op() and
+        args[0].op in ['<<', '>>'] and
+        op_name != args[0]):
+        var = args[0].args[0]
+        int1 = args[0].args[1]
+        int2 = args[1]
+        if int1 == int2 and int1.is_int() and int(int1) < expr.size:
+            if op_name == '>>':
+                mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size)
+            else:
+                mask = ExprInt(
+                    ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1),
+                    expr.size
+                )
+            ret = var & mask
+            return ret
+
     # ((A & A.mask)
     if op_name == "&" and args[-1] == expr.mask:
         return ExprOp('&', *args[:-1])
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index a4e839cf..b2591a83 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -177,6 +177,10 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
            (ExprInt(0x4142, 32)[:32], ExprInt(0x4142, 32)),
            (ExprInt(0x4142, 32)[:8], ExprInt(0x42, 8)),
            (ExprInt(0x4142, 32)[8:16], ExprInt(0x41, 8)),
+           (ExprOp('>>', ExprOp('<<', a, ExprInt(0x4, 32)), ExprInt(0x4, 32)),
+            ExprOp('&', a, ExprInt(0x0FFFFFFF, 32))),
+           (ExprOp('<<', ExprOp('>>', a, ExprInt(0x4, 32)), ExprInt(0x4, 32)),
+            ExprOp('&', a, ExprInt(0xFFFFFFF0, 32))),
            (a[:32], a),
            (a[:8][:8], a[:8]),
            (a[:16][:8], a[:8]),