diff options
Diffstat (limited to 'miasm2/expression/simplifications_common.py')
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 284 |
1 files changed, 258 insertions, 26 deletions
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 6f0eb34a..726b7577 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -702,8 +702,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1)) + ExprInt(0, expr.size), + ExprInt(1, expr.size)) elif (expr.is_op("CC_U<") and test_cc_eq_args( @@ -726,8 +726,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and @@ -746,8 +746,8 @@ def simp_cc_conds(expr_simp, expr): arg = expr.args[0].args[0] expr = ExprCond( ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_NE") and test_cc_eq_args( @@ -756,8 +756,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_EQUAL, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and @@ -781,8 +781,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp("&", *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and @@ -794,8 +794,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and @@ -806,8 +806,8 @@ def simp_cc_conds(expr_simp, expr): expr.args[1].is_int(0)): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) @@ -820,8 +820,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and @@ -865,8 +865,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and @@ -1046,9 +1046,26 @@ def simp_zeroext_eq_cst(expr_s, expr): src = arg1.args[0] if int(arg2) > (1 << src.size): # Always false - return ExprInt(0, 1) + return ExprInt(0, expr.size) return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size)) +def simp_cond_zeroext(expr_s, expr): + """ + X.zeroExt()?(A:B) => X ? A:B + X.signExt()?(A:B) => X ? A:B + """ + if not ( + expr.cond.is_op() and + ( + expr.cond.op.startswith("zeroExt") or + expr.cond.op.startswith("signExt") + ) + ): + return expr + + ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2) + return ret + def simp_ext_eq_ext(expr_s, expr): # A.zeroExt(X) == B.zeroExt(X) => A == B # A.signExt(X) == B.signExt(X) => A == B @@ -1075,6 +1092,106 @@ def simp_cond_eq_zero(expr_s, expr): new_expr = ExprCond(arg1, expr.src2, expr.src1) return new_expr +def simp_sign_inf_zeroext(expr_s, expr): + """ + /!\ Ensure before: X.zeroExt(X.size) => X + + X.zeroExt() <s 0 => 0 + X.zeroExt() <=s 0 => X == 0 + + X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive) + X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive) + + X.zeroExt() <s cst => 0 (cst negative) + X.zeroExt() <=s cst => 0 (cst negative) + + """ + if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not (arg1.is_op() and arg1.op.startswith("zeroExt")): + return expr + src = arg1.args[0] + assert src.size < arg1.size + + # If cst is zero + if arg2.is_int(0): + if expr.is_op(TOK_INF_SIGNED): + # X.zeroExt() <s 0 => 0 + return ExprInt(0, expr.size) + else: + # X.zeroExt() <=s 0 => X == 0 + return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size)) + + # cst is not zero + cst = int(arg2) + if cst & (1 << (arg2.size - 1)): + # cst is negative + return ExprInt(0, expr.size) + # cst is positive + if expr.is_op(TOK_INF_SIGNED): + # X.zeroExt() <s cst => X.zeroExt() <u cst (cst positive) + return ExprOp(TOK_INF_UNSIGNED, src, expr_s(arg2[:src.size])) + # X.zeroExt() <=s cst => X.zeroExt() <=u cst (cst positive) + return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size])) + + +def simp_zeroext_and_cst_eq_cst(expr_s, expr): + # A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size] + if not expr.is_op(TOK_EQUAL): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not arg1.is_op('&'): + return expr + is_ok = True + sizes = set() + for arg in arg1.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt")): + sizes.add(arg.args[0].size) + continue + is_ok = False + break + if not is_ok: + return expr + if len(sizes) != 1: + return expr + size = list(sizes)[0] + if int(arg2) > ((1 << size) - 1): + return expr + args = [expr_s(arg[:size]) for arg in arg1.args] + left = ExprOp('&', *args) + right = expr_s(arg2[:size]) + ret = ExprOp(TOK_EQUAL, left, right) + return ret + + +def test_one_bit_set(arg): + return arg != 0 and ((arg & (arg - 1)) == 0) + +def simp_x_and_cst_eq_cst(expr_s, expr): + # (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B + cond = expr.cond + if not cond.is_op(TOK_EQUAL): + return expr + arg1, mask2 = cond.args + if not mask2.is_int(): + return expr + if not test_one_bit_set(int(mask2)): + return expr + if not arg1.is_op('&'): + return expr + mask1 = arg1.args[-1] + if mask1 != mask2: + return expr + cond = ExprOp('&', *arg1.args) + return ExprCond(cond, expr.src1, expr.src2) def simp_cmp_int_int(expr_s, expr): # IntA <s IntB => int @@ -1094,8 +1211,7 @@ def simp_cmp_int_int(expr_s, expr): if expr.is_op(TOK_EQUAL): if int_a == int_b: return ExprInt(1, 1) - else: - return ExprInt(0, 1) + return ExprInt(0, expr.size) if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]: int_a = int(mod_size2int[int_a.size](int(int_a))) @@ -1133,17 +1249,133 @@ def simp_ext_cst(expr_s, expr): def simp_slice_of_ext(expr_s, expr): - # zeroExt(X)[0:size(X)] => X - if expr.start != 0: - return expr + """ + C.zeroExt(X)[A:B] => 0 if A >= size(C) + C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) + A.zeroExt(X)[0:Y] => A.zeroExt(Y) + """ if not expr.arg.is_op(): return expr if not expr.arg.op.startswith("zeroExt"): return expr arg = expr.arg.args[0] - if arg.size != expr.size: + + if expr.start >= arg.size: + # C.zeroExt(X)[A:B] => 0 if A >= size(C) + return ExprInt(0, expr.size) + if expr.stop <= arg.size: + # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) + return arg[expr.start:expr.stop] + if expr.start == 0: + # A.zeroExt(X)[0:Y] => A.zeroExt(Y) + return arg.zeroExtend(expr.stop) + return expr + +def simp_slice_of_op_ext(expr_s, expr): + # (X.zeroExt() + ... + Int)[0:8] => X + ... + int[:] + if expr.start != 0: + return expr + src = expr.arg + if not src.is_op("+"): + return expr + is_ok = True + for arg in src.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt") and + arg.args[0].size == expr.stop): + continue + is_ok = False + break + if not is_ok: + return expr + args = [expr_s(arg[:expr.stop]) for arg in src.args] + return ExprOp("+", *args) + + +def simp_cond_logic_ext(expr_s, expr): + # (X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B + cond = expr.cond + if not cond.is_op(): + return expr + if cond.op not in ["&", "^", "|"]: return expr - return arg + is_ok = True + sizes = set() + for arg in cond.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt")): + sizes.add(arg.args[0].size) + continue + is_ok = False + break + if not is_ok: + return expr + if len(sizes) != 1: + return expr + size = list(sizes)[0] + args = [expr_s(arg[:size]) for arg in cond.args] + cond = ExprOp(cond.op, *args) + return ExprCond(cond, expr.src1, expr.src2) + + +def simp_cond_sign_bit(expr_s, expr): + """(a & .. & 0x80000000) ? A:B => (a & ...) <s 0 ? A:B""" + cond = expr.cond + if not cond.is_op('&'): + return expr + last = cond.args[-1] + if not last.is_int(1 << (last.size - 1)): + return expr + zero = ExprInt(0, expr.cond.size) + if len(cond.args) == 2: + args = [cond.args[0], zero] + else: + args = [ExprOp('&', *list(cond.args[:-1])), zero] + cond = ExprOp(TOK_INF_SIGNED, *args) + return ExprCond(cond, expr.src1, expr.src2) + +def simp_test_signext_inf(expr_s, expr): + # A.signExt() <s int => A <s int[:] + if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): + return expr + arg, cst = expr.args + if not (arg.is_op() and arg.op.startswith("signExt")): + return expr + if not cst.is_int(): + return expr + base = arg.args[0] + tmp = int(mod_size2int[cst.size](int(cst))) + if -(1 << (base.size - 1)) <= tmp < (1 << (base.size - 1)): + # Can trunc integer + return ExprOp(expr.op, base, expr_s(cst[:base.size])) + if (tmp >= (1 << (base.size - 1)) or + tmp < -(1 << (base.size - 1)) ): + return ExprInt(1, 1) + return expr + + +def simp_test_zeroext_inf(expr_s, expr): + # A.zeroExt() <u int => A <u int[:] + if not (expr.is_op(TOK_INF_UNSIGNED) or expr.is_op(TOK_INF_EQUAL_UNSIGNED)): + return expr + arg, cst = expr.args + if not (arg.is_op() and arg.op.startswith("zeroExt")): + return expr + if not cst.is_int(): + return expr + base = arg.args[0] + tmp = int(mod_size2uint[cst.size](int(cst))) + if 0 <= tmp < (1 << base.size): + # Can trunc integer + return ExprOp(expr.op, base, expr_s(cst[:base.size])) + if tmp >= (1 << base.size): + return ExprInt(1, 1) + return expr + def simp_add_multiple(expr_s, expr): # X + X => 2 * X |