about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/simplifications.py31
-rw-r--r--miasm2/expression/simplifications_common.py132
-rw-r--r--miasm2/expression/simplifications_cond.py54
-rw-r--r--miasm2/ir/translators/z3_ir.py30
-rw-r--r--test/expression/simplifications.py63
5 files changed, 218 insertions, 92 deletions
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index 9114cbbe..a237a57e 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -45,16 +45,23 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_double_zeroext,
             simplifications_common.simp_double_signext,
             simplifications_common.simp_zeroext_eq_cst,
+            simplifications_common.simp_ext_eq_ext,
+
+            simplifications_common.simp_cmp_int,
+            simplifications_common.simp_cmp_int_int,
+            simplifications_common.simp_ext_cst,
 
         ],
 
-        m2_expr.ExprSlice: [simplifications_common.simp_slice],
+        m2_expr.ExprSlice: [
+            simplifications_common.simp_slice,
+            simplifications_common.simp_slice_of_ext,
+        ],
         m2_expr.ExprCompose: [simplifications_common.simp_compose],
         m2_expr.ExprCond: [
             simplifications_common.simp_cond,
             # CC op
             simplifications_common.simp_cond_flag,
-            simplifications_common.simp_cond_int,
             simplifications_common.simp_cmp_int_arg,
 
             simplifications_common.simp_cond_eq_zero,
@@ -68,14 +75,18 @@ class ExpressionSimplifier(object):
     PASS_HEAVY = {}
 
     # Cond passes
-    PASS_COND = {m2_expr.ExprSlice: [simplifications_cond.expr_simp_inf_signed,
-                                     simplifications_cond.expr_simp_inf_unsigned_inversed],
-                 m2_expr.ExprOp: [simplifications_cond.exec_inf_unsigned,
-                                  simplifications_cond.exec_inf_signed,
-                                  simplifications_cond.expr_simp_inverse,
-                                  simplifications_cond.exec_equal],
-                 m2_expr.ExprCond: [simplifications_cond.expr_simp_equal]
-                 }
+    PASS_COND = {
+        m2_expr.ExprSlice: [
+            simplifications_cond.expr_simp_inf_signed,
+            simplifications_cond.expr_simp_inf_unsigned_inversed
+        ],
+        m2_expr.ExprOp: [
+            simplifications_cond.expr_simp_inverse,
+        ],
+        m2_expr.ExprCond: [
+            simplifications_cond.expr_simp_equal
+        ]
+    }
 
 
     # Available passes lists are:
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 9a59fbd4..7db4e819 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -884,35 +884,34 @@ def simp_cond_flag(expr_simp, expr):
     return expr
 
 
-def simp_cond_int(expr_simp, expr):
-    if (expr.cond.is_op(TOK_EQUAL) and
-          expr.cond.args[1].is_int() and
-          expr.cond.args[0].is_compose() and
-          len(expr.cond.args[0].args) == 2 and
-          expr.cond.args[0].args[1].is_int(0)):
+def simp_cmp_int(expr_simp, expr):
+    # ({X, 0} == int) => X == int[:]
+    # X + int1 == int2 => X == int2-int1
+    if (expr.is_op(TOK_EQUAL) and
+          expr.args[1].is_int() and
+          expr.args[0].is_compose() and
+          len(expr.args[0].args) == 2 and
+          expr.args[0].args[1].is_int(0)):
         # ({X, 0} == int) => X == int[:]
-        src = expr.cond.args[0].args[0]
-        int_val = int(expr.cond.args[1])
+        src = expr.args[0].args[0]
+        int_val = int(expr.args[1])
         new_int = ExprInt(int_val, src.size)
         expr = expr_simp(
-            ExprCond(
-                ExprOp(TOK_EQUAL, src, new_int),
-                expr.src1,
-                expr.src2)
+            ExprOp(TOK_EQUAL, src, new_int)
         )
-    elif (expr.cond.is_op() and
-          expr.cond.op in [
+    elif (expr.is_op() and
+          expr.op in [
               TOK_EQUAL,
-              TOK_INF_SIGNED,
-              TOK_INF_EQUAL_SIGNED,
-              TOK_INF_UNSIGNED,
-              TOK_INF_EQUAL_UNSIGNED
           ] and
-          expr.cond.args[1].is_int() and
-          expr.cond.args[0].is_op("+") and
-          expr.cond.args[0].args[-1].is_int()):
+          expr.args[1].is_int() and
+          expr.args[0].is_op("+") and
+          expr.args[0].args[-1].is_int()):
         # X + int1 == int2 => X == int2-int1
-        left, right = expr.cond.args
+        # WARNING:
+        # X - 0x10 <=u 0x20 gives X in [0x10 0x30]
+        # which is not equivalet to A <=u 0x10
+
+        left, right = expr.args
         left, int_diff = left.args[:-1], left.args[-1]
         if len(left) == 1:
             left = left[0]
@@ -920,10 +919,7 @@ def simp_cond_int(expr_simp, expr):
             left = ExprOp('+', *left)
         new_int = expr_simp(right - int_diff)
         expr = expr_simp(
-            ExprCond(
-                ExprOp(expr.cond.op, left, new_int),
-                expr.src1,
-                expr.src2)
+            ExprOp(expr.op, left, new_int),
         )
     return expr
 
@@ -1047,6 +1043,20 @@ def simp_zeroext_eq_cst(expr_s, expr):
         return ExprInt(0, 1)
     return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size))
 
