about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/simplifications.py10
-rw-r--r--miasm2/expression/simplifications_common.py284
-rw-r--r--test/expression/simplifications.py176
3 files changed, 443 insertions, 27 deletions
diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py
index 3f50fc1a..8ea9c41f 100644
--- a/miasm2/expression/simplifications.py
+++ b/miasm2/expression/simplifications.py
@@ -49,24 +49,32 @@ class ExpressionSimplifier(object):
             simplifications_common.simp_ext_eq_ext,
 
             simplifications_common.simp_cmp_int,
+            simplifications_common.simp_sign_inf_zeroext,
             simplifications_common.simp_cmp_int_int,
             simplifications_common.simp_ext_cst,
+            simplifications_common.simp_zeroext_and_cst_eq_cst,
+            simplifications_common.simp_test_signext_inf,
+            simplifications_common.simp_test_zeroext_inf,
 
         ],
 
         m2_expr.ExprSlice: [
             simplifications_common.simp_slice,
             simplifications_common.simp_slice_of_ext,
+            simplifications_common.simp_slice_of_op_ext,
         ],
         m2_expr.ExprCompose: [simplifications_common.simp_compose],
         m2_expr.ExprCond: [
             simplifications_common.simp_cond,
+            simplifications_common.simp_cond_zeroext,
             # CC op
             simplifications_common.simp_cond_flag,
             simplifications_common.simp_cmp_int_arg,
 
             simplifications_common.simp_cond_eq_zero,
-
+            simplifications_common.simp_x_and_cst_eq_cst,
+            simplifications_common.simp_cond_logic_ext,
+            simplifications_common.simp_cond_sign_bit,
         ],
         m2_expr.ExprMem: [simplifications_common.simp_mem],
 
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 6f0eb34a..726b7577 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -702,8 +702,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1))
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size))
 
     elif (expr.is_op("CC_U<") and
           test_cc_eq_args(
@@ -726,8 +726,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_EQ") and
@@ -746,8 +746,8 @@ def simp_cc_conds(expr_simp, expr):
         arg = expr.args[0].args[0]
         expr = ExprCond(
             ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
     elif (expr.is_op("CC_NE") and
           test_cc_eq_args(
@@ -756,8 +756,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_EQUAL, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_EQ") and
@@ -781,8 +781,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp("&", *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S>") and
@@ -794,8 +794,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S>") and
@@ -806,8 +806,8 @@ def simp_cc_conds(expr_simp, expr):
           expr.args[1].is_int(0)):
         expr = ExprCond(
             ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
 
@@ -820,8 +820,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S<") and
@@ -865,8 +865,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S<") and
@@ -1046,9 +1046,26 @@ def simp_zeroext_eq_cst(expr_s, expr):
     src = arg1.args[0]
     if int(arg2) > (1 << src.size):
         # Always false
-        return ExprInt(0, 1)
+        return ExprInt(0, expr.size)
     return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size))
 
+def simp_cond_zeroext(expr_s, expr):
+    """
+    X.zeroExt()?(A:B) => X ? A:B
+    X.signExt()?(A:B) => X ? A:B
+    """
+    if not (
+            expr.cond.is_op() and
+            (
+                expr.cond.op.startswith("zeroExt") or
+                expr.cond.op.startswith("signExt")
+            )
+    ):
+        return expr
+
+    ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2)
+    return ret
+
 def simp_ext_eq_ext(expr_s, expr):
     # A.zeroExt(X) == B.zeroExt(X) => A == B
     # A.signExt(X) == B.signExt(X) => A == B
@@ -1075,6 +1092,106 @@ def simp_cond_eq_zero(expr_s, expr):
     new_expr = ExprCond(arg1, expr.src2, expr.src1)
     return new_expr
 
+def simp_sign_inf_zeroext(expr_s, expr):
+    """
+    /!\ Ensure before: X.zeroExt(X.size) => X
+
+    X.zeroExt() <s 0 => 0
+    X.zeroExt() <=s 0 => X == 0
+
+    X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive)
+    X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive)
+
+    X.zeroExt() <s cst => 0 (cst negative)
+    X.zeroExt() <=s cst => 0 (cst negative)
+
+    """
+    if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not (arg1.is_op() and arg1.op.startswith("zeroExt")):
+        return expr
+    src = arg1.args[0]
+    assert src.size < arg1.size
+
+    # If cst is zero
+    if arg2.is_int(0):
+        if expr.is_op(TOK_INF_SIGNED):
+            # X.zeroExt() <s 0 => 0
+            return ExprInt(0, expr.size)
+        else:
+            # X.zeroExt() <=s 0 => X == 0
+            return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size))
+
+    # cst is not zero
+    cst = int(arg2)
+    if cst & (1 << (arg2.size - 1)):
+        # cst is negative
+        return ExprInt(0, expr.size)
+    # cst is positive
+    if expr.is_op(TOK_INF_SIGNED):
+        # X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive)
+        return ExprOp(TOK_INF_UNSIGNED, src, expr_s(arg2[:src.size]))
+    # X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive)
+    return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size]))
+
+
+def simp_zeroext_and_cst_eq_cst(expr_s, expr):
+    # A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size]
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not arg1.is_op('&'):
+        return expr
+    is_ok = True
+    sizes = set()
+    for arg in arg1.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt")):
+            sizes.add(arg.args[0].size)
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    if len(sizes) != 1:
+        return expr
+    size = list(sizes)[0]
+    if int(arg2) > ((1 << size) - 1):
+        return expr
+    args = [expr_s(arg[:size]) for arg in arg1.args]
+    left = ExprOp('&', *args)
+    right = expr_s(arg2[:size])
+    ret = ExprOp(TOK_EQUAL, left, right)
+    return ret
+
+
+def test_one_bit_set(arg):
+    return arg != 0  and ((arg & (arg - 1)) == 0)
+
+def simp_x_and_cst_eq_cst(expr_s, expr):
+    # (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B
+    cond = expr.cond
+    if not cond.is_op(TOK_EQUAL):
+        return expr
+    arg1, mask2 = cond.args
+    if not mask2.is_int():
+        return expr
+    if not test_one_bit_set(int(mask2)):
+        return expr
+    if not arg1.is_op('&'):
+        return expr
+    mask1 = arg1.args[-1]
+    if mask1 != mask2:
+        return expr
+    cond = ExprOp('&', *arg1.args)
+    return ExprCond(cond, expr.src1, expr.src2)
 
 def simp_cmp_int_int(expr_s, expr):
     # IntA <s IntB => int
