about summary refs log tree commit diff stats
path: root/miasm2/expression/simplifications_common.py
diff options
context:
space:
mode:
authorFabrice Desclaux <fabrice.desclaux@cea.fr>2018-12-22 19:43:04 +0100
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2019-01-15 15:14:42 +0100
commit94184d227ed4e524b4622dd3f99141003be348bd (patch)
tree8dc514f6950c8c1e618d192028308edd002cc78f /miasm2/expression/simplifications_common.py
parenteb9b59dd4b4805dee549b69f024019f9d25b2fa5 (diff)
downloadmiasm-94184d227ed4e524b4622dd3f99141003be348bd.tar.gz
miasm-94184d227ed4e524b4622dd3f99141003be348bd.zip
Expression: add simplifications
Diffstat (limited to 'miasm2/expression/simplifications_common.py')
-rw-r--r--miasm2/expression/simplifications_common.py284
1 files changed, 258 insertions, 26 deletions
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 6f0eb34a..726b7577 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -702,8 +702,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1))
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size))
 
     elif (expr.is_op("CC_U<") and
           test_cc_eq_args(
@@ -726,8 +726,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_EQ") and
@@ -746,8 +746,8 @@ def simp_cc_conds(expr_simp, expr):
         arg = expr.args[0].args[0]
         expr = ExprCond(
             ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
     elif (expr.is_op("CC_NE") and
           test_cc_eq_args(
@@ -756,8 +756,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_EQUAL, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_EQ") and
@@ -781,8 +781,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp("&", *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S>") and
@@ -794,8 +794,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S>") and
@@ -806,8 +806,8 @@ def simp_cc_conds(expr_simp, expr):
           expr.args[1].is_int(0)):
         expr = ExprCond(
             ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
 
@@ -820,8 +820,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_SIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S<") and
@@ -865,8 +865,8 @@ def simp_cc_conds(expr_simp, expr):
           )):
         expr = ExprCond(
             ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args),
-            ExprInt(0, 1),
-            ExprInt(1, 1)
+            ExprInt(0, expr.size),
+            ExprInt(1, expr.size)
         )
 
     elif (expr.is_op("CC_S<") and
@@ -1046,9 +1046,26 @@ def simp_zeroext_eq_cst(expr_s, expr):
     src = arg1.args[0]
     if int(arg2) > (1 << src.size):
         # Always false
-        return ExprInt(0, 1)
+        return ExprInt(0, expr.size)
     return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size))
 
+def simp_cond_zeroext(expr_s, expr):
+    """
+    X.zeroExt()?(A:B) => X ? A:B
+    X.signExt()?(A:B) => X ? A:B
+    """
+    if not (
+            expr.cond.is_op() and
+            (
+                expr.cond.op.startswith("zeroExt") or
+                expr.cond.op.startswith("signExt")
+            )
+    ):
+        return expr
+
+    ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2)
+    return ret
+
 def simp_ext_eq_ext(expr_s, expr):
     # A.zeroExt(X) == B.zeroExt(X) => A == B
     # A.signExt(X) == B.signExt(X) => A == B
@@ -1075,6 +1092,106 @@ def simp_cond_eq_zero(expr_s, expr):
     new_expr = ExprCond(arg1, expr.src2, expr.src1)
     return new_expr
 
+def simp_sign_inf_zeroext(expr_s, expr):
+    """
+    /!\ Ensure before: X.zeroExt(X.size) => X
+
+    X.zeroExt() <s 0 => 0
+    X.zeroExt() <=s 0 => X == 0
+
+    X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive)
+    X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive)
+
+    X.zeroExt() <s cst => 0 (cst negative)
+    X.zeroExt() <=s cst => 0 (cst negative)
+
+    """
+    if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not (arg1.is_op() and arg1.op.startswith("zeroExt")):
+        return expr
+    src = arg1.args[0]
+    assert src.size < arg1.size
+
+    # If cst is zero
+    if arg2.is_int(0):
+        if expr.is_op(TOK_INF_SIGNED):
+            # X.zeroExt() <s 0 => 0
+            return ExprInt(0, expr.size)
+        else:
+            # X.zeroExt() <=s 0 => X == 0
+            return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size))
+
+    # cst is not zero
+    cst = int(arg2)
+    if cst & (1 << (arg2.size - 1)):
+        # cst is negative
+        return ExprInt(0, expr.size)
+    # cst is positive
+    if expr.is_op(TOK_INF_SIGNED):
+        # X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive)
+        return ExprOp(TOK_INF_UNSIGNED, src, expr_s(arg2[:src.size]))
+    # X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive)
+    return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size]))
+
+
+def simp_zeroext_and_cst_eq_cst(expr_s, expr):
+    # A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size]
+    if not expr.is_op(TOK_EQUAL):
+        return expr
+    arg1, arg2 = expr.args
+    if not arg2.is_int():
+        return expr
+    if not arg1.is_op('&'):
+        return expr
+    is_ok = True
+    sizes = set()
+    for arg in arg1.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt")):
+            sizes.add(arg.args[0].size)
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    if len(sizes) != 1:
+        return expr
+    size = list(sizes)[0]
+    if int(arg2) > ((1 << size) - 1):
+        return expr
+    args = [expr_s(arg[:size]) for arg in arg1.args]
+    left = ExprOp('&', *args)
+    right = expr_s(arg2[:size])
+    ret = ExprOp(TOK_EQUAL, left, right)
+    return ret
+
+
+def test_one_bit_set(arg):
+    return arg != 0  and ((arg & (arg - 1)) == 0)
+
+def simp_x_and_cst_eq_cst(expr_s, expr):
+    # (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B
+    cond = expr.cond
+    if not cond.is_op(TOK_EQUAL):
+        return expr
+    arg1, mask2 = cond.args
+    if not mask2.is_int():
+        return expr
+    if not test_one_bit_set(int(mask2)):
+        return expr
+    if not arg1.is_op('&'):
+        return expr
+    mask1 = arg1.args[-1]
+    if mask1 != mask2:
+        return expr
+    cond = ExprOp('&', *arg1.args)
+    return ExprCond(cond, expr.src1, expr.src2)
 
 def simp_cmp_int_int(expr_s, expr):
     # IntA <s IntB => int