+def simp_ext_eq_ext(expr_s, expr):
+    # A.zeroExt(X) == B.zeroExt(X) => A == B
+    # A.signExt(X) == B.signExt(X) => A == B
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if (not ((arg1.is_op() and arg1.op.startswith("zeroExt") and
+              arg2.is_op() and arg2.op.startswith("zeroExt")) or
+             (arg1.is_op() and arg1.op.startswith("signExt") and
+               arg2.is_op() and arg2.op.startswith("signExt")))):
+        return expr
+    if arg1.args[0].size != arg2.args[0].size:
+        return expr
+    return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0])
 
 def simp_cond_eq_zero(expr_s, expr):
     # (X == 0)?(A:B) => X?(B:A)
@@ -1058,3 +1068,73 @@ def simp_cond_eq_zero(expr_s, expr):
         return expr
     new_expr = ExprCond(arg1, expr.src2, expr.src1)
     return new_expr
+
+
+def simp_cmp_int_int(expr_s, expr):
+    # IntA <s IntB => int
+    # IntA <u IntB => int
+    # IntA <=s IntB => int
+    # IntA <=u IntB => int
+    # IntA == IntB => int
+    if expr.op not in [
+            TOK_EQUAL,
+            TOK_INF_SIGNED, TOK_INF_UNSIGNED,
+            TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
+    ]:
+        return expr
+    if not all(arg.is_int() for arg in expr.args):
+        return expr
+    int_a, int_b = expr.args
+    if expr.is_op(TOK_EQUAL):
+        if int_a == int_b:
+            return ExprInt(1, 1)
+        else:
+            return ExprInt(0, 1)
+
+    if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]:
+        int_a = int(mod_size2int[int_a.size](int(int_a)))
+        int_b = int(mod_size2int[int_b.size](int(int_b)))
+    else:
+        int_a = int(mod_size2uint[int_a.size](int(int_a)))
+        int_b = int(mod_size2uint[int_b.size](int(int_b)))
+
+    if expr.op in [TOK_INF_SIGNED, TOK_INF_UNSIGNED]:
+        ret = int_a < int_b
+    else:
+        ret = int_a <= int_b
+
+    if ret:
+        ret = 1
+    else:
+        ret = 0
+    return ExprInt(ret, 1)
+
+
+def simp_ext_cst(expr_s, expr):
+    # Int.zeroExt(X) => Int
+    # Int.signExt(X) => Int
+    if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")):
+        return expr
+    arg = expr.args[0]
+    if not arg.is_int():
+        return expr
+    if expr.op.startswith("zeroExt"):
+        ret = int(arg)
+    else:
+        ret = int(mod_size2int[arg.size](int(arg)))
+    ret = ExprInt(ret, expr.size)
+    return ret
+
+
+def simp_slice_of_ext(expr_s, expr):
+    # zeroExt(X)[0:size(X)] => X
+    if expr.start != 0:
+        return expr
+    if not expr.arg.is_op():
+        return expr
+    if not expr.arg.op.startswith("zeroExt"):
+        return expr
+    arg = expr.arg.args[0]
+    if arg.size != expr.size:
+        return expr
+    return arg
diff --git a/miasm2/expression/simplifications_cond.py b/miasm2/expression/simplifications_cond.py
index 6bdc810f..f6b1ea8b 100644
--- a/miasm2/expression/simplifications_cond.py
+++ b/miasm2/expression/simplifications_cond.py
@@ -176,57 +176,3 @@ def expr_simp_equal(expr_simp, e):
         return e
 
     return ExprOp_equal(r[jok1], expr_simp(-r[jok2]))
