about summary refs log tree commit diff stats
path: root/miasm2/expression/simplifications_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/expression/simplifications_common.py')
-rw-r--r--miasm2/expression/simplifications_common.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 7db4e819..7bdfd33b 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -308,6 +308,12 @@ def simp_cst_propagation(e_s, expr):
             return -ExprOp(op_name, *new_args)
         args = new_args
 
+    # -(a * b * int) => a * b * (-int)
+    if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int():
+        args = args[0].args
+        return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)]))
+
+
     # A << int with A ExprCompose => move index
     if (op_name == "<<" and args[0].is_compose() and
         args[1].is_int() and int(args[1]) != 0):
@@ -1138,3 +1144,72 @@ def simp_slice_of_ext(expr_s, expr):
     if arg.size != expr.size:
         return expr
     return arg
+
+def simp_add_multiple(expr_s, expr):
+    # X + X => 2 * X
+    # X + X * int1 => X * (1 + int1)
+    # X * int1 + (- X) => X * (int1 - 1)
+    # X + (X << int1) => X * (1 + 2 ** int1)
+    # Correct even if addition overflow/underflow
+    if not expr.is_op('+'):
+        return expr
+
+    # Extract each argument and its counter
+    operands = {}
+    for i, arg in enumerate(expr.args):
+        if arg.is_op('*') and arg.args[1].is_int():
+            base_expr, factor = arg.args
+            operands[base_expr] = operands.get(base_expr, 0) + int(factor)
+        elif arg.is_op('<<') and arg.args[1].is_int():
+            base_expr, factor = arg.args
+            operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor)
+        elif arg.is_op("-"):
+            arg = arg.args[0]
+            if arg.is_op('<<') and arg.args[1].is_int():
+                base_expr, factor = arg.args
+                operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor))
+            else:
+                operands[arg] = operands.get(arg, 0) - 1
+        else:
+            operands[arg] = operands.get(arg, 0) + 1
+    out = []
+
+    # Best effort to factor common args:
+    # (a + b) * 3 + a + b => (a + b) * 4
+    # Does not factor:
+    # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a
+    modified = True
+    while modified:
+        modified = False
+        for arg, count in operands.iteritems():
+            if not arg.is_op('+'):
+                continue
+            components = arg.args
+            if not all(component in operands for component in components):
+                continue
+            counters = set(operands[component] for component in components)
+            if len(counters) != 1:
+                continue
+            counter = counters.pop()
+            for component in components:
+                del operands[component]
+            operands[arg] += counter
+            modified = True
+            break
+
+    for arg, count in operands.iteritems():
+        if count == 0:
+            continue
+        if count == 1:
+            out.append(arg)
+            continue
+        out.append(arg * ExprInt(count, expr.size))
+
+    if len(out) == len(expr.args):
+        # No reductions
+        return expr
+    if not out:
+        return ExprInt(0, expr.size)
+    if len(out) == 1:
+        return out[0]
+    return ExprOp('+', *out)