@@ -1094,8 +1211,7 @@ def simp_cmp_int_int(expr_s, expr):
     if expr.is_op(TOK_EQUAL):
         if int_a == int_b:
             return ExprInt(1, 1)
-        else:
-            return ExprInt(0, 1)
+        return ExprInt(0, expr.size)
 
     if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]:
         int_a = int(mod_size2int[int_a.size](int(int_a)))
@@ -1133,17 +1249,133 @@ def simp_ext_cst(expr_s, expr):
 
 
 def simp_slice_of_ext(expr_s, expr):
-    # zeroExt(X)[0:size(X)] => X
-    if expr.start != 0:
-        return expr
+    """
+    C.zeroExt(X)[A:B] => 0 if A >= size(C)
+    C.zeroExt(X)[A:B] => C[A:B] if B <= size(C)
+    A.zeroExt(X)[0:Y] => A.zeroExt(Y)
+    """
     if not expr.arg.is_op():
         return expr
     if not expr.arg.op.startswith("zeroExt"):
         return expr
     arg = expr.arg.args[0]
-    if arg.size != expr.size:
+
+    if expr.start >= arg.size:
+        # C.zeroExt(X)[A:B] => 0 if A >= size(C)
+        return ExprInt(0, expr.size)
+    if expr.stop <= arg.size:
+        # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C)
+        return arg[expr.start:expr.stop]
+    if expr.start == 0:
+        # A.zeroExt(X)[0:Y] => A.zeroExt(Y)
+        return arg.zeroExtend(expr.stop)
+    return expr
+
+def simp_slice_of_op_ext(expr_s, expr):
+    # (X.zeroExt() + ... + Int)[0:8] => X + ... + int[:]
+    if expr.start != 0:
+        return expr
+    src = expr.arg
+    if not src.is_op("+"):
+        return expr
+    is_ok = True
+    for arg in src.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt") and
+            arg.args[0].size == expr.stop):
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    args = [expr_s(arg[:expr.stop]) for arg in src.args]
+    return ExprOp("+", *args)
+
+
+def simp_cond_logic_ext(expr_s, expr):
+    # (X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in ["&", "^", "|"]:
         return expr
