about summary refs log tree commit diff stats
path: root/miasm/expression/simplifications_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm/expression/simplifications_common.py')
-rw-r--r--miasm/expression/simplifications_common.py138
1 files changed, 88 insertions, 50 deletions
diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py
index 1c0bb04c..932db49a 100644
--- a/miasm/expression/simplifications_common.py
+++ b/miasm/expression/simplifications_common.py
@@ -32,30 +32,30 @@ def simp_cst_propagation(e_s, expr):
             int2 = args.pop()
             int1 = args.pop()
             if op_name == '+':
-                out = int1.arg + int2.arg
+                out = mod_size2uint[int1.size](int(int1) + int(int2))
             elif op_name == '*':
-                out = int1.arg * int2.arg
+                out = mod_size2uint[int1.size](int(int1) * int(int2))
             elif op_name == '**':
-                out =int1.arg ** int2.arg
+                out = mod_size2uint[int1.size](int(int1) ** int(int2))
             elif op_name == '^':
-                out = int1.arg ^ int2.arg
+                out = mod_size2uint[int1.size](int(int1) ^ int(int2))
             elif op_name == '&':
-                out = int1.arg & int2.arg
+                out = mod_size2uint[int1.size](int(int1) & int(int2))
             elif op_name == '|':
-                out = int1.arg | int2.arg
+                out = mod_size2uint[int1.size](int(int1) | int(int2))
             elif op_name == '>>':
                 if int(int2) > int1.size:
                     out = 0
                 else:
-                    out = int1.arg >> int2.arg
+                    out = mod_size2uint[int1.size](int(int1) >> int(int2))
             elif op_name == '<<':
                 if int(int2) > int1.size:
                     out = 0
                 else:
-                    out = int1.arg << int2.arg
+                    out = mod_size2uint[int1.size](int(int1) << int(int2))
             elif op_name == 'a>>':
-                tmp1 = mod_size2int[int1.arg.size](int1.arg)
-                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
                 if tmp2 > int1.size:
                     is_signed = int(int1) & (1 << (int1.size - 1))
                     if is_signed:
@@ -63,55 +63,57 @@ def simp_cst_propagation(e_s, expr):
                     else:
                         out = 0
                 else:
-                    out = mod_size2uint[int1.arg.size](tmp1 >> tmp2)
+                    out = mod_size2uint[int1.size](tmp1 >> tmp2)
             elif op_name == '>>>':
-                shifter = int2.arg % int2.size
-                out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter))
+                shifter = int(int2) % int2.size
+                out = (int(int1) >> shifter) | (int(int1) << (int2.size - shifter))
             elif op_name == '<<<':
-                shifter = int2.arg % int2.size
-                out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter))
+                shifter = int(int2) % int2.size
+                out = (int(int1) << shifter) | (int(int1) >> (int2.size - shifter))
             elif op_name == '/':
-                out = int1.arg // int2.arg
+                assert int(int2), "division by 0"
+                out = int(int1) // int(int2)
             elif op_name == '%':
-                out = int1.arg % int2.arg
+                assert int(int2), "division by 0"
+                out = int(int1) % int(int2)
             elif op_name == 'sdiv':
-                assert int2.arg.arg
-                tmp1 = mod_size2int[int1.arg.size](int1.arg)
-                tmp2 = mod_size2int[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 // tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2int[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 // tmp2)
             elif op_name == 'smod':
-                assert int2.arg.arg
-                tmp1 = mod_size2int[int1.arg.size](int1.arg)
-                tmp2 = mod_size2int[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2int[int1.size](int(int1))
+                tmp2 = mod_size2int[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 % tmp2)
             elif op_name == 'umod':
-                assert int2.arg.arg
-                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
-                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2uint[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 % tmp2)
             elif op_name == 'udiv':
-                assert int2.arg.arg
-                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
-                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
-                out = mod_size2uint[int1.arg.size](tmp1 // tmp2)
+                assert int(int2), "division by 0"
+                tmp1 = mod_size2uint[int1.size](int(int1))
+                tmp2 = mod_size2uint[int2.size](int(int2))
+                out = mod_size2uint[int1.size](tmp1 // tmp2)
 
 
 
-            args.append(ExprInt(out, int1.size))
+            args.append(ExprInt(int(out), int1.size))
 
     # cnttrailzeros(int) => int
     if op_name == "cnttrailzeros" and args[0].is_int():
         i = 0
-        while args[0].arg & (1 << i) == 0 and i < args[0].size:
+        while int(args[0]) & (1 << i) == 0 and i < args[0].size:
             i += 1
         return ExprInt(i, args[0].size)
 
     # cntleadzeros(int) => int
     if op_name == "cntleadzeros" and args[0].is_int():
-        if args[0].arg == 0:
+        if int(args[0]) == 0:
             return ExprInt(args[0].size, args[0].size)
         i = args[0].size - 1
-        while args[0].arg & (1 << i) == 0:
+        while int(args[0]) & (1 << i) == 0:
             i -= 1
         return ExprInt(expr.size - (i + 1), args[0].size)
 
@@ -120,6 +122,7 @@ def simp_cst_propagation(e_s, expr):
         len(args[0].args) == 1):
         return args[0].args[0]
 
+
     # -(int) => -int
     if op_name == '-' and len(args) == 1 and args[0].is_int():
         return ExprInt(-int(args[0]), expr.size)
@@ -207,13 +210,13 @@ def simp_cst_propagation(e_s, expr):
             j += 1
         i += 1
 
-    if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1:
+    if op_name in ['+', '^', '|', '&', '%', '/', '**'] and len(args) == 1:
         return args[0]
 
     # A <<< A.size => A
     if (op_name in ['<<<', '>>>'] and
         args[1].is_int() and
-        args[1].arg == args[0].size):
+        int(args[1]) == args[0].size):
         return args[0]
 
     # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
@@ -277,7 +280,10 @@ def simp_cst_propagation(e_s, expr):
 
     # ((A & A.mask)
     if op_name == "&" and args[-1] == expr.mask:
-        return ExprOp('&', *args[:-1])
+        args = args[:-1]
+        if len(args) == 1:
+            return args[0]
+        return ExprOp('&', *args)
 
     # ((A | A.mask)
     if op_name == "|" and args[-1] == expr.mask:
@@ -289,7 +295,7 @@ def simp_cst_propagation(e_s, expr):
     # ((A & mask) >> shift) with mask < 2**shift => 0
     if op_name == ">>" and args[1].is_int() and args[0].is_op("&"):
         if (args[0].args[1].is_int() and
-            2 ** args[1].arg > args[0].args[1].arg):
+            2 ** int(args[1]) > int(args[0].args[1])):
             return ExprInt(0, args[0].size)
 
     # parity(int) => int
@@ -315,7 +321,6 @@ def simp_cst_propagation(e_s, expr):
         args = args[0].args
         return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)]))
 
