about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/arch/x86/sem.py13
-rw-r--r--miasm2/expression/expression_helper.py10
-rw-r--r--miasm2/expression/simplifications.py14
-rw-r--r--miasm2/expression/simplifications_common.py58
-rw-r--r--miasm2/ir/translators/z3_ir.py17
-rw-r--r--test/expression/simplifications.py219
-rwxr-xr-xtest/test_all.py3
7 files changed, 257 insertions, 77 deletions
diff --git a/miasm2/arch/x86/sem.py b/miasm2/arch/x86/sem.py
index 93c4910a..9f438b71 100644
--- a/miasm2/arch/x86/sem.py
+++ b/miasm2/arch/x86/sem.py
@@ -21,7 +21,6 @@ import miasm2.expression.expression as m2_expr
 from miasm2.expression.simplifications import expr_simp
 from miasm2.arch.x86.regs import *
 from miasm2.arch.x86.arch import mn_x86, repeat_mn, replace_regs
-from miasm2.expression.expression_helper import expr_cmps, expr_cmpu
 from miasm2.ir.ir import IntermediateRepresentation, IRBlock, AssignBlock
 from miasm2.core.sembuilder import SemBuilder
 from miasm2.jitter.csts import EXCEPT_DIV_BY_ZERO, EXCEPT_ILLEGAL_INSN, \
@@ -2741,11 +2740,11 @@ def daa(_, instr):
     e = []
     r_al = mRAX[instr.mode][:8]
 
-    cond1 = expr_cmpu(r_al[:4], m2_expr.ExprInt(0x9, 4)) | af
+    cond1 = m2_expr.expr_is_unsigned_greater(r_al[:4], m2_expr.ExprInt(0x9, 4)) | af
     e.append(m2_expr.ExprAff(af, cond1))
 
