diff options
| -rw-r--r-- | miasm2/expression/simplifications.py | 31 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 132 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_cond.py | 54 | ||||
| -rw-r--r-- | miasm2/ir/translators/z3_ir.py | 30 | ||||
| -rw-r--r-- | test/expression/simplifications.py | 63 |
5 files changed, 218 insertions, 92 deletions
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py index 9114cbbe..a237a57e 100644 --- a/miasm2/expression/simplifications.py +++ b/miasm2/expression/simplifications.py @@ -45,16 +45,23 @@ class ExpressionSimplifier(object): simplifications_common.simp_double_zeroext, simplifications_common.simp_double_signext, simplifications_common.simp_zeroext_eq_cst, + simplifications_common.simp_ext_eq_ext, + + simplifications_common.simp_cmp_int, + simplifications_common.simp_cmp_int_int, + simplifications_common.simp_ext_cst, ], - m2_expr.ExprSlice: [simplifications_common.simp_slice], + m2_expr.ExprSlice: [ + simplifications_common.simp_slice, + simplifications_common.simp_slice_of_ext, + ], m2_expr.ExprCompose: [simplifications_common.simp_compose], m2_expr.ExprCond: [ simplifications_common.simp_cond, # CC op simplifications_common.simp_cond_flag, - simplifications_common.simp_cond_int, simplifications_common.simp_cmp_int_arg, simplifications_common.simp_cond_eq_zero, @@ -68,14 +75,18 @@ class ExpressionSimplifier(object): PASS_HEAVY = {} # Cond passes - PASS_COND = {m2_expr.ExprSlice: [simplifications_cond.expr_simp_inf_signed, - simplifications_cond.expr_simp_inf_unsigned_inversed], - m2_expr.ExprOp: [simplifications_cond.exec_inf_unsigned, - simplifications_cond.exec_inf_signed, - simplifications_cond.expr_simp_inverse, - simplifications_cond.exec_equal], - m2_expr.ExprCond: [simplifications_cond.expr_simp_equal] - } + PASS_COND = { + m2_expr.ExprSlice: [ + simplifications_cond.expr_simp_inf_signed, + simplifications_cond.expr_simp_inf_unsigned_inversed + ], + m2_expr.ExprOp: [ + simplifications_cond.expr_simp_inverse, + ], + m2_expr.ExprCond: [ + simplifications_cond.expr_simp_equal + ] + } # Available passes lists are: diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 9a59fbd4..7db4e819 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -884,35 +884,34 @@ def simp_cond_flag(expr_simp, expr): return expr -def simp_cond_int(expr_simp, expr): - if (expr.cond.is_op(TOK_EQUAL) and - expr.cond.args[1].is_int() and - expr.cond.args[0].is_compose() and - len(expr.cond.args[0].args) == 2 and - expr.cond.args[0].args[1].is_int(0)): +def simp_cmp_int(expr_simp, expr): + # ({X, 0} == int) => X == int[:] + # X + int1 == int2 => X == int2-int1 + if (expr.is_op(TOK_EQUAL) and + expr.args[1].is_int() and + expr.args[0].is_compose() and + len(expr.args[0].args) == 2 and + expr.args[0].args[1].is_int(0)): # ({X, 0} == int) => X == int[:] - src = expr.cond.args[0].args[0] - int_val = int(expr.cond.args[1]) + src = expr.args[0].args[0] + int_val = int(expr.args[1]) new_int = ExprInt(int_val, src.size) expr = expr_simp( - ExprCond( - ExprOp(TOK_EQUAL, src, new_int), - expr.src1, - expr.src2) + ExprOp(TOK_EQUAL, src, new_int) ) - elif (expr.cond.is_op() and - expr.cond.op in [ + elif (expr.is_op() and + expr.op in [ TOK_EQUAL, - TOK_INF_SIGNED, - TOK_INF_EQUAL_SIGNED, - TOK_INF_UNSIGNED, - TOK_INF_EQUAL_UNSIGNED ] and - expr.cond.args[1].is_int() and - expr.cond.args[0].is_op("+") and - expr.cond.args[0].args[-1].is_int()): + expr.args[1].is_int() and + expr.args[0].is_op("+") and + expr.args[0].args[-1].is_int()): # X + int1 == int2 => X == int2-int1 - left, right = expr.cond.args + # WARNING: + # X - 0x10 <=u 0x20 gives X in [0x10 0x30] + # which is not equivalet to A <=u 0x10 + + left, right = expr.args left, int_diff = left.args[:-1], left.args[-1] if len(left) == 1: left = left[0] @@ -920,10 +919,7 @@ def simp_cond_int(expr_simp, expr): left = ExprOp('+', *left) new_int = expr_simp(right - int_diff) expr = expr_simp( - ExprCond( - ExprOp(expr.cond.op, left, new_int), - expr.src1, - expr.src2) + ExprOp(expr.op, left, new_int), ) return expr @@ -1047,6 +1043,20 @@ def simp_zeroext_eq_cst(expr_s, expr): return ExprInt(0, 1) return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size)) +def simp_ext_eq_ext(expr_s, expr): + # A.zeroExt(X) == B.zeroExt(X) => A == B + # A.signExt(X) == B.signExt(X) => A == B + if not expr.is_op(TOK_EQUAL): + return expr + arg1, arg2 = expr.args + if (not ((arg1.is_op() and arg1.op.startswith("zeroExt") and + arg2.is_op() and arg2.op.startswith("zeroExt")) or + (arg1.is_op() and arg1.op.startswith("signExt") and + arg2.is_op() and arg2.op.startswith("signExt")))): + return expr + if arg1.args[0].size != arg2.args[0].size: + return expr + return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0]) def simp_cond_eq_zero(expr_s, expr): # (X == 0)?(A:B) => X?(B:A) @@ -1058,3 +1068,73 @@ def simp_cond_eq_zero(expr_s, expr): return expr new_expr = ExprCond(arg1, expr.src2, expr.src1) return new_expr + + +def simp_cmp_int_int(expr_s, expr): + # IntA <s IntB => int + # IntA <u IntB => int + # IntA <=s IntB => int + # IntA <=u IntB => int + # IntA == IntB => int + if expr.op not in [ + TOK_EQUAL, + TOK_INF_SIGNED, TOK_INF_UNSIGNED, + TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, + ]: + return expr + if not all(arg.is_int() for arg in expr.args): + return expr + int_a, int_b = expr.args + if expr.is_op(TOK_EQUAL): + if int_a == int_b: + return ExprInt(1, 1) + else: + return ExprInt(0, 1) + + if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]: + int_a = int(mod_size2int[int_a.size](int(int_a))) + int_b = int(mod_size2int[int_b.size](int(int_b))) + else: + int_a = int(mod_size2uint[int_a.size](int(int_a))) + int_b = int(mod_size2uint[int_b.size](int(int_b))) + + if expr.op in [TOK_INF_SIGNED, TOK_INF_UNSIGNED]: + ret = int_a < int_b + else: + ret = int_a <= int_b + + if ret: + ret = 1 + else: + ret = 0 + return ExprInt(ret, 1) + + +def simp_ext_cst(expr_s, expr): + # Int.zeroExt(X) => Int + # Int.signExt(X) => Int + if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): + return expr + arg = expr.args[0] + if not arg.is_int(): + return expr + if expr.op.startswith("zeroExt"): + ret = int(arg) + else: + ret = int(mod_size2int[arg.size](int(arg))) + ret = ExprInt(ret, expr.size) + return ret + + +def simp_slice_of_ext(expr_s, expr): + # zeroExt(X)[0:size(X)] => X + if expr.start != 0: + return expr + if not expr.arg.is_op(): + return expr + if not expr.arg.op.startswith("zeroExt"): + return expr + arg = expr.arg.args[0] + if arg.size != expr.size: + return expr + return arg diff --git a/miasm2/expression/simplifications_cond.py b/miasm2/expression/simplifications_cond.py index 6bdc810f..f6b1ea8b 100644 --- a/miasm2/expression/simplifications_cond.py +++ b/miasm2/expression/simplifications_cond.py @@ -176,57 +176,3 @@ def expr_simp_equal(expr_simp, e): return e return ExprOp_equal(r[jok1], expr_simp(-r[jok2])) - -# Compute conditions - -def exec_inf_unsigned(expr_simp, e): - "Compute x <u y" - if e.op != m2_expr.TOK_INF_UNSIGNED: - return e - - arg1, arg2 = e.args - - if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt): - return m2_expr.ExprInt(1, 1) if (arg1.arg < arg2.arg) else m2_expr.ExprInt(0, 1) - else: - return e - - -def __comp_signed(arg1, arg2): - """Return ExprInt(1, 1) if arg1 <s arg2 else ExprInt(0, 1) - @arg1, @arg2: ExprInt""" - - val1 = int(arg1) - if val1 >> (arg1.size - 1) == 1: - val1 = - ((int(arg1.mask) ^ val1) + 1) - - val2 = int(arg2) - if val2 >> (arg2.size - 1) == 1: - val2 = - ((int(arg2.mask) ^ val2) + 1) - - return m2_expr.ExprInt(1, 1) if (val1 < val2) else m2_expr.ExprInt(0, 1) - -def exec_inf_signed(expr_simp, e): - "Compute x <s y" - - if e.op != m2_expr.TOK_INF_SIGNED: - return e - - arg1, arg2 = e.args - - if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt): - return __comp_signed(arg1, arg2) - else: - return e - -def exec_equal(expr_simp, e): - "Compute x == y" - - if e.op != m2_expr.TOK_EQUAL: - return e - - arg1, arg2 = e.args - if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt): - return m2_expr.ExprInt(1, 1) if (arg1.arg == arg2.arg) else m2_expr.ExprInt(0, 1) - else: - return e diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py index 55952180..1cc8c29d 100644 --- a/miasm2/ir/translators/z3_ir.py +++ b/miasm2/ir/translators/z3_ir.py @@ -205,6 +205,36 @@ class TranslatorZ3(Translator): res = res - (arg * (self._idivC(res, arg))) elif expr.op == "umod": res = z3.URem(res, arg) + elif expr.op == "==": + res = z3.If( + args[0] == args[1], + z3.BitVecVal(1, 1), + z3.BitVecVal(0, 1) + ) + elif expr.op == "<u": + res = z3.If( + z3.ULT(args[0], args[1]), + z3.BitVecVal(1, 1), + z3.BitVecVal(0, 1) + ) + elif expr.op == "<s": + res = z3.If( + z3.SLT(args[0], args[1]), + z3.BitVecVal(1, 1), + z3.BitVecVal(0, 1) + ) + elif expr.op == "<=u": + res = z3.If( + z3.ULE(args[0], args[1]), + z3.BitVecVal(1, 1), + z3.BitVecVal(0, 1) + ) + elif expr.op == "<=s": + res = z3.If( + z3.SLE(args[0], args[1]), + z3.BitVecVal(1, 1), + z3.BitVecVal(0, 1) + ) else: raise NotImplementedError("Unsupported OP yet: %s" % expr.op) elif expr.op == 'parity': diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py index 68dc0437..741d6adb 100644 --- a/test/expression/simplifications.py +++ b/test/expression/simplifications.py @@ -97,6 +97,7 @@ s = a[:8] i0 = ExprInt(0, 32) i1 = ExprInt(1, 32) i2 = ExprInt(2, 32) +im1 = ExprInt(-1, 32) icustom = ExprInt(0x12345678, 32) cc = ExprCond(a, b, c) @@ -452,9 +453,65 @@ for e_input, e_check in to_test: rez = e_new == e_check if not rez: raise ValueError( - 'bug in expr_simp_explicit simp(%s) is %s and should be %s' % (e_input, e_new, e_check)) + 'bug in expr_simp_explicit simp(%s) is %s and should be %s' % (e_input, e_new, e_check) + ) check(e_input, e_check) + +# Test high level op +to_test = [ + (ExprOp(TOK_EQUAL, a+i2, i1), ExprOp(TOK_EQUAL, a+i1, i0)), + (ExprOp(TOK_INF_SIGNED, a+i2, i1), ExprOp(TOK_INF_SIGNED, a+i2, i1)), + (ExprOp(TOK_INF_UNSIGNED, a+i2, i1), ExprOp(TOK_INF_UNSIGNED, a+i2, i1)), + + ( + ExprOp(TOK_EQUAL, ExprCompose(a8, ExprInt(0, 24)), im1), + ExprOp(TOK_EQUAL, a8, ExprInt(0xFF, 8)) + ), + + (ExprOp(TOK_INF_SIGNED, i1, i2), ExprInt(1, 1)), + (ExprOp(TOK_INF_UNSIGNED, i1, i2), ExprInt(1, 1)), + (ExprOp(TOK_INF_EQUAL_SIGNED, i1, i2), ExprInt(1, 1)), + (ExprOp(TOK_INF_EQUAL_UNSIGNED, i1, i2), ExprInt(1, 1)), + + (ExprOp(TOK_INF_SIGNED, i2, i1), ExprInt(0, 1)), + (ExprOp(TOK_INF_UNSIGNED, i2, i1), ExprInt(0, 1)), + (ExprOp(TOK_INF_EQUAL_SIGNED, i2, i1), ExprInt(0, 1)), + (ExprOp(TOK_INF_EQUAL_UNSIGNED, i2, i1), ExprInt(0, 1)), + + (ExprOp(TOK_INF_SIGNED, i1, i1), ExprInt(0, 1)), + (ExprOp(TOK_INF_UNSIGNED, i1, i1), ExprInt(0, 1)), + (ExprOp(TOK_INF_EQUAL_SIGNED, i1, i1), ExprInt(1, 1)), + (ExprOp(TOK_INF_EQUAL_UNSIGNED, i1, i1), ExprInt(1, 1)), + + + (ExprOp(TOK_INF_SIGNED, im1, i1), ExprInt(1, 1)), + (ExprOp(TOK_INF_UNSIGNED, im1, i1), ExprInt(0, 1)), + (ExprOp(TOK_INF_EQUAL_SIGNED, im1, i1), ExprInt(1, 1)), + (ExprOp(TOK_INF_EQUAL_UNSIGNED, im1, i1), ExprInt(0, 1)), + + (ExprOp(TOK_INF_SIGNED, i1, im1), ExprInt(0, 1)), + (ExprOp(TOK_INF_UNSIGNED, i1, im1), ExprInt(1, 1)), + (ExprOp(TOK_INF_EQUAL_SIGNED, i1, im1), ExprInt(0, 1)), + (ExprOp(TOK_INF_EQUAL_UNSIGNED, i1, im1), ExprInt(1, 1)), + + (ExprOp(TOK_EQUAL, a8.zeroExtend(32), b8.zeroExtend(32)), ExprOp(TOK_EQUAL, a8, b8)), + (ExprOp(TOK_EQUAL, a8.signExtend(32), b8.signExtend(32)), ExprOp(TOK_EQUAL, a8, b8)), + +] + +for e_input, e_check in to_test: + print "#" * 80 + e_check = expr_simp(e_check) + e_new = expr_simp(e_input) + print "original: ", str(e_input), "new: ", str(e_new) + rez = e_new == e_check + if not rez: + raise ValueError( + 'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check) + ) + + # Test conds to_test = [ @@ -490,7 +547,9 @@ for e_input, e_check in to_test: rez = e_new == e_check if not rez: raise ValueError( - 'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check)) + 'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check) + ) + if args.z3: # This check is done on 32 bits, but the size is not use by Miasm formulas, so |