about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--test/expression/simplifications.py84
-rwxr-xr-xtest/test_all.py3
2 files changed, 75 insertions, 12 deletions
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index 53283bef..949aa6ff 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -2,10 +2,73 @@
 # Expression simplification regression tests  #
 #
 from pdb import pm
+from argparse import ArgumentParser
+
 from miasm2.expression.expression import *
 from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier
 from miasm2.expression.simplifications_cond import ExprOp_inf_signed, ExprOp_inf_unsigned, ExprOp_equal
 
+parser = ArgumentParser("Expression simplification regression tests")
+parser.add_argument("--z3", action="store_true", help="Enable check against z3")
+parser.add_argument("--z3-timeout", type=int, help="z3 timeout (in seconds)",
+                    default=20)
+args = parser.parse_args()
+
+# Additionnal imports and definitions
+if args.z3:
+    import z3
+    from miasm2.ir.translators import Translator
+    trans = Translator.to_language("z3")
+
+    def check(expr_in, expr_out):
+        """Check that expr_in is always equals to expr_out"""
+        print "Ensure %s = %s" % (expr_in, expr_out)
+        solver = z3.Solver()
+        solver.set("timeout", args.z3_timeout * 1000)
+        try:
+            solver.add(trans.from_expr(expr_in) != trans.from_expr(expr_out))
+        except NotImplementedError as error:
+            print "Unable to translate in z3", error
+            return
+
+        result = solver.check()
+        if result == z3.unknown:
+            print "-> Timeout!"
+            return
+
+        if result != z3.unsat:
+            print "ERROR: a counter-example has been founded:"
+            model = solver.model()
+            print model
+
+            print "Reinjecting in the simplifier:"
+            to_rep = {}
+            expressions = expr_in.get_r().union(expr_out.get_r())
+            for expr in expressions:
+                value = model.eval(trans.from_expr(expr))
+                if hasattr(value, "as_long"):
+                    new_val = ExprInt(value.as_long(), expr.size)
+                else:
+                    raise RuntimeError("Unable to reinject %r" % value)
+
+                to_rep[expr] = new_val
+
+            new_expr_in = expr_in.replace_expr(to_rep)
+            new_expr_out = expr_out.replace_expr(to_rep)
+
+            print "Check %s = %s" % (new_expr_in, new_expr_out)
+            simp_in = expr_simp(new_expr_in)
+            simp_out =  expr_simp(new_expr_out)
+            print "[%s] %s = %s" % (simp_in == simp_out, simp_in, simp_out)
+
+            # Either the simplification does not stand, either the test is wrong
+            raise RuntimeError("Bad simplification")
+
+else:
+    # Dummy 'check' method to avoid checking the '--z3' argument each time
+    check = lambda expr_in, expr_out: None
+
+
 # Define example objects
 a = ExprId('a', 32)
 b = ExprId('b', 32)
@@ -316,16 +379,15 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
 
 ]
 
-for e, e_check in to_test[:]:
-    #
+for e_input, e_check in to_test:
     print "#" * 80
-    # print str(e), str(e_check)
-    e_new = expr_simp(e)
-    print "original: ", str(e), "new: ", str(e_new)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check))
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+    check(e_input, e_check)
 
 # Test conds
 
@@ -355,17 +417,15 @@ expr_simp_cond = ExpressionSimplifier()
 expr_simp.enable_passes(ExpressionSimplifier.PASS_COND)
 
 
-for e, e_check in to_test[:]:
-    #
+for e_input, e_check in to_test:
     print "#" * 80
     e_check = expr_simp(e_check)
-    # print str(e), str(e_check)
-    e_new = expr_simp(e)
-    print "original: ", str(e), "new: ", str(e_new)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check))
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
 
 
 
diff --git a/test/test_all.py b/test/test_all.py
index 04aca62e..6aa2a97e 100755
--- a/test/test_all.py
+++ b/test/test_all.py
@@ -249,6 +249,9 @@ for script in ["modint.py",
                "expr_cmp.py",
                ]:
     testset += RegressionTest([script], base_dir="expression")
+testset += RegressionTest(["simplifications.py", "--z3"],
+                          base_dir="expression",
+                          tags=[TAGS["z3"]])
 
 ## ObjC/CHandler
 testset += RegressionTest(["test_chandler.py"], base_dir="expr_type",