diff options
| -rw-r--r-- | miasm2/expression/simplifications.py | 1 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 75 | ||||
| -rw-r--r-- | test/expression/simplifications.py | 20 |
3 files changed, 95 insertions, 1 deletions
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py index a237a57e..3f50fc1a 100644 --- a/miasm2/expression/simplifications.py +++ b/miasm2/expression/simplifications.py @@ -37,6 +37,7 @@ class ExpressionSimplifier(object): simplifications_common.simp_cst_propagation, simplifications_common.simp_cond_op_int, simplifications_common.simp_cond_factor, + simplifications_common.simp_add_multiple, # CC op simplifications_common.simp_cc_conds, simplifications_common.simp_subwc_cf, diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 7db4e819..7bdfd33b 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -308,6 +308,12 @@ def simp_cst_propagation(e_s, expr): return -ExprOp(op_name, *new_args) args = new_args + # -(a * b * int) => a * b * (-int) + if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int(): + args = args[0].args + return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) + + # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): @@ -1138,3 +1144,72 @@ def simp_slice_of_ext(expr_s, expr): if arg.size != expr.size: return expr return arg + +def simp_add_multiple(expr_s, expr): + # X + X => 2 * X + # X + X * int1 => X * (1 + int1) + # X * int1 + (- X) => X * (int1 - 1) + # X + (X << int1) => X * (1 + 2 ** int1) + # Correct even if addition overflow/underflow + if not expr.is_op('+'): + return expr + + # Extract each argument and its counter + operands = {} + for i, arg in enumerate(expr.args): + if arg.is_op('*') and arg.args[1].is_int(): + base_expr, factor = arg.args + operands[base_expr] = operands.get(base_expr, 0) + int(factor) + elif arg.is_op('<<') and arg.args[1].is_int(): + base_expr, factor = arg.args + operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor) + elif arg.is_op("-"): + arg = arg.args[0] + if arg.is_op('<<') and arg.args[1].is_int(): + base_expr, factor = arg.args + operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor)) + else: + operands[arg] = operands.get(arg, 0) - 1 + else: + operands[arg] = operands.get(arg, 0) + 1 + out = [] + + # Best effort to factor common args: + # (a + b) * 3 + a + b => (a + b) * 4 + # Does not factor: + # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a + modified = True + while modified: + modified = False + for arg, count in operands.iteritems(): + if not arg.is_op('+'): + continue + components = arg.args + if not all(component in operands for component in components): + continue + counters = set(operands[component] for component in components) + if len(counters) != 1: + continue + counter = counters.pop() + for component in components: + del operands[component] + operands[arg] += counter + modified = True + break + + for arg, count in operands.iteritems(): + if count == 0: + continue + if count == 1: + out.append(arg) + continue + out.append(arg * ExprInt(count, expr.size)) + + if len(out) == len(expr.args): + # No reductions + return expr + if not out: + return ExprInt(0, expr.size) + if len(out) == 1: + return out[0] + return ExprOp('+', *out) diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py index 741d6adb..57895510 100644 --- a/test/expression/simplifications.py +++ b/test/expression/simplifications.py @@ -97,7 +97,10 @@ s = a[:8] i0 = ExprInt(0, 32) i1 = ExprInt(1, 32) i2 = ExprInt(2, 32) +i3 = ExprInt(3, 32) im1 = ExprInt(-1, 32) +im2 = ExprInt(-2, 32) + icustom = ExprInt(0x12345678, 32) cc = ExprCond(a, b, c) @@ -242,7 +245,7 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), (ExprOp('*', -a, -b, c, ExprInt(0x12, 32)), ExprOp('*', a, b, c, ExprInt(0x12, 32))), (ExprOp('*', -a, -b, -c, ExprInt(0x12, 32)), - - ExprOp('*', a, b, c, ExprInt(0x12, 32))), + ExprOp('*', a, b, c, ExprInt(-0x12, 32))), (a | ExprInt(0xffffffff, 32), ExprInt(0xffffffff, 32)), (ExprCond(a, ExprInt(1, 32), ExprInt(2, 32)) * ExprInt(4, 32), @@ -443,6 +446,21 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), (ExprOp("signExt_16", ExprInt(0x8, 8)), ExprInt(0x8, 16)), (ExprOp("signExt_16", ExprInt(-0x8, 8)), ExprInt(-0x8, 16)), + (- (i2*a), a * im2), + (a + a, a * i2), + (ExprOp('+', a, a), a * i2), + (ExprOp('+', a, a, a), a * i3), + ((a<<i1) - a, a), + ((a<<i1) - (a<<i2), a*im2), + ((a<<i1) - a - a, i0), + ((a<<i2) - (a<<i1) - (a<<i1), i0), + ((a<<i2) - a*i3, a), + (((a+b) * i3) - (a + b), (a+b) * i2), + (((a+b) * i2) + a + b, (a+b) * i3), + (((a+b) * i3) - a - b, (a+b) * i2), + (((a+b) * i2) - a - b, a+b), + (((a+b) * i2) - i2 * a - i2 * b, i0), + ] |