-    return arg
+    is_ok = True
+    sizes = set()
+    for arg in cond.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt")):
+            sizes.add(arg.args[0].size)
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    if len(sizes) != 1:
+        return expr
+    size = list(sizes)[0]
+    args = [expr_s(arg[:size]) for arg in cond.args]
+    cond = ExprOp(cond.op, *args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+
+def simp_cond_sign_bit(expr_s, expr):
+    """(a & .. & 0x80000000) ? A:B => (a & ...) <s 0 ? A:B"""
+    cond = expr.cond
+    if not cond.is_op('&'):
+        return expr
+    last = cond.args[-1]
+    if not last.is_int(1 << (last.size - 1)):
+        return expr
+    zero = ExprInt(0, expr.cond.size)
+    if len(cond.args) == 2:
+        args = [cond.args[0], zero]
+    else:
+        args = [ExprOp('&', *list(cond.args[:-1])), zero]
+    cond = ExprOp(TOK_INF_SIGNED, *args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+def simp_test_signext_inf(expr_s, expr):
+    # A.signExt() <s int => A <s int[:]
+    if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
+        return expr
+    arg, cst = expr.args
+    if not (arg.is_op() and arg.op.startswith("signExt")):
+        return expr
+    if not cst.is_int():
+        return expr
+    base = arg.args[0]
+    tmp = int(mod_size2int[cst.size](int(cst)))
+    if -(1 << (base.size - 1)) <= tmp < (1 << (base.size - 1)):
+        # Can trunc integer
+        return ExprOp(expr.op, base, expr_s(cst[:base.size]))
+    if (tmp >= (1 << (base.size - 1)) or
+        tmp < -(1 << (base.size - 1)) ):
+        return ExprInt(1, 1)
+    return expr
+
+
+def simp_test_zeroext_inf(expr_s, expr):
+    # A.zeroExt() <u int => A <u int[:]
+    if not (expr.is_op(TOK_INF_UNSIGNED) or expr.is_op(TOK_INF_EQUAL_UNSIGNED)):
+        return expr
+    arg, cst = expr.args
+    if not (arg.is_op() and arg.op.startswith("zeroExt")):
+        return expr
+    if not cst.is_int():
+        return expr
+    base = arg.args[0]
+    tmp = int(mod_size2uint[cst.size](int(cst)))
+    if 0 <= tmp < (1 << base.size):
+        # Can trunc integer
+        return ExprOp(expr.op, base, expr_s(cst[:base.size]))
+    if tmp >= (1 << base.size):
+        return ExprInt(1, 1)
+    return expr
+
 
 def simp_add_multiple(expr_s, expr):
     # X + X => 2 * X
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index 364456c6..4a093a98 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -446,6 +446,9 @@ to_test = [(ExprInt(1, 32) - ExprInt(1, 32), ExprInt(0, 32)),
     (ExprOp("signExt_16", ExprInt(0x8, 8)), ExprInt(0x8, 16)),
     (ExprOp("signExt_16", ExprInt(-0x8, 8)), ExprInt(-0x8, 16)),
 
+    (ExprCond(a8.zeroExtend(32), a, b), ExprCond(a8, a, b)),
+
+
     (- (i2*a), a * im2),
     (a + a, a * i2),
     (ExprOp('+', a, a), a * i2),
@@ -516,6 +519,179 @@ to_test = [
     (ExprOp(TOK_EQUAL, a8.zeroExtend(32), b8.zeroExtend(32)), ExprOp(TOK_EQUAL, a8, b8)),
     (ExprOp(TOK_EQUAL, a8.signExtend(32), b8.signExtend(32)), ExprOp(TOK_EQUAL, a8, b8)),
 
+    (ExprOp(TOK_INF_EQUAL_SIGNED, a8.zeroExtend(32), i0), ExprOp(TOK_EQUAL, a8, ExprInt(0, 8))),
+
+    ((a8.zeroExtend(32) + b8.zeroExtend(32) + ExprInt(1, 32))[0:8], a8 + b8 + ExprInt(1, 8)),
+
+    (ExprCond(a8.zeroExtend(32), a, b), ExprCond(a8, a, b)),
+    (ExprCond(a8.signExtend(32), a, b), ExprCond(a8, a, b)),
+
+
+    (
+        ExprOp(
+            TOK_EQUAL,
+            a8.zeroExtend(32) & b8.zeroExtend(32) & ExprInt(0x12, 32),
+            i1
+        ),
+        ExprOp(
+            TOK_EQUAL,
+            a8 & b8 & ExprInt(0x12, 8),
+            ExprInt(1, 8)
+        )
+    ),
+
+    (
+        ExprCond(
+            ExprOp(
+                TOK_EQUAL,
+                a & b & ExprInt(0x80, 32),
+                ExprInt(0x80, 32)
+            ), a, b
+        ),
+        ExprCond(a & b & ExprInt(0x80, 32), a, b)
+    ),
+
+
+
+    (
+        ExprCond(
+            a8.zeroExtend(32) & b8.zeroExtend(32) & ExprInt(0x12, 32),
+            a, b
+        ),
+        ExprCond(
+            a8 & b8 & ExprInt(0x12, 8),
+            a, b
+        ),
+    ),
+
+
+    (a8.zeroExtend(32)[:8], a8),
+    (a.zeroExtend(64)[:32], a),
+    (a.zeroExtend(64)[:8], a[:8]),
+    (a8.zeroExtend(32)[:16], a8.zeroExtend(16)),
+
+    (
+        ExprCond(
+            a & ExprInt(0x80000000, 32),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a, ExprInt(0, 32) ),
+            a, b
+        )
+    ),
+
+
+
+    (
+        ExprCond(
+            a8.signExtend(32) & ExprInt(0x80000000, 32),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8, ExprInt(0, 8) ),
+            a, b
+        )
+    ),
+
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8.signExtend(32), ExprInt(0x10, 32) ),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8, ExprInt(0x10, 8) ),
+            a, b
+        )
+    ),
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8.signExtend(32), ExprInt(-0x10, 32) ),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8, ExprInt(-0x10, 8) ),
+            a, b
+        )
+    ),
+
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_UNSIGNED, a8.zeroExtend(32), ExprInt(0x10, 32) ),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_UNSIGNED, a8, ExprInt(0x10, 8) ),
+            a, b
+        )
+    ),
+
+
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8.signExtend(32), ExprInt(0x200, 32) ),
+            a, b
+        ),
+        a
+    ),
+
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_UNSIGNED, a8.zeroExtend(32), ExprInt(0x200, 32) ),
+            a, b
+        ),
+        a
+    ),
+
+
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8.zeroExtend(32), ExprInt(0x10, 32) ),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_UNSIGNED, a8, ExprInt(0x10, 8) ),
+            a, b
+        )
+    ),
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_EQUAL_SIGNED, a8.zeroExtend(32), ExprInt(0x10, 32) ),
+            a, b
+        ),
+        ExprCond(
+            ExprOp(TOK_INF_EQUAL_UNSIGNED, a8, ExprInt(0x10, 8) ),
+            a, b
+        )
+    ),
+
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_SIGNED, a8.zeroExtend(32), ExprInt(-1, 32) ),
+            a, b
+        ),
+        b
+    ),
+
+    (
+        ExprCond(
+            ExprOp(TOK_INF_EQUAL_SIGNED, a8.zeroExtend(32), ExprInt(-1, 32) ),
+            a, b
+        ),
+        b
+    ),
+
+
+    (a8.zeroExtend(32)[2:5], a8[2:5]),
+
 ]
 
 for e_input, e_check in to_test: