diff options
Diffstat (limited to 'test/expression/simplifications.py')
| -rw-r--r-- | test/expression/simplifications.py | 219 |
1 files changed, 171 insertions, 48 deletions
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py index e4c3f2e9..0c516a8e 100644 --- a/test/expression/simplifications.py +++ b/test/expression/simplifications.py @@ -2,11 +2,69 @@ # Expression simplification regression tests # # from pdb import pm +from argparse import ArgumentParser +import logging + from miasm2.expression.expression import * -from miasm2.expression.expression_helper import expr_cmpu, expr_cmps -from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier +from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier, log_exprsimp from miasm2.expression.simplifications_cond import ExprOp_inf_signed, ExprOp_inf_unsigned, ExprOp_equal +parser = ArgumentParser("Expression simplification regression tests") +parser.add_argument("--z3", action="store_true", help="Enable check against z3") +parser.add_argument("-v", "--verbose", action="store_true", + help="Verbose simplify") +args = parser.parse_args() + +if args.verbose: + log_exprsimp.setLevel(logging.DEBUG) + +# Additionnal imports and definitions +if args.z3: + import z3 + from miasm2.ir.translators import Translator + trans = Translator.to_language("z3") + + def check(expr_in, expr_out): + """Check that expr_in is always equals to expr_out""" + print "Ensure %s = %s" % (expr_in, expr_out) + solver = z3.Solver() + solver.add(trans.from_expr(expr_in) != trans.from_expr(expr_out)) + + result = solver.check() + + if result != z3.unsat: + print "ERROR: a counter-example has been founded:" + model = solver.model() + print model + + print "Reinjecting in the simplifier:" + to_rep = {} + expressions = expr_in.get_r().union(expr_out.get_r()) + for expr in expressions: + value = model.eval(trans.from_expr(expr)) + if hasattr(value, "as_long"): + new_val = ExprInt(value.as_long(), expr.size) + else: + raise RuntimeError("Unable to reinject %r" % value) + + to_rep[expr] = new_val + + new_expr_in = expr_in.replace_expr(to_rep) + new_expr_out = expr_out.replace_expr(to_rep) + + print "Check %s = %s" % (new_expr_in, new_expr_out) + simp_in = expr_simp(new_expr_in) + simp_out = expr_simp(new_expr_out) + print "[%s] %s = %s" % (simp_in == simp_out, simp_in, simp_out) + + # Either the simplification does not stand, either the test is wrong + raise RuntimeError("Bad simplification") + +else: + # Dummy 'check' method to avoid checking the '--z3' argument each time + check = lambda expr_in, expr_out: None + + # Define example objects a = ExprId('a', 32) b = ExprId('b', 32) @@ -15,6 +73,22 @@ d = ExprId('d', 32) e = ExprId('e', 32) f = ExprId('f', size=64) +b_msb_null = b[:31].zeroExtend(32) +c_msb_null = c[:31].zeroExtend(32) + +a31 = ExprId('a31', 31) +b31 = ExprId('b31', 31) +c31 = ExprId('c31', 31) +b31_msb_null = ExprId('b31', 31)[:30].zeroExtend(31) +c31_msb_null = ExprId('c31', 31)[:30].zeroExtend(31) + +a8 = ExprId('a8', 8) +b8 = ExprId('b8', 8) +c8 = ExprId('c8', 8) +d8 = ExprId('d8', 8) +e8 = ExprId('e8', 8) + + m = ExprMem(a) s = a[:8] @@ -49,17 +123,35 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), (ExprOp('>>>', a, ExprInt(32, 32)), a), (ExprOp('>>>', a, ExprInt(0, 32)), a), (ExprOp('<<', a, ExprInt(0, 32)), a), + (ExprOp('<<<', a31, ExprInt(31, 31)), a31), + (ExprOp('>>>', a31, ExprInt(31, 31)), a31), + (ExprOp('>>>', a31, ExprInt(0, 31)), a31), + (ExprOp('<<', a31, ExprInt(0, 31)), a31), + + (ExprOp('<<<', a31, ExprOp('<<<', b31, c31)), + ExprOp('<<<', a31, ExprOp('<<<', b31, c31))), + (ExprOp('<<<', ExprOp('>>>', a31, b31), c31), + ExprOp('<<<', ExprOp('>>>', a31, b31), c31)), + (ExprOp('>>>', ExprOp('<<<', a31, b31), c31), + ExprOp('>>>', ExprOp('<<<', a31, b31), c31)), + (ExprOp('>>>', ExprOp('<<<', a31, b31), b31), + a31), + (ExprOp('<<<', ExprOp('>>>', a31, b31), b31), + a31), + (ExprOp('>>>', ExprOp('>>>', a31, b31), b31), + ExprOp('>>>', ExprOp('>>>', a31, b31), b31)), + (ExprOp('<<<', ExprOp('<<<', a31, b31), b31), + ExprOp('<<<', ExprOp('<<<', a31, b31), b31)), + + (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)), + ExprOp('>>>', a31, ExprInt(0x13, 31))), + (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)), + ExprOp('<<<', a31, ExprInt(0x13, 31))), + (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)), + ExprOp('>>>', a31, ExprInt(0x1c, 31))), + (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)), + ExprOp('<<<', a31, ExprInt(0x1c, 31))), - (ExprOp('<<<', a, ExprOp('<<<', b, c)), - ExprOp('<<<', a, ExprOp('<<<', b, c))), - (ExprOp('<<<', ExprOp('<<<', a, b), c), - ExprOp('<<<', a, (b+c))), - (ExprOp('<<<', ExprOp('>>>', a, b), c), - ExprOp('>>>', a, (b-c))), - (ExprOp('>>>', ExprOp('<<<', a, b), c), - ExprOp('<<<', a, (b-c))), - (ExprOp('>>>', ExprOp('<<<', a, b), b), - a), (ExprOp(">>>", ExprInt(0x1000, 16), ExprInt(0x11, 16)), ExprInt(0x800, 16)), (ExprOp("<<<", ExprInt(0x1000, 16), ExprInt(0x11, 16)), @@ -134,9 +226,9 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), (ExprOp('*', a, b, c, ExprInt(0x12, 32))[0:17], ExprOp( '*', a[0:17], b[0:17], c[0:17], ExprInt(0x12, 17))), - (ExprOp('*', a, ExprInt(0x0, 32)), + (ExprOp('*', a, b, ExprInt(0x0, 32)), ExprInt(0x0, 32)), - (ExprOp('&', a, ExprInt(0x0, 32)), + (ExprOp('&', a, b, ExprInt(0x0, 32)), ExprInt(0x0, 32)), (ExprOp('*', a, ExprInt(0xffffffff, 32)), -a), @@ -144,16 +236,14 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), ExprOp('*', a, b, c, ExprInt(0x12, 32))), (ExprOp('*', -a, -b, -c, ExprInt(0x12, 32)), - ExprOp('*', a, b, c, ExprInt(0x12, 32))), - (ExprOp('**', ExprInt(2, 32), ExprInt(8, 32)), ExprInt(0x100, 32)), - (ExprInt(2, 32)**ExprInt(8, 32), ExprInt(256, 32)), (a | ExprInt(0xffffffff, 32), ExprInt(0xffffffff, 32)), (ExprCond(a, ExprInt(1, 32), ExprInt(2, 32)) * ExprInt(4, 32), ExprCond(a, ExprInt(4, 32), ExprInt(8, 32))), (ExprCond(a, b, c) + ExprCond(a, d, e), ExprCond(a, b + d, c + e)), - (ExprCond(a, b, c) * ExprCond(a, d, e), - ExprCond(a, b * d, c * e)), + (ExprCond(a8, b8, c8) * ExprCond(a8, d8, e8), + ExprCond(a8, b8 * d8, c8 * e8)), (ExprCond(a, ExprInt(8, 32), ExprInt(4, 32)) >> ExprInt(1, 32), ExprCond(a, ExprInt(4, 32), ExprInt(2, 32))), @@ -268,37 +358,37 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), a[:16]), ((a << ExprInt(16, 32))[24:32], a[8:16]), - (expr_cmpu(ExprInt(0, 32), ExprInt(0, 32)), + (expr_is_unsigned_greater(ExprInt(0, 32), ExprInt(0, 32)), ExprInt(0, 1)), - (expr_cmpu(ExprInt(10, 32), ExprInt(0, 32)), + (expr_is_unsigned_greater(ExprInt(10, 32), ExprInt(0, 32)), ExprInt(1, 1)), - (expr_cmpu(ExprInt(10, 32), ExprInt(5, 32)), + (expr_is_unsigned_greater(ExprInt(10, 32), ExprInt(5, 32)), ExprInt(1, 1)), - (expr_cmpu(ExprInt(5, 32), ExprInt(10, 32)), + (expr_is_unsigned_greater(ExprInt(5, 32), ExprInt(10, 32)), ExprInt(0, 1)), - (expr_cmpu(ExprInt(-1, 32), ExprInt(0, 32)), + (expr_is_unsigned_greater(ExprInt(-1, 32), ExprInt(0, 32)), ExprInt(1, 1)), - (expr_cmpu(ExprInt(-1, 32), ExprInt(-1, 32)), + (expr_is_unsigned_greater(ExprInt(-1, 32), ExprInt(-1, 32)), ExprInt(0, 1)), - (expr_cmpu(ExprInt(0, 32), ExprInt(-1, 32)), + (expr_is_unsigned_greater(ExprInt(0, 32), ExprInt(-1, 32)), ExprInt(0, 1)), - (expr_cmps(ExprInt(0, 32), ExprInt(0, 32)), + (expr_is_signed_greater(ExprInt(0, 32), ExprInt(0, 32)), ExprInt(0, 1)), - (expr_cmps(ExprInt(10, 32), ExprInt(0, 32)), + (expr_is_signed_greater(ExprInt(10, 32), ExprInt(0, 32)), ExprInt(1, 1)), - (expr_cmps(ExprInt(10, 32), ExprInt(5, 32)), + (expr_is_signed_greater(ExprInt(10, 32), ExprInt(5, 32)), ExprInt(1, 1)), - (expr_cmps(ExprInt(5, 32), ExprInt(10, 32)), + (expr_is_signed_greater(ExprInt(5, 32), ExprInt(10, 32)), ExprInt(0, 1)), - (expr_cmps(ExprInt(-1, 32), ExprInt(0, 32)), + (expr_is_signed_greater(ExprInt(-1, 32), ExprInt(0, 32)), ExprInt(0, 1)), - (expr_cmps(ExprInt(-1, 32), ExprInt(-1, 32)), + (expr_is_signed_greater(ExprInt(-1, 32), ExprInt(-1, 32)), ExprInt(0, 1)), - (expr_cmps(ExprInt(0, 32), ExprInt(-1, 32)), + (expr_is_signed_greater(ExprInt(0, 32), ExprInt(-1, 32)), ExprInt(1, 1)), - (expr_cmps(ExprInt(-5, 32), ExprInt(-10, 32)), + (expr_is_signed_greater(ExprInt(-5, 32), ExprInt(-10, 32)), ExprInt(1, 1)), - (expr_cmps(ExprInt(-10, 32), ExprInt(-5, 32)), + (expr_is_signed_greater(ExprInt(-10, 32), ExprInt(-5, 32)), ExprInt(0, 1)), (ExprOp("idiv", ExprInt(0x0123, 16), ExprInt(0xfffb, 16))[:8], @@ -312,21 +402,26 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)), ExprInt(0x0321, 16))), (ExprCompose(ExprCond(a, i1, i0), ExprCond(a, i1, i2)), ExprCond(a, ExprInt(0x100000001, 64), ExprInt(0x200000000, 64))), - ((ExprMem(ExprCond(a, b, c)),ExprCond(a, ExprMem(b), ExprMem(c)))), + ((ExprMem(ExprCond(a, b, c), 4),ExprCond(a, ExprMem(b, 4), ExprMem(c, 4)))), (ExprCond(a, i0, i1) + ExprCond(a, i0, i1), ExprCond(a, i0, i2)), + (a << b << c, a << b << c), # Left unmodified + (a << b_msb_null << c_msb_null, + a << (b_msb_null + c_msb_null)), + (a >> b >> c, a >> b >> c), # Left unmodified + (a >> b_msb_null >> c_msb_null, + a >> (b_msb_null + c_msb_null)), ] -for e, e_check in to_test[:]: - # +for e_input, e_check in to_test: print "#" * 80 - # print str(e), str(e_check) - e_new = expr_simp(e) - print "original: ", str(e), "new: ", str(e_new) + e_new = expr_simp(e_input) + print "original: ", str(e_input), "new: ", str(e_new) rez = e_new == e_check if not rez: raise ValueError( - 'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check)) + 'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check)) + check(e_input, e_check) # Test conds @@ -356,18 +451,46 @@ expr_simp_cond = ExpressionSimplifier() expr_simp.enable_passes(ExpressionSimplifier.PASS_COND) -for e, e_check in to_test[:]: - # +for e_input, e_check in to_test: print "#" * 80 e_check = expr_simp(e_check) - # print str(e), str(e_check) - e_new = expr_simp(e) - print "original: ", str(e), "new: ", str(e_new) + e_new = expr_simp(e_input) + print "original: ", str(e_input), "new: ", str(e_new) rez = e_new == e_check if not rez: raise ValueError( - 'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check)) - + 'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check)) + +if args.z3: + # This check is done on 32 bits, but the size is not use by Miasm formulas, so + # it should be OK for any size > 0 + x1 = ExprId("x1", 32) + x2 = ExprId("x2", 32) + i1_tmp = ExprInt(1, 1) + + x1_z3 = trans.from_expr(x1) + x2_z3 = trans.from_expr(x2) + i1_z3 = trans.from_expr(i1_tmp) + + # (Assumptions, function(arg1, arg2) -> True/False (= i1/i0) to check) + tests = [ + (x1_z3 == x2_z3, expr_is_equal), + (x1_z3 != x2_z3, expr_is_not_equal), + (z3.UGT(x1_z3, x2_z3), expr_is_unsigned_greater), + (z3.UGE(x1_z3, x2_z3), expr_is_unsigned_greater_or_equal), + (z3.ULT(x1_z3, x2_z3), expr_is_unsigned_lower), + (z3.ULE(x1_z3, x2_z3), expr_is_unsigned_lower_or_equal), + (x1_z3 > x2_z3, expr_is_signed_greater), + (x1_z3 >= x2_z3, expr_is_signed_greater_or_equal), + (x1_z3 < x2_z3, expr_is_signed_lower), + (x1_z3 <= x2_z3, expr_is_signed_lower_or_equal), + ] + + for assumption, func in tests: + solver = z3.Solver() + solver.add(assumption) + solver.add(trans.from_expr(func(x1, x2)) != i1_z3) + assert solver.check() == z3.unsat x = ExprId('x', 32) |