-    cond2 = expr_cmpu(m2_expr.ExprInt(6, 8), r_al)
-    cond3 = expr_cmpu(r_al, m2_expr.ExprInt(0x99, 8)) | cf
+    cond2 = m2_expr.expr_is_unsigned_greater(m2_expr.ExprInt(6, 8), r_al)
+    cond3 = m2_expr.expr_is_unsigned_greater(r_al, m2_expr.ExprInt(0x99, 8)) | cf
 
     cf_c1 = m2_expr.ExprCond(cond1,
                              cf | (cond2),
@@ -2771,11 +2770,11 @@ def das(_, instr):
     e = []
     r_al = mRAX[instr.mode][:8]
 
-    cond1 = expr_cmpu(r_al[:4], m2_expr.ExprInt(0x9, 4)) | af
+    cond1 = m2_expr.expr_is_unsigned_greater(r_al[:4], m2_expr.ExprInt(0x9, 4)) | af
     e.append(m2_expr.ExprAff(af, cond1))
 
-    cond2 = expr_cmpu(m2_expr.ExprInt(6, 8), r_al)
-    cond3 = expr_cmpu(r_al, m2_expr.ExprInt(0x99, 8)) | cf
+    cond2 = m2_expr.expr_is_unsigned_greater(m2_expr.ExprInt(6, 8), r_al)
+    cond3 = m2_expr.expr_is_unsigned_greater(r_al, m2_expr.ExprInt(0x99, 8)) | cf
 
     cf_c1 = m2_expr.ExprCond(cond1,
                              cf | (cond2),
diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py
index 1e718faa..722d169d 100644
--- a/miasm2/expression/expression_helper.py
+++ b/miasm2/expression/expression_helper.py
@@ -21,6 +21,7 @@ import itertools
 import collections
 import random
 import string
+import warnings
 
 import miasm2.expression.expression as m2_expr
 
@@ -468,16 +469,14 @@ class ExprRandom(object):
 
         return got
 
-def _expr_cmp_gen(arg1, arg2):
-    return (arg2 - arg1) ^ ((arg2 ^ arg1) & ((arg2 - arg1) ^ arg2))
-
 def expr_cmpu(arg1, arg2):
     """
     Returns a one bit long Expression:
     * 1 if @arg1 is strictly greater than @arg2 (unsigned)
     * 0 otherwise.
     """
-    return (_expr_cmp_gen(arg1, arg2) ^ arg2 ^ arg1).msb()
+    warnings.warn('DEPRECATION WARNING: use "expr_is_unsigned_greater" instead"')
+    return m2_expr.expr_is_unsigned_greater(arg1, arg2)
 
 def expr_cmps(arg1, arg2):
     """
@@ -485,7 +484,8 @@ def expr_cmps(arg1, arg2):
     * 1 if @arg1 is strictly greater than @arg2 (signed)
     * 0 otherwise.
     """
-    return _expr_cmp_gen(arg1, arg2).msb()
+    warnings.warn('DEPRECATION WARNING: use "expr_is_signed_greater" instead"')
+    return m2_expr.expr_is_signed_greater(arg1, arg2)
 
 
 class CondConstraint(object):
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index d3483d9e..e6c5dc54 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -2,6 +2,8 @@
 #                     Simplification methods library                           #
 #                                                                              #
 
+import logging
+
 from miasm2.expression import simplifications_common
 from miasm2.expression import simplifications_cond
 from miasm2.expression.expression_helper import fast_unify
@@ -10,6 +12,12 @@ import miasm2.expression.expression as m2_expr
 # Expression Simplifier
 # ---------------------
 
+log_exprsimp = logging.getLogger("exprsimp")
+console_handler = logging.StreamHandler()
+console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
+log_exprsimp.addHandler(console_handler)
+log_exprsimp.setLevel(logging.WARNING)
+
 
 class ExpressionSimplifier(object):
 
@@ -67,9 +75,15 @@ class ExpressionSimplifier(object):
         Return an Expr instance"""
 
         cls = expression.__class__
+        debug_level = log_exprsimp.level >= logging.DEBUG
         for simp_func in self.expr_simp_cb.get(cls, []):
             # Apply simplifications
+            before = expression
             expression = simp_func(self, expression)
+            after = expression
+
+            if debug_level and before != after:
+                log_exprsimp.debug("[%s] %s => %s", simp_func, before, after)
 
             # If class changes, stop to prevent wrong simplifications
             if expression.__class__ is not cls:
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 02b43c4b..ccb97cb3 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -196,26 +196,44 @@ def simp_cst_propagation(e_s, expr):
         args[1].arg == args[0].size):
         return args[0]
 
-    # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>)
+    # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
     if (op_name in ['<<<', '>>>'] and
         args[0].is_op() and
         args[0].op in ['<<<', '>>>']):
-        op1 = op_name
-        op2 = args[0].op
-        if op1 == op2:
-            op_name = op1
-            args1 = args[0].args[1] + args[1]
-        else:
-            op_name = op2
-            args1 = args[0].args[1] - args[1]
+        A = args[0].args[0]
+        X = args[0].args[1]
+        Y = args[1]
+        if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size):
+            return args[0].args[0]
+        elif X.is_int() and Y.is_int():
+            new_X = int(X) % expr.size
+            new_Y = int(Y) % expr.size
+            if op_name == args[0].op:
+                rot = (new_X + new_Y) % expr.size
+                op = op_name
+            else:
+                rot = new_Y - new_X
+                op = op_name
+                if rot < 0:
+                    rot = - rot
+                    op = {">>>": "<<<", "<<<": ">>>"}[op_name]
+            args = [A, ExprInt(rot, expr.size)]
+            op_name = op
 
-        args0 = args[0].args[0]
-        args = [args0, args1]
+        else:
+            # Do not consider this case, too tricky (overflow on addition /
+            # substraction)
+            pass
 
-    # A >> X >> Y  =>  A >> (X+Y)
+    # A >> X >> Y  =>  A >> (X+Y) if X + Y does not overflow
+    # To be sure, only consider the simplification when X.msb and Y.msb are 0
     if (op_name in ['<<', '>>'] and
         args[0].is_op(op_name)):
-        args = [args[0].args[0], args[0].args[1] + args[1]]
+        X = args[0].args[1]
+        Y = args[1]
+        if (e_s(X.msb()) == ExprInt(0, 1) and
+            e_s(Y.msb()) == ExprInt(0, 1)):
+            args = [args[0].args[0], X + Y]
 
     # ((A & A.mask)
     if op_name == "&" and args[-1] == expr.mask:
@@ -327,7 +345,7 @@ def simp_cond_op_int(e_s, expr):
     "Extract conditions from operations"
 
 
-    # x?a:b + x?c:d + e => x?(a+b+e:c+d+e)
+    # x?a:b + x?c:d + e => x?(a+c+e:b+d+e)
     if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']:
         return expr
     if len(expr.args) < 2:
@@ -360,6 +378,14 @@ def simp_cond_factor(e_s, expr):
         return expr
     if len(expr.args) < 2:
         return expr
+
+    if expr.op in ['>>', '<<', 'a>>']:
+        assert len(expr.args) == 2
+
+    # Note: the following code is correct for non-commutative operation only if
+    # there is 2 arguments. Otherwise, the order is not conserved
+
+    # Regroup sub-expression by similar conditions
     conds = {}
     not_conds = []
     multi_cond = False
@@ -375,7 +401,9 @@ def simp_cond_factor(e_s, expr):
         conds[cond].append(arg)
     if not multi_cond:
         return expr
-    c_out = not_conds[:]
+
+    # Rebuild the new expression
+    c_out = not_conds
     for cond, vals in conds.items():
         new_src1 = [x.src1 for x in vals]
         new_src2 = [x.src2 for x in vals]
diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py
index d8b550d9..d33764fb 100644
--- a/miasm2/ir/translators/z3_ir.py
+++ b/miasm2/ir/translators/z3_ir.py
@@ -151,6 +151,19 @@ class TranslatorZ3(Translator):
         src2 = self.from_expr(expr.src2)
         return z3.If(cond != 0, src1, src2)
 
+    def _abs(self, z3_value):
+        return z3.If(z3_value >= 0,z3_value,-z3_value)
+
+    def _idivC(self, num, den):
+        """Divide (signed) @num by @den (z3 values) as C would
+        See modint.__div__ for implementation choice
+        """
+        result_sign = z3.If(num * den >= 0,
+                            z3.BitVecVal(1, num.size()),
+                            z3.BitVecVal(-1, num.size()),
+        )
+        return z3.UDiv(self._abs(num), self._abs(den)) * result_sign
+
     def from_ExprOp(self, expr):
         args = map(self.from_expr, expr.args)
         res = args[0]
@@ -168,11 +181,11 @@ class TranslatorZ3(Translator):
                 elif expr.op == ">>>":
                     res = z3.RotateRight(res, arg)
                 elif expr.op == "idiv":
-                    res = res / arg
+                    res = self._idivC(res, arg)
                 elif expr.op == "udiv":
                     res = z3.UDiv(res, arg)
                 elif expr.op == "imod":
-                    res = res % arg
+                    res = res - (arg * (self._idivC(res, arg)))
                 elif expr.op == "umod":
                     res = z3.URem(res, arg)
                 else:
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index e4c3f2e9..0c516a8e 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -2,11 +2,69 @@
 # Expression simplification regression tests  #
 #
 from pdb import pm
+from argparse import ArgumentParser
+import logging
+
 from miasm2.expression.expression import *
-from miasm2.expression.expression_helper import expr_cmpu, expr_cmps
-from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier
+from miasm2.expression.simplifications import expr_simp, ExpressionSimplifier, log_exprsimp
 from miasm2.expression.simplifications_cond import ExprOp_inf_signed, ExprOp_inf_unsigned, ExprOp_equal
 
+parser = ArgumentParser("Expression simplification regression tests")
+parser.add_argument("--z3", action="store_true", help="Enable check against z3")
+parser.add_argument("-v", "--verbose", action="store_true",
+                    help="Verbose simplify")
+args = parser.parse_args()
+
+if args.verbose:
+    log_exprsimp.setLevel(logging.DEBUG)
+
+# Additionnal imports and definitions
+if args.z3:
+    import z3
+    from miasm2.ir.translators import Translator
+    trans = Translator.to_language("z3")
+
+    def check(expr_in, expr_out):
+        """Check that expr_in is always equals to expr_out"""
+        print "Ensure %s = %s" % (expr_in, expr_out)
+        solver = z3.Solver()
+        solver.add(trans.from_expr(expr_in) != trans.from_expr(expr_out))
+
+        result = solver.check()
+
+        if result != z3.unsat:
+            print "ERROR: a counter-example has been founded:"
+            model = solver.model()
+            print model
+
+            print "Reinjecting in the simplifier:"
+            to_rep = {}
+            expressions = expr_in.get_r().union(expr_out.get_r())
+            for expr in expressions:
+                value = model.eval(trans.from_expr(expr))
+                if hasattr(value, "as_long"):
+                    new_val = ExprInt(value.as_long(), expr.size)
+                else:
+                    raise RuntimeError("Unable to reinject %r" % value)
+
+                to_rep[expr] = new_val
+
+            new_expr_in = expr_in.replace_expr(to_rep)
+            new_expr_out = expr_out.replace_expr(to_rep)
+
+            print "Check %s = %s" % (new_expr_in, new_expr_out)
+            simp_in = expr_simp(new_expr_in)
+            simp_out =  expr_simp(new_expr_out)
+            print "[%s] %s = %s" % (simp_in == simp_out, simp_in, simp_out)
+
+            # Either the simplification does not stand, either the test is wrong
+            raise RuntimeError("Bad simplification")
+
+else:
+    # Dummy 'check' method to avoid checking the '--z3' argument each time
+    check = lambda expr_in, expr_out: None
+
+
 # Define example objects
 a = ExprId('a', 32)
 b = ExprId('b', 32)
@@ -15,6 +73,22 @@ d = ExprId('d', 32)
 e = ExprId('e', 32)
 f = ExprId('f', size=64)
 
+b_msb_null = b[:31].zeroExtend(32)
+c_msb_null = c[:31].zeroExtend(32)
+
+a31 = ExprId('a31', 31)
+b31 = ExprId('b31', 31)
+c31 = ExprId('c31', 31)
+b31_msb_null = ExprId('b31', 31)[:30].zeroExtend(31)
+c31_msb_null = ExprId('c31', 31)[:30].zeroExtend(31)
+
+a8 = ExprId('a8', 8)
+b8 = ExprId('b8', 8)
+c8 = ExprId('c8', 8)
+d8 = ExprId('d8', 8)
+e8 = ExprId('e8', 8)
+
+
 m = ExprMem(a)
 s = a[:8]
 
@@ -49,17 +123,35 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
            (ExprOp('>>>', a, ExprInt(32, 32)), a),
            (ExprOp('>>>', a, ExprInt(0, 32)), a),
            (ExprOp('<<', a, ExprInt(0, 32)), a),
+           (ExprOp('<<<', a31, ExprInt(31, 31)), a31),
+           (ExprOp('>>>', a31, ExprInt(31, 31)), a31),
+           (ExprOp('>>>', a31, ExprInt(0, 31)), a31),
+           (ExprOp('<<', a31, ExprInt(0, 31)), a31),
+
+           (ExprOp('<<<', a31, ExprOp('<<<', b31, c31)),
+            ExprOp('<<<', a31, ExprOp('<<<', b31, c31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, b31), c31),
+            ExprOp('<<<', ExprOp('>>>', a31, b31), c31)),
+           (ExprOp('>>>', ExprOp('<<<', a31, b31), c31),
+            ExprOp('>>>', ExprOp('<<<', a31, b31), c31)),
+           (ExprOp('>>>', ExprOp('<<<', a31, b31), b31),
+            a31),
+           (ExprOp('<<<', ExprOp('>>>', a31, b31), b31),
+            a31),
+           (ExprOp('>>>', ExprOp('>>>', a31, b31), b31),
+            ExprOp('>>>', ExprOp('>>>', a31, b31), b31)),
+           (ExprOp('<<<', ExprOp('<<<', a31, b31), b31),
+            ExprOp('<<<', ExprOp('<<<', a31, b31), b31)),
+
+           (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)),
+            ExprOp('>>>', a31, ExprInt(0x13, 31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(0x1234, 31)), ExprInt(0x1111, 31)),
+            ExprOp('<<<', a31, ExprInt(0x13, 31))),
+           (ExprOp('>>>', ExprOp('<<<', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)),
+            ExprOp('>>>', a31, ExprInt(0x1c, 31))),
+           (ExprOp('<<<', ExprOp('>>>', a31, ExprInt(-1, 31)), ExprInt(0x1111, 31)),
+            ExprOp('<<<', a31, ExprInt(0x1c, 31))),
 
-           (ExprOp('<<<', a, ExprOp('<<<', b, c)),
-            ExprOp('<<<', a, ExprOp('<<<', b, c))),
-           (ExprOp('<<<', ExprOp('<<<', a, b), c),
-            ExprOp('<<<', a, (b+c))),
-           (ExprOp('<<<', ExprOp('>>>', a, b), c),
-            ExprOp('>>>', a, (b-c))),
-           (ExprOp('>>>', ExprOp('<<<', a, b), c),
-            ExprOp('<<<', a, (b-c))),
-           (ExprOp('>>>', ExprOp('<<<', a, b), b),
-            a),
            (ExprOp(">>>", ExprInt(0x1000, 16), ExprInt(0x11, 16)),
             ExprInt(0x800, 16)),
            (ExprOp("<<<", ExprInt(0x1000, 16), ExprInt(0x11, 16)),
@@ -134,9 +226,9 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
     (ExprOp('*', a, b, c, ExprInt(0x12, 32))[0:17],
      ExprOp(
      '*', a[0:17], b[0:17], c[0:17], ExprInt(0x12, 17))),
-    (ExprOp('*', a, ExprInt(0x0, 32)),
+    (ExprOp('*', a, b, ExprInt(0x0, 32)),
      ExprInt(0x0, 32)),
-    (ExprOp('&', a, ExprInt(0x0, 32)),
+    (ExprOp('&', a, b, ExprInt(0x0, 32)),
      ExprInt(0x0, 32)),
     (ExprOp('*', a, ExprInt(0xffffffff, 32)),
      -a),
@@ -144,16 +236,14 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
      ExprOp('*', a, b, c, ExprInt(0x12, 32))),
     (ExprOp('*', -a, -b, -c, ExprInt(0x12, 32)),
      - ExprOp('*', a, b, c, ExprInt(0x12, 32))),
-     (ExprOp('**', ExprInt(2, 32), ExprInt(8, 32)), ExprInt(0x100, 32)),
-     (ExprInt(2, 32)**ExprInt(8, 32), ExprInt(256, 32)),
     (a | ExprInt(0xffffffff, 32),
      ExprInt(0xffffffff, 32)),
     (ExprCond(a, ExprInt(1, 32), ExprInt(2, 32)) * ExprInt(4, 32),
      ExprCond(a, ExprInt(4, 32), ExprInt(8, 32))),
     (ExprCond(a, b, c) + ExprCond(a, d, e),
      ExprCond(a, b + d, c + e)),
-    (ExprCond(a, b, c) * ExprCond(a, d, e),
-     ExprCond(a, b * d, c * e)),
+    (ExprCond(a8, b8, c8) * ExprCond(a8, d8, e8),
+     ExprCond(a8, b8 * d8, c8 * e8)),
 
     (ExprCond(a, ExprInt(8, 32), ExprInt(4, 32)) >> ExprInt(1, 32),
      ExprCond(a, ExprInt(4, 32), ExprInt(2, 32))),
@@ -268,37 +358,37 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
      a[:16]),
     ((a << ExprInt(16, 32))[24:32],
      a[8:16]),
-    (expr_cmpu(ExprInt(0, 32), ExprInt(0, 32)),
+    (expr_is_unsigned_greater(ExprInt(0, 32), ExprInt(0, 32)),
      ExprInt(0, 1)),
-    (expr_cmpu(ExprInt(10, 32), ExprInt(0, 32)),
+    (expr_is_unsigned_greater(ExprInt(10, 32), ExprInt(0, 32)),
      ExprInt(1, 1)),
-    (expr_cmpu(ExprInt(10, 32), ExprInt(5, 32)),
+    (expr_is_unsigned_greater(ExprInt(10, 32), ExprInt(5, 32)),
      ExprInt(1, 1)),
-    (expr_cmpu(ExprInt(5, 32), ExprInt(10, 32)),
+    (expr_is_unsigned_greater(ExprInt(5, 32), ExprInt(10, 32)),
      ExprInt(0, 1)),
-    (expr_cmpu(ExprInt(-1, 32), ExprInt(0, 32)),
+    (expr_is_unsigned_greater(ExprInt(-1, 32), ExprInt(0, 32)),
      ExprInt(1, 1)),
-    (expr_cmpu(ExprInt(-1, 32), ExprInt(-1, 32)),
+    (expr_is_unsigned_greater(ExprInt(-1, 32), ExprInt(-1, 32)),
      ExprInt(0, 1)),
-    (expr_cmpu(ExprInt(0, 32), ExprInt(-1, 32)),
+    (expr_is_unsigned_greater(ExprInt(0, 32), ExprInt(-1, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(0, 32), ExprInt(0, 32)),
+    (expr_is_signed_greater(ExprInt(0, 32), ExprInt(0, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(10, 32), ExprInt(0, 32)),
+    (expr_is_signed_greater(ExprInt(10, 32), ExprInt(0, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(10, 32), ExprInt(5, 32)),
+    (expr_is_signed_greater(ExprInt(10, 32), ExprInt(5, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(5, 32), ExprInt(10, 32)),
+    (expr_is_signed_greater(ExprInt(5, 32), ExprInt(10, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(-1, 32), ExprInt(0, 32)),
+    (expr_is_signed_greater(ExprInt(-1, 32), ExprInt(0, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(-1, 32), ExprInt(-1, 32)),
+    (expr_is_signed_greater(ExprInt(-1, 32), ExprInt(-1, 32)),
      ExprInt(0, 1)),
-    (expr_cmps(ExprInt(0, 32), ExprInt(-1, 32)),
+    (expr_is_signed_greater(ExprInt(0, 32), ExprInt(-1, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(-5, 32), ExprInt(-10, 32)),
+    (expr_is_signed_greater(ExprInt(-5, 32), ExprInt(-10, 32)),
      ExprInt(1, 1)),
-    (expr_cmps(ExprInt(-10, 32), ExprInt(-5, 32)),
+    (expr_is_signed_greater(ExprInt(-10, 32), ExprInt(-5, 32)),
      ExprInt(0, 1)),
 
     (ExprOp("idiv", ExprInt(0x0123, 16), ExprInt(0xfffb, 16))[:8],
@@ -312,21 +402,26 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
                  ExprInt(0x0321, 16))),
     (ExprCompose(ExprCond(a, i1, i0), ExprCond(a, i1, i2)),
      ExprCond(a, ExprInt(0x100000001, 64), ExprInt(0x200000000, 64))),
-    ((ExprMem(ExprCond(a, b, c)),ExprCond(a, ExprMem(b), ExprMem(c)))),
+    ((ExprMem(ExprCond(a, b, c), 4),ExprCond(a, ExprMem(b, 4), ExprMem(c, 4)))),
     (ExprCond(a, i0, i1) + ExprCond(a, i0, i1), ExprCond(a, i0, i2)),
 
+    (a << b << c, a << b << c), # Left unmodified
+    (a << b_msb_null << c_msb_null,
+     a << (b_msb_null + c_msb_null)),
+    (a >> b >> c, a >> b >> c), # Left unmodified
+    (a >> b_msb_null >> c_msb_null,
+     a >> (b_msb_null + c_msb_null)),
 ]
 
-for e, e_check in to_test[:]:
-    #
+for e_input, e_check in to_test:
     print "#" * 80
-    # print str(e), str(e_check)
-    e_new = expr_simp(e)
-    print "original: ", str(e), "new: ", str(e_new)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check))
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+    check(e_input, e_check)
 
 # Test conds
 
@@ -356,18 +451,46 @@ expr_simp_cond = ExpressionSimplifier()
 expr_simp.enable_passes(ExpressionSimplifier.PASS_COND)
 
 
-for e, e_check in to_test[:]:
-    #
+for e_input, e_check in to_test:
     print "#" * 80
     e_check = expr_simp(e_check)
-    # print str(e), str(e_check)
-    e_new = expr_simp(e)
-    print "original: ", str(e), "new: ", str(e_new)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e, e_new, e_check))
-
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+
+if args.z3:
+    # This check is done on 32 bits, but the size is not use by Miasm formulas, so
+    # it should be OK for any size > 0
+    x1 = ExprId("x1", 32)
+    x2 = ExprId("x2", 32)
+    i1_tmp = ExprInt(1, 1)
+
+    x1_z3 = trans.from_expr(x1)
+    x2_z3 = trans.from_expr(x2)
+    i1_z3 = trans.from_expr(i1_tmp)
+
+    # (Assumptions, function(arg1, arg2) -> True/False (= i1/i0) to check)
+    tests = [
+        (x1_z3 == x2_z3, expr_is_equal),
+        (x1_z3 != x2_z3, expr_is_not_equal),
+        (z3.UGT(x1_z3, x2_z3), expr_is_unsigned_greater),
+        (z3.UGE(x1_z3, x2_z3), expr_is_unsigned_greater_or_equal),
+        (z3.ULT(x1_z3, x2_z3), expr_is_unsigned_lower),
+        (z3.ULE(x1_z3, x2_z3), expr_is_unsigned_lower_or_equal),
+        (x1_z3 > x2_z3, expr_is_signed_greater),
+        (x1_z3 >= x2_z3, expr_is_signed_greater_or_equal),
+        (x1_z3 < x2_z3, expr_is_signed_lower),
+        (x1_z3 <= x2_z3, expr_is_signed_lower_or_equal),
+    ]
+
+    for assumption, func in tests:
+        solver = z3.Solver()
+        solver.add(assumption)
+        solver.add(trans.from_expr(func(x1, x2)) != i1_z3)
+        assert solver.check() == z3.unsat
 
 
 x = ExprId('x', 32)
diff --git a/test/test_all.py b/test/test_all.py
index 04aca62e..6aa2a97e 100755
--- a/test/test_all.py
+++ b/test/test_all.py
@@ -249,6 +249,9 @@ for script in ["modint.py",
                "expr_cmp.py",
                ]:
     testset += RegressionTest([script], base_dir="expression")
+testset += RegressionTest(["simplifications.py", "--z3"],
+                          base_dir="expression",
+                          tags=[TAGS["z3"]])
 
 ## ObjC/CHandler
 testset += RegressionTest(["test_chandler.py"], base_dir="expr_type",