about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/simplifications.py1
-rw-r--r--miasm2/expression/simplifications_common.py75
-rw-r--r--test/expression/simplifications.py20
3 files changed, 95 insertions, 1 deletions
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index a237a57e..3f50fc1a 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -37,6 +37,7 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_cst_propagation,
             simplifications_common.simp_cond_op_int,
             simplifications_common.simp_cond_factor,
+            simplifications_common.simp_add_multiple,
             # CC op
             simplifications_common.simp_cc_conds,
             simplifications_common.simp_subwc_cf,
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 7db4e819..7bdfd33b 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -308,6 +308,12 @@ def simp_cst_propagation(e_s, expr):
             return -ExprOp(op_name, *new_args)
         args = new_args
 
+    # -(a * b * int) => a * b * (-int)
+    if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int():
+        args = args[0].args
+        return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)]))
+
+
     # A << int with A ExprCompose => move index
     if (op_name == "<<" and args[0].is_compose() and
         args[1].is_int() and int(args[1]) != 0):
@@ -1138,3 +1144,72 @@ def simp_slice_of_ext(expr_s, expr):
     if arg.size != expr.size:
         return expr
     return arg
+
+def simp_add_multiple(expr_s, expr):
+    # X + X => 2 * X
+    # X + X * int1 => X * (1 + int1)
+    # X * int1 + (- X) => X * (int1 - 1)
+    # X + (X << int1) => X * (1 + 2 ** int1)
+    # Correct even if addition overflow/underflow
+    if not expr.is_op('+'):
+        return expr
+
+    # Extract each argument and its counter
+    operands = {}
+    for i, arg in enumerate(expr.args):
+        if arg.is_op('*') and arg.args[1].is_int():
+            base_expr, factor = arg.args
+            operands[base_expr] = operands.get(base_expr, 0) + int(factor)
+        elif arg.is_op('<<') and arg.args[1].is_int():
+            base_expr, factor = arg.args
+            operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor)
+        elif arg.is_op("-"):
+            arg = arg.args[0]
+            if arg.is_op('<<') and arg.args[1].is_int():
+                base_expr, factor = arg.args
+                operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor))
+            else:
+                operands[arg] = operands.get(arg, 0) - 1
+        else:
+            operands[arg] = operands.get(arg, 0) + 1
+    out = []
+
+    # Best effort to factor common args:
+    # (a + b) * 3 + a + b => (a + b) * 4
+    # Does not factor:
+    # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a
+    modified = True
+    while modified:
+        modified = False
+        for arg, count in operands.iteritems():
+            if not arg.is_op('+'):
+                continue
+            components = arg.args
+            if not all(component in operands for component in components):
+                continue
+            counters = set(operands[component] for component in components)
+            if len(counters) != 1:
+                continue
+            counter = counters.pop()
+            for component in components:
+                del operands[component]
+            operands[arg] += counter
+            modified = True
+            break
+
+    for arg, count in operands.iteritems():
+        if count == 0:
+            continue
+        if count == 1:
+            out.append(arg)
+            continue
+        out.append(arg * ExprInt(count, expr.size))
+
+    if len(out) == len(expr.args):
+        # No reductions
+        return expr
+    if not out:
+        return ExprInt(0, expr.size)
+    if len(out) == 1:
+        return out[0]
+    return ExprOp('+', *out)
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index 741d6adb..57895510 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -97,7 +97,10 @@ s = a[:8]
 i0 = ExprInt(0, 32)
 i1 = ExprInt(1, 32)
 i2 = ExprInt(2, 32)
+i3 = ExprInt(3, 32)
 im1 = ExprInt(-1, 32)
+im2 = ExprInt(-2, 32)
+
 icustom = ExprInt(0x12345678, 32)
 cc = ExprCond(a, b, c)
 
@@ -242,7 +245,7 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
     (ExprOp('*', -a, -b, c, ExprInt(0x12, 32)),
      ExprOp('*', a, b, c, ExprInt(0x12, 32))),
     (ExprOp('*', -a, -b, -c, ExprInt(0x12, 32)),
-     - ExprOp('*', a, b, c, ExprInt(0x12, 32))),
+     ExprOp('*', a, b, c, ExprInt(-0x12, 32))),
     (a | ExprInt(0xffffffff, 32),
      ExprInt(0xffffffff, 32)),
     (ExprCond(a, ExprInt(1, 32), ExprInt(2, 32)) * ExprInt(4, 32),
@@ -443,6 +446,21 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
     (ExprOp("signExt_16", ExprInt(0x8, 8)), ExprInt(0x8, 16)),
     (ExprOp("signExt_16", ExprInt(-0x8, 8)), ExprInt(-0x8, 16)),
 
+    (- (i2*a), a * im2),
+    (a + a, a * i2),
+    (ExprOp('+', a, a), a * i2),
+    (ExprOp('+', a, a, a), a * i3),
+    ((a<<i1) - a, a),
+    ((a<<i1) - (a<<i2), a*im2),
+    ((a<<i1) - a - a, i0),
+    ((a<<i2) - (a<<i1) - (a<<i1), i0),
+    ((a<<i2) - a*i3, a),
+    (((a+b) * i3) - (a + b), (a+b) * i2),
+    (((a+b) * i2) + a + b, (a+b) * i3),
+    (((a+b) * i3) - a - b, (a+b) * i2),
+    (((a+b) * i2) - a - b, a+b),
+    (((a+b) * i2) - i2 * a - i2 * b, i0),
+
 
 ]