-
-# Compute conditions
-
-def exec_inf_unsigned(expr_simp, e):
-    "Compute x <u y"
-    if e.op != m2_expr.TOK_INF_UNSIGNED:
-        return e
-
-    arg1, arg2 = e.args
-
-    if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt):
-        return m2_expr.ExprInt(1, 1) if (arg1.arg < arg2.arg) else m2_expr.ExprInt(0, 1)
-    else:
-        return e
-
-
-def __comp_signed(arg1, arg2):
-    """Return ExprInt(1, 1) if arg1 <s arg2 else ExprInt(0, 1)
-    @arg1, @arg2: ExprInt"""
-
-    val1 = int(arg1)
-    if val1 >> (arg1.size - 1) == 1:
-        val1 = - ((int(arg1.mask) ^ val1) + 1)
-
-    val2 = int(arg2)
-    if val2 >> (arg2.size - 1) == 1:
-        val2 = - ((int(arg2.mask) ^ val2) + 1)
-
-    return m2_expr.ExprInt(1, 1) if (val1 < val2) else m2_expr.ExprInt(0, 1)
-
-def exec_inf_signed(expr_simp, e):
-    "Compute x <s y"
-
-    if e.op != m2_expr.TOK_INF_SIGNED:
-        return e
-
-    arg1, arg2 = e.args
-
-    if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt):
-        return __comp_signed(arg1, arg2)
-    else:
-        return e
-
-def exec_equal(expr_simp, e):
-    "Compute x == y"
-
-    if e.op != m2_expr.TOK_EQUAL:
-        return e
-
-    arg1, arg2 = e.args
-    if isinstance(arg1, m2_expr.ExprInt) and isinstance(arg2, m2_expr.ExprInt):
-        return m2_expr.ExprInt(1, 1) if (arg1.arg == arg2.arg) else m2_expr.ExprInt(0, 1)
-    else:
-        return e
diff --git a/miasm2/ir/translators/z3_ir.py b/miasm2/ir/translators/z3_ir.py
index 55952180..1cc8c29d 100644
--- a/miasm2/ir/translators/z3_ir.py
+++ b/miasm2/ir/translators/z3_ir.py
@@ -205,6 +205,36 @@ class TranslatorZ3(Translator):
                     res = res - (arg * (self._idivC(res, arg)))
                 elif expr.op == "umod":
                     res = z3.URem(res, arg)
+                elif expr.op == "==":
+                    res = z3.If(
+                        args[0] == args[1],
+                        z3.BitVecVal(1, 1),
+                        z3.BitVecVal(0, 1)
+                    )
+                elif expr.op == "<u":
+                    res = z3.If(
+                        z3.ULT(args[0], args[1]),
+                        z3.BitVecVal(1, 1),
+                        z3.BitVecVal(0, 1)
+                    )
+                elif expr.op == "<s":
+                    res = z3.If(
+                        z3.SLT(args[0], args[1]),
+                        z3.BitVecVal(1, 1),
+                        z3.BitVecVal(0, 1)
+                    )
+                elif expr.op == "<=u":
+                    res = z3.If(
+                        z3.ULE(args[0], args[1]),
+                        z3.BitVecVal(1, 1),
+                        z3.BitVecVal(0, 1)
+                    )
+                elif expr.op == "<=s":
+                    res = z3.If(
+                        z3.SLE(args[0], args[1]),
+                        z3.BitVecVal(1, 1),
+                        z3.BitVecVal(0, 1)
+                    )
                 else:
                     raise NotImplementedError("Unsupported OP yet: %s" % expr.op)
         elif expr.op == 'parity':
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index 68dc0437..741d6adb 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -97,6 +97,7 @@ s = a[:8]
 i0 = ExprInt(0, 32)
 i1 = ExprInt(1, 32)
 i2 = ExprInt(2, 32)
+im1 = ExprInt(-1, 32)
 icustom = ExprInt(0x12345678, 32)
 cc = ExprCond(a, b, c)
 
@@ -452,9 +453,65 @@ for e_input, e_check in to_test:
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp_explicit simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+            'bug in expr_simp_explicit simp(%s) is %s and should be %s' % (e_input, e_new, e_check)
+        )
     check(e_input, e_check)
 
+
+# Test high level op
+to_test = [
+    (ExprOp(TOK_EQUAL, a+i2, i1), ExprOp(TOK_EQUAL, a+i1, i0)),
+    (ExprOp(TOK_INF_SIGNED, a+i2, i1), ExprOp(TOK_INF_SIGNED, a+i2, i1)),
+    (ExprOp(TOK_INF_UNSIGNED, a+i2, i1), ExprOp(TOK_INF_UNSIGNED, a+i2, i1)),
+
+    (
+        ExprOp(TOK_EQUAL, ExprCompose(a8, ExprInt(0, 24)), im1),
+        ExprOp(TOK_EQUAL, a8, ExprInt(0xFF, 8))
+    ),
+
+    (ExprOp(TOK_INF_SIGNED, i1, i2), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_UNSIGNED, i1, i2), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_EQUAL_SIGNED, i1, i2), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_EQUAL_UNSIGNED, i1, i2), ExprInt(1, 1)),
+
+    (ExprOp(TOK_INF_SIGNED, i2, i1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_UNSIGNED, i2, i1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_EQUAL_SIGNED, i2, i1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_EQUAL_UNSIGNED, i2, i1), ExprInt(0, 1)),
+
+    (ExprOp(TOK_INF_SIGNED, i1, i1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_UNSIGNED, i1, i1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_EQUAL_SIGNED, i1, i1), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_EQUAL_UNSIGNED, i1, i1), ExprInt(1, 1)),
+
+
+    (ExprOp(TOK_INF_SIGNED, im1, i1), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_UNSIGNED, im1, i1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_EQUAL_SIGNED, im1, i1), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_EQUAL_UNSIGNED, im1, i1), ExprInt(0, 1)),
+
+    (ExprOp(TOK_INF_SIGNED, i1, im1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_UNSIGNED, i1, im1), ExprInt(1, 1)),
+    (ExprOp(TOK_INF_EQUAL_SIGNED, i1, im1), ExprInt(0, 1)),
+    (ExprOp(TOK_INF_EQUAL_UNSIGNED, i1, im1), ExprInt(1, 1)),
+
+    (ExprOp(TOK_EQUAL, a8.zeroExtend(32), b8.zeroExtend(32)), ExprOp(TOK_EQUAL, a8, b8)),
+    (ExprOp(TOK_EQUAL, a8.signExtend(32), b8.signExtend(32)), ExprOp(TOK_EQUAL, a8, b8)),
+
+]
+
+for e_input, e_check in to_test:
+    print "#" * 80
+    e_check = expr_simp(e_check)
+    e_new = expr_simp(e_input)
+    print "original: ", str(e_input), "new: ", str(e_new)
+    rez = e_new == e_check
+    if not rez:
+        raise ValueError(
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check)
+        )
+
+
 # Test conds
 
 to_test = [
@@ -490,7 +547,9 @@ for e_input, e_check in to_test:
     rez = e_new == e_check
     if not rez:
         raise ValueError(
-            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check))
+            'bug in expr_simp simp(%s) is %s and should be %s' % (e_input, e_new, e_check)
+        )
+
 
 if args.z3:
     # This check is done on 32 bits, but the size is not use by Miasm formulas, so