@@ -1094,8 +1211,7 @@ def simp_cmp_int_int(expr_s, expr):
     if expr.is_op(TOK_EQUAL):
         if int_a == int_b:
             return ExprInt(1, 1)
-        else:
-            return ExprInt(0, 1)
+        return ExprInt(0, expr.size)
 
     if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]:
         int_a = int(mod_size2int[int_a.size](int(int_a)))
@@ -1133,17 +1249,133 @@ def simp_ext_cst(expr_s, expr):
 
 
 def simp_slice_of_ext(expr_s, expr):
-    # zeroExt(X)[0:size(X)] => X
-    if expr.start != 0:
-        return expr
+    """
+    C.zeroExt(X)[A:B] => 0 if A >= size(C)
+    C.zeroExt(X)[A:B] => C[A:B] if B <= size(C)
+    A.zeroExt(X)[0:Y] => A.zeroExt(Y)
+    """
     if not expr.arg.is_op():
         return expr
     if not expr.arg.op.startswith("zeroExt"):
         return expr
     arg = expr.arg.args[0]
-    if arg.size != expr.size:
+
+    if expr.start >= arg.size:
+        # C.zeroExt(X)[A:B] => 0 if A >= size(C)
+        return ExprInt(0, expr.size)
+    if expr.stop <= arg.size:
+        # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C)
+        return arg[expr.start:expr.stop]
+    if expr.start == 0:
+        # A.zeroExt(X)[0:Y] => A.zeroExt(Y)
+        return arg.zeroExtend(expr.stop)
+    return expr
+
+def simp_slice_of_op_ext(expr_s, expr):
+    # (X.zeroExt() + ... + Int)[0:8] => X + ... + int[:]
+    if expr.start != 0:
+        return expr
+    src = expr.arg
+    if not src.is_op("+"):
+        return expr
+    is_ok = True
+    for arg in src.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt") and
+            arg.args[0].size == expr.stop):
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    args = [expr_s(arg[:expr.stop]) for arg in src.args]
+    return ExprOp("+", *args)
+
+
+def simp_cond_logic_ext(expr_s, expr):
+    # (X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B
+    cond = expr.cond
+    if not cond.is_op():
+        return expr
+    if cond.op not in ["&", "^", "|"]:
         return expr
-    return arg
+    is_ok = True
+    sizes = set()
+    for arg in cond.args:
+        if arg.is_int():
+            continue
+        if (arg.is_op() and
+            arg.op.startswith("zeroExt")):
+            sizes.add(arg.args[0].size)
+            continue
+        is_ok = False
+        break
+    if not is_ok:
+        return expr
+    if len(sizes) != 1:
+        return expr
+    size = list(sizes)[0]
+    args = [expr_s(arg[:size]) for arg in cond.args]
+    cond = ExprOp(cond.op, *args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+
+def simp_cond_sign_bit(expr_s, expr):
+    """(a & .. & 0x80000000) ? A:B => (a & ...) <s 0 ? A:B"""
+    cond = expr.cond
+    if not cond.is_op('&'):
+        return expr
+    last = cond.args[-1]
+    if not last.is_int(1 << (last.size - 1)):
+        return expr
+    zero = ExprInt(0, expr.cond.size)
+    if len(cond.args) == 2:
+        args = [cond.args[0], zero]
+    else:
+        args = [ExprOp('&', *list(cond.args[:-1])), zero]
+    cond = ExprOp(TOK_INF_SIGNED, *args)
+    return ExprCond(cond, expr.src1, expr.src2)
+
+def simp_test_signext_inf(expr_s, expr):
+    # A.signExt() <s int => A <s int[:]
+    if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)):
+        return expr
+    arg, cst = expr.args
+    if not (arg.is_op() and arg.op.startswith("signExt")):
+        return expr
+    if not cst.is_int():
+        return expr
+    base = arg.args[0]
+    tmp = int(mod_size2int[cst.size](int(cst)))
+    if -(1 << (base.size - 1)) <= tmp < (1 << (base.size - 1)):
+        # Can trunc integer
+        return ExprOp(expr.op, base, expr_s(cst[:base.size]))
+    if (tmp >= (1 << (base.size - 1)) or
+        tmp < -(1 << (base.size - 1)) ):
+        return ExprInt(1, 1)
+    return expr
+
+
+def simp_test_zeroext_inf(expr_s, expr):
+    # A.zeroExt() <u int => A <u int[:]
+    if not (expr.is_op(TOK_INF_UNSIGNED) or expr.is_op(TOK_INF_EQUAL_UNSIGNED)):
+        return expr
+    arg, cst = expr.args
+    if not (arg.is_op() and arg.op.startswith("zeroExt")):
+        return expr
+    if not cst.is_int():
+        return expr
+    base = arg.args[0]
+    tmp = int(mod_size2uint[cst.size](int(cst)))
+    if 0 <= tmp < (1 << base.size):
+        # Can trunc integer
+        return ExprOp(expr.op, base, expr_s(cst[:base.size]))
+    if tmp >= (1 << base.size):
+        return ExprInt(1, 1)
+    return expr
+
 
 def simp_add_multiple(expr_s, expr):
     # X + X => 2 * X