about summary refs log tree commit diff stats
path: root/test/expression/simplifications.py
diff options
context:
space:
mode:
authorserpilliere <serpilliere@users.noreply.github.com>2018-02-14 15:13:20 +0100
committerGitHub <noreply@github.com>2018-02-14 15:13:20 +0100
commit9dd075f09e4f31ec7fe12e50709d9e58c65ed5f4 (patch)
tree65d6c4f1c613822d0441bd296cc4c7e7f1136522 /test/expression/simplifications.py
parentdcfadb31685d428618b88f19fcc96dd70cecfc8f (diff)
parent0f55f0779555c38cd907143527d4ddbf26c18157 (diff)
downloadfocaccia-miasm-9dd075f09e4f31ec7fe12e50709d9e58c65ed5f4.tar.gz
focaccia-miasm-9dd075f09e4f31ec7fe12e50709d9e58c65ed5f4.zip
Merge pull request #679 from commial/refactor-expr-comp
Refactor expr simplifications tests
Diffstat (limited to 'test/expression/simplifications.py')
-rw-r--r--test/expression/simplifications.py219
1 files changed, 171 insertions, 48 deletions
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index e4c3f2e9..0c516a8e 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -2,11 +2,69 @@
 # Expression simplification regression tests  #
 #
 from pdb import pm
+from argparse import ArgumentParser
+import logging
+
 from miasm2.expression.expression import *
-from miasm2.expression.expression_helper import expr_cmpu, expr_cmps
-from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier
+from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier, log_exprsimp
 from miasm2.expression.simplifications_cond import ExprOp_inf_signed, ExprOp_inf_unsigned, ExprOp_equal
 
+parser = ArgumentParser("Expression simplification regression tests")
+parser.add_argument("--z3", action="store_true", help="Enable check against z3")
+parser.add_argument("-v", "--verbose", action="store_true",
+                    help="Verbose simplify")
+args = parser.parse_args()
+
+if args.verbose:
+    log_exprsimp.setLevel(logging.DEBUG)
+
+# Additionnal imports and definitions
+if args.z3:
+    import z3
+    from miasm2.ir.translators import Translator
+    trans = Translator.to_language("z3")
+
+    def check(expr_in, expr_out):
+        """Check that expr_in is always equals to expr_out"""
+        print "Ensure %s = %s" % (expr_in, expr_out)
+        solver = z3.Solver()
+        solver.add(trans.from_expr(expr_in) != trans.from_expr(expr_out))
+
+        result = solver.check()
+
+        if result != z3.unsat:
+            print "ERROR: a counter-example has been founded:"
+            model = solver.model()
+            print model
+
+            print "Reinjecting in the simplifier:"
+            to_rep = {}
+            expressions = expr_in.get_r().union(expr_out.get_r())
+            for expr in expressions:
+                value = model.eval(trans.from_expr(expr))
+                if hasattr(value, "as_long"):
+                    new_val = ExprInt(value.as_long(), expr.size)
+                else:
+                    raise RuntimeError("Unable to reinject %r" % value)
+
+                to_rep[expr] = new_val
+
+            new_expr_in = expr_in.replace_expr(to_rep)
+            new_expr_out = expr_out.replace_expr(to_rep)
+
+            print "Check %s = %s" % (new_expr_in, new_expr_out)
+            simp_in = expr_simp(new_expr_in)
+            simp_out =  expr_simp(new_expr_out)
+            print "[%s] %s = %s" % (simp_in == simp_out, simp_in, simp_out)
+
+            # Either the simplification does not stand, either the test is wrong
+            raise RuntimeError("Bad simplification")
+
+else:
+    # Dummy 'check' method to avoid checking the '--z3' argument each time
+    check = lambda expr_in, expr_out: None
+
+
 # Define example objects
 a = ExprId('a', 32)
 b = ExprId('b', 32)
@@ -15,6 +73,22 @@ d = ExprId('d', 32)
 e = ExprId('e', 32)
 f = ExprId('f', size=64)
 
+b_msb_null = b[:31].zeroExtend(32)
+c_msb_null = c[:31].zeroExtend(32)
+
+a31 = ExprId('a31', 31)
+b31 = ExprId('b31', 31)
+c31 = ExprId('c31', 31)
+b31_msb_null = ExprId('b31', 31)[:30].zeroExtend(31)
+c31_msb_null = ExprId('c31', 31)[:30].zeroExtend(31)
+
+a8 = ExprId('a8', 8)
+b8 = ExprId('b8', 8)
+c8 = ExprId('c8', 8)
+d8 = ExprId('d8', 8)
+e8 = ExprId('e8', 8)
+
+
 m = ExprMem(a)
 s = a[:8]
 
@@ -49,17 +123,35 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
            (ExprOp('>>>', a, ExprInt(32, 32)), a),
            (ExprOp('>>>', a, ExprInt(0, 32)), a),
            (ExprOp('<<', a, ExprInt(0, 32)), a),
+           (ExprOp('<<<', a31, ExprInt(31, 31)), a31),
+           (ExprOp('>>>', a31, ExprInt(31, 31)), a31),
+           (ExprOp('>>>', a31, ExprInt(0, 31)), a31),
+           (ExprOp('<<', a31, ExprInt(0, 31)), a31),
+
+           (ExprOp('<<<', a31, ExprOp('<<<', b31, c31)),
+            ExprOp('<<<', a31, ExprOp('<<<', b31, c31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, b31), c31),
+            ExprOp('<<<', ExprOp('>>>', a31, b31), c31)),
+           (ExprOp('>>>', ExprOp('<<<', a31, b31), c31),
+            ExprOp('>>>', ExprOp('<<<', a31, b31), c31)),
+           (ExprOp('>>>', ExprOp('<<<', a31, b31), b31),
+            a31),
+           (ExprOp('<<<', ExprOp('>>>', a31, b31), b31),
+            a31),
+           (ExprOp('>>>', ExprOp('>>>', a31, b31), b31),
+            ExprOp('>>>', ExprOp('>>>', a31, b31), b31)),
+           (ExprOp('<<<', ExprOp('<<<', a31, b31), b31),
+            ExprOp('<<<', ExprOp('<<<', a31, b31), b31)),
+
+           (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)),
+            ExprOp('>>>', a31, ExprInt(0x13, 31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)),
+            ExprOp('<<<', a31, ExprInt(0x13, 31))),
+           (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)),
+            ExprOp('>>>', a31, ExprInt(0x1c, 31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)),
+            ExprOp('<<<', a31, ExprInt(0x1c, 31))),
 
-           (ExprOp('<<<', a, ExprOp('<<<', b, c)),
-            ExprOp('<<<', a, ExprOp('<<<', b, c))),
-           (ExprOp('<<<', ExprOp('<<<', a, b), c),
-            ExprOp('<<<', a, (b+c))),
-           (ExprOp('<<<', ExprOp('>>>', a, b), c),
-            ExprOp('>>>', a, (b-c))),
-           (ExprOp('>>>', ExprOp('<<<', a, b), c),
-            ExprOp('<<<', a, (b-c))),
-           (ExprOp('>>>', ExprOp('<<<', a, b), b),
-            a),
            (ExprOp(">>>", ExprInt(0x1000, 16), ExprInt(0x11, 16)),
             ExprInt(0x800, 16)),
            (ExprOp("<<<", ExprInt(0x1000, 16), ExprInt(0x11, 16)),
@@ -134,9 +226,9 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
     (ExprOp('*', a, b, c, ExprInt(0x12, 32))[0:17],
      ExprOp(
      '*', a[0:17], b[0:17], c[0:17], ExprInt(0x12, 17))),
-    (ExprOp('*', a, ExprInt(0x0, 32)),
+    (ExprOp('*', a, b, ExprInt(0x0, 32)),
      ExprInt(0x0, 32)),
-    (ExprOp('&', a, ExprInt(0x0, 32)),
+    (ExprOp('&', a, b, ExprInt(0x0, 32)),
      ExprInt(0x0, 32)),
     (ExprOp('*', a, ExprInt(0xffffffff, 32)),
      -a),
@@ -144,16 +236,14 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
      ExprOp('*', a, b, c, ExprInt(0x12, 32))),
     (ExprOp('*', -a, -b, -c, ExprInt(0x12, 32)),
      - ExprOp('*', a, b, c, ExprInt(0x12, 32))),
-     (ExprOp('**', ExprInt(2, 32), ExprInt(8, 32)), ExprInt(0x100, 32)),
-     (ExprInt(2, 32)**ExprInt(8, 32), ExprInt(256, 32)),
     (a | ExprInt(0xffffffff, 32),
      ExprInt(0xffffffff, 32)),
     (ExprCond(a, ExprInt(1, 32), ExprInt(2, 32)) * ExprInt(4, 32),
      ExprCond(a, ExprInt(4, 32), ExprInt(8, 32))),
     (ExprCond(a, b, c) + ExprCond(a, d, e),
      ExprCond(a, b + d, c + e)),
-    (ExprCond(a, b, c) * ExprCond(a, d, e),
-     ExprCond(a, b * d, c * e)),
+    (ExprCond(a8, b8, c8) * ExprCond(a8, d8, e8),
+     ExprCond(a8, b8 * d8, c8 * e8)),
 
     (ExprCond(a, ExprInt(8, 32), ExprInt(4, 32)) >> ExprInt(1, 32),
      ExprCond(a, ExprInt(4, 32), ExprInt(2, 32))),
@@ -268,37 +358,37 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
      a[:16]),
     ((a << ExprInt(16, 32))[24:32],
      a[8:16]),
-    (expr_cmpu(ExprInt(0, 32), ExprInt(0, 32)),
+    (expr_is_unsigned_greater(ExprInt(0, 32), ExprInt(0, 32)),
      ExprInt(0, 1)),
-    (expr_cmpu(ExprInt(10, 32), ExprInt(0, 32)),
+    (expr_is_unsigned_greater(ExprInt(10, 32), ExprInt(0, 32)),
      ExprInt(1, 1)),
-    (expr_cmpu(ExprInt(10, 32), ExprInt(5, 32)),
+    (expr_is_unsigned_greater(ExprInt(10, 32), ExprInt(5, 32)),
      ExprInt(1, 1)),
-    (expr_cmpu(ExprInt(5, 32), ExprInt(10, 32)),
+    (expr_is_unsigned_greater(ExprInt(5, 32), ExprInt(10, 32)),
      ExprInt(0, 1)),
-    (expr_cmpu(ExprInt(-1, 32), ExprInt(0, 32)),
+    (expr_is_unsigned_greater(ExprInt(-1, 32), ExprInt(0, 32)),
      ExprInt(1, 1)),
-    (expr_cmpu(ExprInt(-1, 32), ExprInt(-1, 32)),
+    (expr_is_unsigned_greater(ExprInt(-1, 32), ExprInt(-1, 32)),
      ExprInt(0, 1)),
-    (expr_cmpu(ExprInt(0, 32), ExprInt(-1, 32)),
+    (expr_is_unsigned_greater(ExprInt(0, 32), ExprInt(-1, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(0, 32), ExprInt(0, 32)),
+    (expr_is_signed_greater(ExprInt(0, 32), ExprInt(0, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(10, 32), ExprInt(0, 32)),
+    (expr_is_signed_greater(ExprInt(10, 32), ExprInt(0, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(10, 32), ExprInt(5, 32)),
+    (expr_is_signed_greater(ExprInt(10, 32), ExprInt(5, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(5, 32), ExprInt(10, 32)),
+    (expr_is_signed_greater(ExprInt(5, 32), ExprInt(10, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(-1, 32), ExprInt(0, 32)),
+    (expr_is_signed_greater(ExprInt(-1, 32), ExprInt(0, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(-1, 32), ExprInt(-1, 32)),
+    (expr_is_signed_greater(ExprInt(-1, 32), ExprInt(-1, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(0, 32), ExprInt(-1, 32)),
+    (expr_is_signed_greater(ExprInt(0, 32), ExprInt(-1, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(-5, 32), ExprInt(-10, 32)),
+    (expr_is_signed_greater(ExprInt(-5, 32), ExprInt(-10, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(-10, 32), ExprInt(-5, 32)),
+    (expr_is_signed_greater(ExprInt(-10, 32), ExprInt(-5, 32)),
      ExprInt(0, 1)),
 
     (ExprOp("idiv", ExprInt(0x0123, 16), ExprInt(0xfffb, 16))[:8],
@@ -312,21 +402,26 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
                  ExprInt(0x0321, 16))),
     (ExprCompose(ExprCond(a, i1, i0), ExprCond(a, i1, i2)),
      ExprCond(a, ExprInt(0x100000001, 64), ExprInt(0x200000000, 64))),
-    ((ExprMem(ExprCond(a, b, c)),ExprCond(a, ExprMem(b), ExprMem(c)))),
+    ((ExprMem(ExprCond(a, b, c), 4),ExprCond(a, ExprMem(b, 4), ExprMem(c, 4)))),
     (ExprCond(a, i0, i1) + ExprCond(a, i0, i1), ExprCond(a, i0, i2)),
 
+    (a << b << c, a << b << c), # Left unmodified
+    (a << b_msb_null << c_msb_null,
+     a << (b_msb_null + c_msb_null)),
+    (a >> b >> c, a >> b >> c), # Left unmodified
+    (a >> b_msb_null >> c_msb_null,
+     a >> (b_msb_null + c_msb_null)),
 ]
 
-for e, e_check in to_test[:]:
-    #
+for e_input, e_check in to_test:
     print "#" * 80
-    # print str(e), str(e_check)
-    e_new = expr_simp(e)
-    print "original: ", str(e), "new: ", str(e_new)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check))
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+    check(e_input, e_check)
 
 # Test conds
 
@@ -356,18 +451,46 @@ expr_simp_cond = ExpressionSimplifier()
 expr_simp.enable_passes(ExpressionSimplifier.PASS_COND)
 
 
-for e, e_check in to_test[:]:
-    #
+for e_input, e_check in to_test:
     print "#" * 80
     e_check = expr_simp(e_check)
-    # print str(e), str(e_check)
-    e_new = expr_simp(e)
-    print "original: ", str(e), "new: ", str(e_new)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check))
-
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+
+if args.z3:
+    # This check is done on 32 bits, but the size is not use by Miasm formulas, so
+    # it should be OK for any size > 0
+    x1 = ExprId("x1", 32)
+    x2 = ExprId("x2", 32)
+    i1_tmp = ExprInt(1, 1)
+
+    x1_z3 = trans.from_expr(x1)
+    x2_z3 = trans.from_expr(x2)
+    i1_z3 = trans.from_expr(i1_tmp)
+
+    # (Assumptions, function(arg1, arg2) -> True/False (= i1/i0) to check)
+    tests = [
+        (x1_z3 == x2_z3, expr_is_equal),
+        (x1_z3 != x2_z3, expr_is_not_equal),
+        (z3.UGT(x1_z3, x2_z3), expr_is_unsigned_greater),
+        (z3.UGE(x1_z3, x2_z3), expr_is_unsigned_greater_or_equal),
+        (z3.ULT(x1_z3, x2_z3), expr_is_unsigned_lower),
+        (z3.ULE(x1_z3, x2_z3), expr_is_unsigned_lower_or_equal),
+        (x1_z3 > x2_z3, expr_is_signed_greater),
+        (x1_z3 >= x2_z3, expr_is_signed_greater_or_equal),
+        (x1_z3 < x2_z3, expr_is_signed_lower),
+        (x1_z3 <= x2_z3, expr_is_signed_lower_or_equal),
+    ]
+
+    for assumption, func in tests:
+        solver = z3.Solver()
+        solver.add(assumption)
+        solver.add(trans.from_expr(func(x1, x2)) != i1_z3)
+        assert solver.check() == z3.unsat
 
 
 x = ExprId('x', 32)