From 94184d227ed4e524b4622dd3f99141003be348bd Mon Sep 17 00:00:00 2001 From: Fabrice Desclaux Date: Sat, 22 Dec 2018 19:43:04 +0100 Subject: Expression: add simplifications --- miasm2/expression/simplifications_common.py | 284 +++++++++++++++++++++++++--- 1 file changed, 258 insertions(+), 26 deletions(-) (limited to 'miasm2/expression/simplifications_common.py') diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 6f0eb34a..726b7577 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -702,8 +702,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1)) + ExprInt(0, expr.size), + ExprInt(1, expr.size)) elif (expr.is_op("CC_U<") and test_cc_eq_args( @@ -726,8 +726,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and @@ -746,8 +746,8 @@ def simp_cc_conds(expr_simp, expr): arg = expr.args[0].args[0] expr = ExprCond( ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_NE") and test_cc_eq_args( @@ -756,8 +756,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_EQUAL, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and @@ -781,8 +781,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp("&", *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and @@ -794,8 +794,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and @@ -806,8 +806,8 @@ def simp_cc_conds(expr_simp, expr): expr.args[1].is_int(0)): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) @@ -820,8 +820,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and @@ -865,8 +865,8 @@ def simp_cc_conds(expr_simp, expr): )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), - ExprInt(0, 1), - ExprInt(1, 1) + ExprInt(0, expr.size), + ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and @@ -1046,9 +1046,26 @@ def simp_zeroext_eq_cst(expr_s, expr): src = arg1.args[0] if int(arg2) > (1 << src.size): # Always false - return ExprInt(0, 1) + return ExprInt(0, expr.size) return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size)) +def simp_cond_zeroext(expr_s, expr): + """ + X.zeroExt()?(A:B) => X ? A:B + X.signExt()?(A:B) => X ? A:B + """ + if not ( + expr.cond.is_op() and + ( + expr.cond.op.startswith("zeroExt") or + expr.cond.op.startswith("signExt") + ) + ): + return expr + + ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2) + return ret + def simp_ext_eq_ext(expr_s, expr): # A.zeroExt(X) == B.zeroExt(X) => A == B # A.signExt(X) == B.signExt(X) => A == B @@ -1075,6 +1092,106 @@ def simp_cond_eq_zero(expr_s, expr): new_expr = ExprCond(arg1, expr.src2, expr.src1) return new_expr +def simp_sign_inf_zeroext(expr_s, expr): + """ + /!\ Ensure before: X.zeroExt(X.size) => X + + X.zeroExt() 0 + X.zeroExt() <=s 0 => X == 0 + + X.zeroExt() X.zeroExt() X.zeroExt() <=u cst (cst positive) + + X.zeroExt() 0 (cst negative) + X.zeroExt() <=s cst => 0 (cst negative) + + """ + if not (expr.is_op(TOK_INF_SIGNED) or expr.is_op(TOK_INF_EQUAL_SIGNED)): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not (arg1.is_op() and arg1.op.startswith("zeroExt")): + return expr + src = arg1.args[0] + assert src.size < arg1.size + + # If cst is zero + if arg2.is_int(0): + if expr.is_op(TOK_INF_SIGNED): + # X.zeroExt() 0 + return ExprInt(0, expr.size) + else: + # X.zeroExt() <=s 0 => X == 0 + return ExprOp(TOK_EQUAL, src, ExprInt(0, src.size)) + + # cst is not zero + cst = int(arg2) + if cst & (1 << (arg2.size - 1)): + # cst is negative + return ExprInt(0, expr.size) + # cst is positive + if expr.is_op(TOK_INF_SIGNED): + # X.zeroExt() X.zeroExt() X.zeroExt() <=u cst (cst positive) + return ExprOp(TOK_INF_EQUAL_UNSIGNED, src, expr_s(arg2[:src.size])) + + +def simp_zeroext_and_cst_eq_cst(expr_s, expr): + # A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size] + if not expr.is_op(TOK_EQUAL): + return expr + arg1, arg2 = expr.args + if not arg2.is_int(): + return expr + if not arg1.is_op('&'): + return expr + is_ok = True + sizes = set() + for arg in arg1.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt")): + sizes.add(arg.args[0].size) + continue + is_ok = False + break + if not is_ok: + return expr + if len(sizes) != 1: + return expr + size = list(sizes)[0] + if int(arg2) > ((1 << size) - 1): + return expr + args = [expr_s(arg[:size]) for arg in arg1.args] + left = ExprOp('&', *args) + right = expr_s(arg2[:size]) + ret = ExprOp(TOK_EQUAL, left, right) + return ret + + +def test_one_bit_set(arg): + return arg != 0 and ((arg & (arg - 1)) == 0) + +def simp_x_and_cst_eq_cst(expr_s, expr): + # (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B + cond = expr.cond + if not cond.is_op(TOK_EQUAL): + return expr + arg1, mask2 = cond.args + if not mask2.is_int(): + return expr + if not test_one_bit_set(int(mask2)): + return expr + if not arg1.is_op('&'): + return expr + mask1 = arg1.args[-1] + if mask1 != mask2: + return expr + cond = ExprOp('&', *arg1.args) + return ExprCond(cond, expr.src1, expr.src2) def simp_cmp_int_int(expr_s, expr): # IntA int @@ -1094,8 +1211,7 @@ def simp_cmp_int_int(expr_s, expr): if expr.is_op(TOK_EQUAL): if int_a == int_b: return ExprInt(1, 1) - else: - return ExprInt(0, 1) + return ExprInt(0, expr.size) if expr.op in [TOK_INF_SIGNED, TOK_INF_EQUAL_SIGNED]: int_a = int(mod_size2int[int_a.size](int(int_a))) @@ -1133,17 +1249,133 @@ def simp_ext_cst(expr_s, expr): def simp_slice_of_ext(expr_s, expr): - # zeroExt(X)[0:size(X)] => X - if expr.start != 0: - return expr + """ + C.zeroExt(X)[A:B] => 0 if A >= size(C) + C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) + A.zeroExt(X)[0:Y] => A.zeroExt(Y) + """ if not expr.arg.is_op(): return expr if not expr.arg.op.startswith("zeroExt"): return expr arg = expr.arg.args[0] - if arg.size != expr.size: + + if expr.start >= arg.size: + # C.zeroExt(X)[A:B] => 0 if A >= size(C) + return ExprInt(0, expr.size) + if expr.stop <= arg.size: + # C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) + return arg[expr.start:expr.stop] + if expr.start == 0: + # A.zeroExt(X)[0:Y] => A.zeroExt(Y) + return arg.zeroExtend(expr.stop) + return expr + +def simp_slice_of_op_ext(expr_s, expr): + # (X.zeroExt() + ... + Int)[0:8] => X + ... + int[:] + if expr.start != 0: + return expr + src = expr.arg + if not src.is_op("+"): + return expr + is_ok = True + for arg in src.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt") and + arg.args[0].size == expr.stop): + continue + is_ok = False + break + if not is_ok: + return expr + args = [expr_s(arg[:expr.stop]) for arg in src.args] + return ExprOp("+", *args) + + +def simp_cond_logic_ext(expr_s, expr): + # (X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B + cond = expr.cond + if not cond.is_op(): + return expr + if cond.op not in ["&", "^", "|"]: return expr - return arg + is_ok = True + sizes = set() + for arg in cond.args: + if arg.is_int(): + continue + if (arg.is_op() and + arg.op.startswith("zeroExt")): + sizes.add(arg.args[0].size) + continue + is_ok = False + break + if not is_ok: + return expr + if len(sizes) != 1: + return expr + size = list(sizes)[0] + args = [expr_s(arg[:size]) for arg in cond.args] + cond = ExprOp(cond.op, *args) + return ExprCond(cond, expr.src1, expr.src2) + + +def simp_cond_sign_bit(expr_s, expr): + """(a & .. & 0x80000000) ? A:B => (a & ...) A = (1 << (base.size - 1)) or + tmp < -(1 << (base.size - 1)) ): + return ExprInt(1, 1) + return expr + + +def simp_test_zeroext_inf(expr_s, expr): + # A.zeroExt() A = (1 << base.size): + return ExprInt(1, 1) + return expr + def simp_add_multiple(expr_s, expr): # X + X => 2 * X -- cgit 1.4.1 From 0bbef883b95887c4e0ada13f440bc2c4bc87fad5 Mon Sep 17 00:00:00 2001 From: Fabrice Desclaux Date: Sun, 13 Jan 2019 20:56:26 +0100 Subject: Expressions/Simplifications: clean code --- miasm2/expression/simplifications_common.py | 151 +++++++++++++++++----------- 1 file changed, 91 insertions(+), 60 deletions(-) (limited to 'miasm2/expression/simplifications_common.py') diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 726b7577..18529d35 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -382,7 +382,7 @@ def simp_cst_propagation(e_s, expr): return ExprOp(op_name, *args) -def simp_cond_op_int(e_s, expr): +def simp_cond_op_int(_, expr): "Extract conditions from operations" @@ -606,10 +606,11 @@ def simp_compose(e_s, expr): return ExprCompose(*args) -def simp_cond(e_s, expr): - "Common simplifications on ExprCond" - # eval exprcond src1/src2 with satifiable/unsatisfiable condition - # propagation +def simp_cond(_, expr): + """ + Common simplifications on ExprCond. + Eval exprcond src1/src2 with satifiable/unsatisfiable condition propagation + """ if (not expr.cond.is_int()) and expr.cond.size == 1: src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)}) src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)}) @@ -666,10 +667,11 @@ def simp_cond(e_s, expr): return expr -def simp_mem(e_s, expr): - "Common simplifications on ExprMem" - - # @32[x?a:b] => x?@32[a]:@32[b] +def simp_mem(_, expr): + """ + Common simplifications on ExprMem: + @32[x?a:b] => x?@32[a]:@32[b] + """ if expr.ptr.is_cond(): cond = expr.ptr ret = ExprCond(cond.cond, @@ -682,6 +684,15 @@ def simp_mem(e_s, expr): def test_cc_eq_args(expr, *sons_op): + """ + Return True if expression's arguments match the list in sons_op, and their + sub arguments are identical. Ex: + CC_S<=( + FLAG_SIGN_SUB(A, B), + FLAG_SUB_OF(A, B), + FLAG_EQ_CMP(A, B) + ) + """ if not expr.is_op(): return False if len(expr.args) != len(sons_op): @@ -694,7 +705,11 @@ def test_cc_eq_args(expr, *sons_op): return len(all_args) == 1 -def simp_cc_conds(expr_simp, expr): +def simp_cc_conds(_, expr): + """ + High level simplifications. Example: + CC_U<(FLAG_SUB_CF(A, B) => A =") and test_cc_eq_args( expr, @@ -882,8 +897,8 @@ def simp_cc_conds(expr_simp, expr): -def simp_cond_flag(expr_simp, expr): - # FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B +def simp_cond_flag(_, expr): + """FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B""" cond = expr.cond if cond.is_op("FLAG_EQ_CMP"): return ExprCond(ExprOp(TOK_EQUAL, *cond.args), expr.src1, expr.src2) @@ -891,8 +906,10 @@ def simp_cond_flag(expr_simp, expr): def simp_cmp_int(expr_simp, expr): - # ({X, 0} == int) => X == int[:] - # X + int1 == int2 => X == int2-int1 + """ + ({X, 0} == int) => X == int[:] + X + int1 == int2 => X == int2-int1 + """ if (expr.is_op(TOK_EQUAL) and expr.args[1].is_int() and expr.args[0].is_compose() and @@ -931,7 +948,7 @@ def simp_cmp_int(expr_simp, expr): -def simp_cmp_int_arg(expr_simp, expr): +def simp_cmp_int_arg(_, expr): """ (0x10 <= R0) ? A:B => @@ -971,10 +988,8 @@ def simp_cmp_int_arg(expr_simp, expr): return ExprCond(ExprOp(op, arg1, arg2), src1, src2) - - -def simp_subwc_cf(expr_s, expr): - # SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D}) +def simp_subwc_cf(_, expr): + """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})""" if not expr.is_op('FLAG_SUBWC_CF'): return expr op3 = expr.args[2] @@ -987,8 +1002,8 @@ def simp_subwc_cf(expr_s, expr): return ExprOp("FLAG_SUB_CF", op1, op2) -def simp_subwc_of(expr_s, expr): - # SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D}) +def simp_subwc_of(_, expr): + """SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D})""" if not expr.is_op('FLAG_SUBWC_OF'): return expr op3 = expr.args[2] @@ -1001,8 +1016,8 @@ def simp_subwc_of(expr_s, expr): return ExprOp("FLAG_SUB_OF", op1, op2) -def simp_sign_subwc_cf(expr_s, expr): - # SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D}) +def simp_sign_subwc_cf(_, expr): + """SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D})""" if not expr.is_op('FLAG_SIGN_SUBWC'): return expr op3 = expr.args[2] @@ -1014,8 +1029,8 @@ def simp_sign_subwc_cf(expr_s, expr): return ExprOp("FLAG_SIGN_SUB", op1, op2) -def simp_double_zeroext(expr_s, expr): - # A.zeroExt(X).zeroExt(Y) => A.zeroExt(Y) +def simp_double_zeroext(_, expr): + """A.zeroExt(X).zeroExt(Y) => A.zeroExt(Y)""" if not (expr.is_op() and expr.op.startswith("zeroExt")): return expr arg1 = expr.args[0] @@ -1024,8 +1039,8 @@ def simp_double_zeroext(expr_s, expr): arg2 = arg1.args[0] return ExprOp(expr.op, arg2) -def simp_double_signext(expr_s, expr): - # A.signExt(X).signExt(Y) => A.signExt(Y) +def simp_double_signext(_, expr): + """A.signExt(X).signExt(Y) => A.signExt(Y)""" if not (expr.is_op() and expr.op.startswith("signExt")): return expr arg1 = expr.args[0] @@ -1034,8 +1049,8 @@ def simp_double_signext(expr_s, expr): arg2 = arg1.args[0] return ExprOp(expr.op, arg2) -def simp_zeroext_eq_cst(expr_s, expr): - # A.zeroExt(X) == int => A == int[:A.size] +def simp_zeroext_eq_cst(_, expr): + """A.zeroExt(X) == int => A == int[:A.size]""" if not expr.is_op(TOK_EQUAL): return expr arg1, arg2 = expr.args @@ -1049,7 +1064,7 @@ def simp_zeroext_eq_cst(expr_s, expr): return ExprInt(0, expr.size) return ExprOp(TOK_EQUAL, src, ExprInt(int(arg2), src.size)) -def simp_cond_zeroext(expr_s, expr): +def simp_cond_zeroext(_, expr): """ X.zeroExt()?(A:B) => X ? A:B X.signExt()?(A:B) => X ? A:B @@ -1066,9 +1081,11 @@ def simp_cond_zeroext(expr_s, expr): ret = ExprCond(expr.cond.args[0], expr.src1, expr.src2) return ret -def simp_ext_eq_ext(expr_s, expr): - # A.zeroExt(X) == B.zeroExt(X) => A == B - # A.signExt(X) == B.signExt(X) => A == B +def simp_ext_eq_ext(_, expr): + """ + A.zeroExt(X) == B.zeroExt(X) => A == B + A.signExt(X) == B.signExt(X) => A == B + """ if not expr.is_op(TOK_EQUAL): return expr arg1, arg2 = expr.args @@ -1081,8 +1098,8 @@ def simp_ext_eq_ext(expr_s, expr): return expr return ExprOp(TOK_EQUAL, arg1.args[0], arg2.args[0]) -def simp_cond_eq_zero(expr_s, expr): - # (X == 0)?(A:B) => X?(B:A) +def simp_cond_eq_zero(_, expr): + """(X == 0)?(A:B) => X?(B:A)""" cond = expr.cond if not cond.is_op(TOK_EQUAL): return expr @@ -1139,7 +1156,9 @@ def simp_sign_inf_zeroext(expr_s, expr): def simp_zeroext_and_cst_eq_cst(expr_s, expr): - # A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size] + """ + A.zeroExt(X) & ... & int == int => A & ... & int[:A.size] == int[:A.size] + """ if not expr.is_op(TOK_EQUAL): return expr arg1, arg2 = expr.args @@ -1173,10 +1192,15 @@ def simp_zeroext_and_cst_eq_cst(expr_s, expr): def test_one_bit_set(arg): + """ + Return True if arg has form 1 << X + """ return arg != 0 and ((arg & (arg - 1)) == 0) -def simp_x_and_cst_eq_cst(expr_s, expr): - # (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B +def simp_x_and_cst_eq_cst(_, expr): + """ + (x & ... & onebitmask == onebitmask) ? A:B => (x & ... & onebitmask) ? A:B + """ cond = expr.cond if not cond.is_op(TOK_EQUAL): return expr @@ -1193,12 +1217,14 @@ def simp_x_and_cst_eq_cst(expr_s, expr): cond = ExprOp('&', *arg1.args) return ExprCond(cond, expr.src1, expr.src2) -def simp_cmp_int_int(expr_s, expr): - # IntA int - # IntA int - # IntA <=s IntB => int - # IntA <=u IntB => int - # IntA == IntB => int +def simp_cmp_int_int(_, expr): + """ + IntA int + IntA int + IntA <=s IntB => int + IntA <=u IntB => int + IntA == IntB => int + """ if expr.op not in [ TOK_EQUAL, TOK_INF_SIGNED, TOK_INF_UNSIGNED, @@ -1232,9 +1258,11 @@ def simp_cmp_int_int(expr_s, expr): return ExprInt(ret, 1) -def simp_ext_cst(expr_s, expr): - # Int.zeroExt(X) => Int - # Int.signExt(X) => Int +def simp_ext_cst(_, expr): + """ + Int.zeroExt(X) => Int + Int.signExt(X) => Int + """ if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): return expr arg = expr.args[0] @@ -1248,7 +1276,7 @@ def simp_ext_cst(expr_s, expr): return ret -def simp_slice_of_ext(expr_s, expr): +def simp_slice_of_ext(_, expr): """ C.zeroExt(X)[A:B] => 0 if A >= size(C) C.zeroExt(X)[A:B] => C[A:B] if B <= size(C) @@ -1272,7 +1300,7 @@ def simp_slice_of_ext(expr_s, expr): return expr def simp_slice_of_op_ext(expr_s, expr): - # (X.zeroExt() + ... + Int)[0:8] => X + ... + int[:] + """(X.zeroExt() + ... + Int)[0:8] => X + ... + int[:]""" if expr.start != 0: return expr src = expr.arg @@ -1295,7 +1323,7 @@ def simp_slice_of_op_ext(expr_s, expr): def simp_cond_logic_ext(expr_s, expr): - # (X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B + """(X.zeroExt() + ... + Int) ? A:B => X + ... + int[:] ? A:B""" cond = expr.cond if not cond.is_op(): return expr @@ -1322,7 +1350,7 @@ def simp_cond_logic_ext(expr_s, expr): return ExprCond(cond, expr.src1, expr.src2) -def simp_cond_sign_bit(expr_s, expr): +def simp_cond_sign_bit(_, expr): """(a & .. & 0x80000000) ? A:B => (a & ...) A A A A 2 * X - # X + X * int1 => X * (1 + int1) - # X * int1 + (- X) => X * (int1 - 1) - # X + (X << int1) => X * (1 + 2 ** int1) - # Correct even if addition overflow/underflow +def simp_add_multiple(_, expr): + """ + X + X => 2 * X + X + X * int1 => X * (1 + int1) + X * int1 + (- X) => X * (int1 - 1) + X + (X << int1) => X * (1 + 2 ** int1) + Correct even if addition overflow/underflow + """ if not expr.is_op('+'): return expr # Extract each argument and its counter operands = {} - for i, arg in enumerate(expr.args): + for arg in expr.args: if arg.is_op('*') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) + int(factor) -- cgit 1.4.1