-
     # A << int with A ExprCompose => move index
     if (op_name == "<<" and args[0].is_compose() and
         args[1].is_int() and int(args[1]) != 0):
@@ -450,8 +455,8 @@ def simp_cond_factor(e_s, expr):
     for cond, vals in viewitems(conds):
         new_src1 = [x.src1 for x in vals]
         new_src2 = [x.src2 for x in vals]
-        src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1))
-        src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2))
+        src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1))
+        src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2))
         c_out.append(ExprCond(cond, src1, src2))
 
     if len(c_out) == 1:
@@ -471,7 +476,7 @@ def simp_slice(e_s, expr):
     if expr.arg.is_int():
         total_bit = expr.stop - expr.start
         mask = (1 << (expr.stop - expr.start)) - 1
-        return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit)
+        return ExprInt(int((int(expr.arg) >> expr.start) & mask), total_bit)
     # Slice(Slice(A, x), y) => Slice(A, z)
     if expr.arg.is_slice():
         if expr.stop - expr.start > expr.arg.stop - expr.arg.start:
@@ -521,7 +526,7 @@ def simp_slice(e_s, expr):
     # distributivity of slice and &
     # (a & int)[x:y] => 0 if int[x:y] == 0
     if expr.arg.is_op("&") and expr.arg.args[-1].is_int():
-        tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop])
+        tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop])
         if tmp.is_int(0):
             return tmp
     # distributivity of slice and exprcond
@@ -536,7 +541,7 @@ def simp_slice(e_s, expr):
 
     # (a * int)[0:y] => (a[0:y] * int[0:y])
     if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int():
-        args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args]
+        args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args]
         return ExprOp(expr.arg.op, *args)
 
     # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size
@@ -626,7 +631,7 @@ def simp_cond(_, expr):
         expr = expr.src1
     # int ? A:B => A or B
     elif expr.cond.is_int():
-        if expr.cond.arg == 0:
+        if int(expr.cond) == 0:
             expr = expr.src2
         else:
             expr = expr.src1
@@ -646,8 +651,8 @@ def simp_cond(_, expr):
     elif (expr.cond.is_cond() and
           expr.cond.src1.is_int() and
           expr.cond.src2.is_int()):
-        int1 = expr.cond.src1.arg.arg
-        int2 = expr.cond.src2.arg.arg
+        int1 = int(expr.cond.src1)
+        int2 = int(expr.cond.src2)
         if int1 and int2:
             expr = expr.src1
         elif int1 == 0 and int2 == 0:
@@ -906,6 +911,15 @@ def simp_cond_flag(_, expr):
     return expr
 
 
+def simp_sub_cf_zero(_, expr):
+    """FLAG_SUB_CF(0, X) => (X)?1:0"""
+    if not expr.is_op("FLAG_SUB_CF"):
+        return expr
+    if not expr.args[0].is_int(0):
+        return expr
+    return ExprCond(expr.args[1], ExprInt(1, 1), ExprInt(0, 1))
+
+
 def simp_cmp_int(expr_simp, expr):
     """
     ({X, 0} == int) => X == int[:]
@@ -1069,6 +1083,13 @@ def simp_cmp_bijective_op(expr_simp, expr):
             args_a.remove(value)
             args_b.remove(value)
 
+    # a + b == a + b + c
+    if not args_a:
+        return ExprOp(TOK_EQUAL, ExprOp(op, *args_b), ExprInt(0, args_b[0].size))
+    # a + b + c == a + b
+    if not args_b:
+        return ExprOp(TOK_EQUAL, ExprOp(op, *args_a), ExprInt(0, args_a[0].size))
+    
     arg_a = ExprOp(op, *args_a)
     arg_b = ExprOp(op, *args_b)
     return ExprOp(TOK_EQUAL, arg_a, arg_b)
@@ -1362,6 +1383,23 @@ def simp_ext_cst(_, expr):
     return ret
 
 
+
+def simp_ext_cond_int(e_s, expr):
+    """
+    zeroExt(ExprCond(X, Int, Int)) => ExprCond(X, Int, Int)
+    """
+    if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")):
+        return expr
+    arg = expr.args[0]
+    if not arg.is_cond():
+        return expr
+    if not (arg.src1.is_int() and arg.src2.is_int()):
+        return expr
+    src1 = ExprOp(expr.op, arg.src1)
+    src2 = ExprOp(expr.op, arg.src2)
+    return e_s(ExprCond(arg.cond, src1, src2))
+
+
 def simp_slice_of_ext(_, expr):
     """
     C.zeroExt(X)[A:B] => 0 if A >= size(C)