about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/simplifications_common.py152
1 files changed, 65 insertions, 87 deletions
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index a070fb81..22994d4e 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -22,8 +22,8 @@ def simp_cst_propagation(e_s, e):
     # TODO: <<< >>> << >> are architecture dependant
     if op in op_propag_cst:
         while (len(args) >= 2 and
-            isinstance(args[-1], ExprInt) and
-            isinstance(args[-2], ExprInt)):
+            args[-1].is_int() and
+            args[-2].is_int()):
             i2 = args.pop()
             i1 = args.pop()
             if op == '+':
@@ -83,48 +83,45 @@ def simp_cst_propagation(e_s, e):
             args.append(o)
 
     # bsf(int) => int
-    if op == "bsf" and isinstance(args[0], ExprInt) and args[0].arg != 0:
+    if op == "bsf" and args[0].is_int() and args[0].arg != 0:
         i = 0
         while args[0].arg & (1 << i) == 0:
             i += 1
         return ExprInt_from(args[0], i)
 
     # bsr(int) => int
-    if op == "bsr" and isinstance(args[0], ExprInt) and args[0].arg != 0:
+    if op == "bsr" and args[0].is_int() and args[0].arg != 0:
         i = args[0].size - 1
         while args[0].arg & (1 << i) == 0:
             i -= 1
         return ExprInt_from(args[0], i)
 
     # -(-(A)) => A
-    if op == '-' and len(args) == 1 and isinstance(args[0], ExprOp) and \
-            args[0].op == '-' and len(args[0].args) == 1:
+    if (op == '-' and len(args) == 1 and args[0].is_op('-') and
+        len(args[0].args) == 1):
         return args[0].args[0]
 
     # -(int) => -int
-    if op == '-' and len(args) == 1 and isinstance(args[0], ExprInt):
+    if op == '-' and len(args) == 1 and args[0].is_int():
         return ExprInt(-args[0].arg)
     # A op 0 =>A
     if op in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1:
-        if isinstance(args[-1], ExprInt) and args[-1].arg == 0:
+        if args[-1].is_int(0):
             args.pop()
     # A - 0 =>A
-    if op == '-' and len(args) > 1 and args[-1].arg == 0:
+    if op == '-' and len(args) > 1 and args[-1].is_int(0):
         assert(len(args) == 2) # Op '-' with more than 2 args: SantityCheckError
         return args[0]
 
     # A * 1 =>A
-    if op == "*" and len(args) > 1:
-        if isinstance(args[-1], ExprInt) and args[-1].arg == 1:
-            args.pop()
+    if op == "*" and len(args) > 1 and args[-1].is_int(1):
+        args.pop()
 
     # for cannon form
     # A * -1 => - A
-    if op == "*" and len(args) > 1:
-        if (isinstance(args[-1], ExprInt) and
-            args[-1].arg == (1 << args[-1].size) - 1):
-            args.pop()
-            args[-1] = - args[-1]
+    if op == "*" and len(args) > 1 and args[-1].is_int((1 << args[-1].size) - 1):
+        args.pop()
+        args[-1] = - args[-1]
 
     # op A => A
     if op in ['+', '*', '^', '&', '|', '>>', '<<',
@@ -140,24 +137,19 @@ def simp_cst_propagation(e_s, e):
         return ExprOp('+', args[0], -args[1])
 
     # A op 0 => 0
-    if op in ['&', "*"] and isinstance(args[1], ExprInt) and args[1].arg == 0:
+    if op in ['&', "*"] and args[1].is_int(0):
         return ExprInt_from(e, 0)
 
     # - (A + B +...) => -A + -B + -C
-    if (op == '-' and
-        len(args) == 1 and
-        isinstance(args[0], ExprOp) and
-        args[0].op == '+'):
+    if op == '-' and len(args) == 1 and args[0].is_op('+'):
         args = [-a for a in args[0].args]
         e = ExprOp('+', *args)
         return e
 
     # -(a?int1:int2) => (a?-int1:-int2)
-    if (op == '-' and
-        len(args) == 1 and
-        isinstance(args[0], ExprCond) and
-        isinstance(args[0].src1, ExprInt) and
-        isinstance(args[0].src2, ExprInt)):
+    if (op == '-' and len(args) == 1 and
+        args[0].is_cond() and
+        args[0].src1.is_int() and args[0].src2.is_int()):
         i1 = args[0].src1
         i2 = args[0].src2
         i1 = ExprInt_from(i1, -i1.arg)
@@ -174,13 +166,13 @@ def simp_cst_propagation(e_s, e):
                 del(args[j])
                 continue
             # A + (- A) => 0
-            if op == '+' and isinstance(args[j], ExprOp) and args[j].op == "-":
+            if op == '+' and args[j].is_op("-"):
                 if len(args[j].args) == 1 and args[i] == args[j].args[0]:
                     args[i] = ExprInt_from(args[i], 0)
                     del(args[j])
                     continue
             # (- A) + A => 0
-            if op == '+' and isinstance(args[i], ExprOp) and args[i].op == "-":
+            if op == '+' and args[i].is_op("-"):
                 if len(args[i].args) == 1 and args[j] == args[i].args[0]:
                     args[i] = ExprInt_from(args[i], 0)
                     del(args[j])
@@ -201,13 +193,13 @@ def simp_cst_propagation(e_s, e):
 
     # A <<< A.size => A
     if (op in ['<<<', '>>>'] and
-        isinstance(args[1], ExprInt) and
+        args[1].is_int() and
         args[1].arg == args[0].size):
         return args[0]
 
     # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>)
     if (op in ['<<<', '>>>'] and
-        isinstance(args[0], ExprOp) and
+        args[0].is_op() and
         args[0].op in ['<<<', '>>>']):
         op1 = op
         op2 = args[0].op
@@ -223,8 +215,7 @@ def simp_cst_propagation(e_s, e):
 
     # A >> X >> Y  =>  A >> (X+Y)
     if (op in ['<<', '>>'] and
-        isinstance(args[0], ExprOp) and
-        args[0].op == op):
+        args[0].is_op(op)):
         args = [args[0].args[0], args[0].args[1] + args[1]]
 
     # ((A & A.mask)
@@ -239,15 +230,13 @@ def simp_cst_propagation(e_s, e):
     # TODO
 
     # ((A & mask) >> shift) whith mask < 2**shift => 0
-    if (op == ">>" and
-        isinstance(args[1], ExprInt) and
-        isinstance(args[0], ExprOp) and args[0].op == "&"):
-        if (isinstance(args[0].args[1], ExprInt) and
+    if op == ">>" and args[1].is_int() and args[0].is_op("&"):
+        if (args[0].args[1].is_int() and
             2 ** args[1].arg > args[0].args[1].arg):
             return ExprInt_from(args[0], 0)
 
     # parity(int) => int
-    if op == 'parity' and isinstance(args[0], ExprInt):
+    if op == 'parity' and args[0].is_int():
         return ExprInt1(parity(args[0].arg))
 
     # (-a) * b * (-c) * (-d) => (-a) * b * c * d
@@ -255,7 +244,7 @@ def simp_cst_propagation(e_s, e):
         new_args = []
         counter = 0
         for a in args:
-            if isinstance(a, ExprOp) and a.op == '-' and len(a.args) == 1:
+            if a.is_op('-') and len(a.args) == 1:
                 new_args.append(a.args[0])
                 counter += 1
             else:
@@ -265,8 +254,8 @@ def simp_cst_propagation(e_s, e):
         args = new_args
 
     # A << int with A ExprCompose => move index
-    if (op == "<<" and isinstance(args[0], ExprCompose) and
-        isinstance(args[1], ExprInt) and int(args[1]) != 0):
+    if (op == "<<" and args[0].is_compose() and
+        args[1].is_int() and int(args[1]) != 0):
         final_size = args[0].size
         shift = int(args[1])
         new_args = []
@@ -291,7 +280,7 @@ def simp_cst_propagation(e_s, e):
         return ExprCompose(*args)
 
     # A >> int with A ExprCompose => move index
-    if op == ">>" and isinstance(args[0], ExprCompose) and isinstance(args[1], ExprInt):
+    if op == ">>" and args[0].is_compose() and args[1].is_int():
         final_size = args[0].size
         shift = int(args[1])
         new_args = []
@@ -316,7 +305,7 @@ def simp_cst_propagation(e_s, e):
 
 
     # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b)
-    if op in ['|', '&', '^'] and all([isinstance(arg, ExprCompose) for arg in args]):
+    if op in ['|', '&', '^'] and all([arg.is_compose() for arg in args]):
         bounds = set()
         for arg in args:
             bound = tuple([expr.size for expr in arg.args])
@@ -337,10 +326,9 @@ def simp_cst_propagation(e_s, e):
         assert len(args) == 3
         dest, rounds, cf = args
         # Skipped if rounds is 0
-        if (isinstance(rounds, ExprInt) and
-            int(rounds) == 0):
+        if rounds.is_int(0):
             return dest
-        elif all(map(lambda x: isinstance(x, ExprInt), args)):
+        elif all(map(lambda x: x.is_int(), args)):
             # The expression can be resolved
             tmp = int(dest)
             cf = int(cf)
@@ -375,12 +363,12 @@ def simp_cond_op_int(e_s, e):
         return e
     if len(e.args) < 2:
         return e
-    if not isinstance(e.args[-1], ExprInt):
+    if not e.args[-1].is_int():
         return e
     a_int = e.args[-1]
     conds = []
     for a in e.args[:-1]:
-        if not isinstance(a, ExprCond):
+        if not a.is_cond():
             return e
         conds.append(a)
     if not conds:
@@ -404,7 +392,7 @@ def simp_cond_factor(e_s, e):
     not_conds = []
     multi_cond = False
     for a in e.args:
-        if not isinstance(a, ExprCond):
+        if not a.is_cond():
             not_conds.append(a)
             continue
         c = a.cond
@@ -437,19 +425,19 @@ def simp_slice(e_s, e):
     if e.start == 0 and e.stop == e.arg.size:
         return e.arg
     # Slice(int) => int
-    elif isinstance(e.arg, ExprInt):
+    elif e.arg.is_int():
         total_bit = e.stop - e.start
         mask = (1 << (e.stop - e.start)) - 1
         return ExprInt(int((e.arg.arg >> e.start) & mask), total_bit)
     # Slice(Slice(A, x), y) => Slice(A, z)
-    elif isinstance(e.arg, ExprSlice):
+    elif e.arg.is_slice():
         if e.stop - e.start > e.arg.stop - e.arg.start:
             raise ValueError('slice in slice: getting more val', str(e))
 
         new_e = ExprSlice(e.arg.arg, e.start + e.arg.start,
                           e.start + e.arg.start + (e.stop - e.start))
         return new_e
-    elif isinstance(e.arg, ExprCompose):
+    elif e.arg.is_compose():
         # Slice(Compose(A), x) => Slice(A, y)
         for index, arg in e.arg.iter_args():
             if index <= e.start and index+arg.size >= e.stop:
@@ -489,38 +477,33 @@ def simp_slice(e_s, e):
 
     # ExprMem(x, size)[:A] => ExprMem(x, a)
     # XXXX todo hum, is it safe?
-    elif (isinstance(e.arg, ExprMem) and
-        e.start == 0 and
-        e.arg.size > e.stop and e.stop % 8 == 0):
+    elif (e.arg.is_mem() and
+          e.start == 0 and
+          e.arg.size > e.stop and e.stop % 8 == 0):
         e = ExprMem(e.arg.arg, size=e.stop)
         return e
     # distributivity of slice and &
     # (a & int)[x:y] => 0 if int[x:y] == 0
-    elif (isinstance(e.arg, ExprOp) and
-        e.arg.op == "&" and
-        isinstance(e.arg.args[-1], ExprInt)):
+    elif e.arg.is_op("&") and e.arg.args[-1].is_int():
         tmp = e_s.expr_simp_wrapper(e.arg.args[-1][e.start:e.stop])
-        if isinstance(tmp, ExprInt) and tmp.arg == 0:
+        if tmp.is_int(0):
             return tmp
     # distributivity of slice and exprcond
     # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y])
-    elif (isinstance(e.arg, ExprCond) and
-        isinstance(e.arg.src1, ExprInt) and
-        isinstance(e.arg.src2, ExprInt)):
+    elif e.arg.is_cond() and e.arg.src1.is_int() and e.arg.src2.is_int():
         src1 = e.arg.src1[e.start:e.stop]
         src2 = e.arg.src2[e.start:e.stop]
         e = ExprCond(e.arg.cond, src1, src2)
 
     # (a * int)[0:y] => (a[0:y] * int[0:y])
-    elif (e.start == 0 and isinstance(e.arg, ExprOp) and
-        e.arg.op == "*" and isinstance(e.arg.args[-1], ExprInt)):
+    elif e.start == 0 and e.arg.is_op("*") and e.arg.args[-1].is_int():
         args = [e_s.expr_simp_wrapper(a[e.start:e.stop]) for a in e.arg.args]
         e = ExprOp(e.arg.op, *args)
 
     # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size
     # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0
-    elif (isinstance(e.arg, ExprOp) and e.arg.op in [">>", "<<"] and
-          isinstance(e.arg.args[1], ExprInt)):
+    elif (e.arg.is_op() and e.arg.op in [">>", "<<"] and
+          e.arg.args[1].is_int()):
         arg, shift = e.arg.args
         shift = int(shift)
         if e.arg.op == ">>":
@@ -541,7 +524,7 @@ def simp_compose(e_s, e):
     out = []
     # compose of compose
     for arg in args:
-        if isinstance(arg, ExprCompose):
+        if arg.is_compose():
             out += arg.args
         else:
             out.append(arg)
@@ -551,10 +534,8 @@ def simp_compose(e_s, e):
         return args[0]
 
     # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z)
-    if (len(args) == 2 and
-        isinstance(args[1], ExprInt) and
-        int(args[1]) == 0):
-        if (isinstance(args[0], ExprSlice) and
+    if len(args) == 2 and args[1].is_int(0):
+        if (args[0].is_slice() and
             args[0].stop == args[0].arg.size and
             args[0].size + args[1].size == args[0].arg.size):
             new_e = args[0].arg >> ExprInt(args[0].start, args[0].arg.size)
@@ -571,7 +552,7 @@ def simp_compose(e_s, e):
             ok = False
             break
         expr_ints_or_conds.append(arg)
-        if isinstance(arg, ExprCond):
+        if arg.is_cond():
             if expr_cond_index is not None:
                 ok = False
             expr_cond_index = i
@@ -589,7 +570,7 @@ def simp_compose(e_s, e):
                 src2.append(arg)
         src1 = e_s.apply_simp(ExprCompose(*src1))
         src2 = e_s.apply_simp(ExprCompose(*src2))
-        if isinstance(src1, ExprInt) and isinstance(src2, ExprInt):
+        if src1.is_int() and src2.is_int():
             return ExprCond(cond.cond, src1, src2)
     return ExprCompose(*args)
 
@@ -598,43 +579,40 @@ def simp_cond(e_s, e):
     "Common simplifications on ExprCond"
     # eval exprcond src1/src2 with satifiable/unsatisfiable condition
     # propagation
-    if (not isinstance(e.cond, ExprInt)) and e.cond.size == 1:
+    if (not e.cond.is_int()) and e.cond.size == 1:
         src1 = e.src1.replace_expr({e.cond: ExprInt1(1)})
         src2 = e.src2.replace_expr({e.cond: ExprInt1(0)})
         if src1 != e.src1 or src2 != e.src2:
             return ExprCond(e.cond, src1, src2)
 
     # -A ? B:C => A ? B:C
-    if (isinstance(e.cond, ExprOp) and
-        e.cond.op == '-' and
-        len(e.cond.args) == 1):
+    if e.cond.is_op('-') and len(e.cond.args) == 1:
         e = ExprCond(e.cond.args[0], e.src1, e.src2)
     # a?x:x
     elif e.src1 == e.src2:
         e = e.src1
     # int ? A:B => A or B
-    elif isinstance(e.cond, ExprInt):
+    elif e.cond.is_int():
         if e.cond.arg == 0:
             e = e.src2
         else:
             e = e.src1
     # a?(a?b:c):x => a?b:x
-    elif isinstance(e.src1, ExprCond) and e.cond == e.src1.cond:
+    elif e.src1.is_cond() and e.cond == e.src1.cond:
         e = ExprCond(e.cond, e.src1.src1, e.src2)
     # a?x:(a?b:c) => a?x:c
-    elif isinstance(e.src2, ExprCond) and e.cond == e.src2.cond:
+    elif e.src2.is_cond() and e.cond == e.src2.cond:
         e = ExprCond(e.cond, e.src1, e.src2.src2)
     # a|int ? b:c => b with int != 0
-    elif (isinstance(e.cond, ExprOp) and
-        e.cond.op == '|' and
-        isinstance(e.cond.args[1], ExprInt) and
-        e.cond.args[1].arg != 0):
+    elif (e.cond.is_op('|') and
+          e.cond.args[1].is_int() and
+          e.cond.args[1].arg != 0):
         return e.src1
 
     # (C?int1:int2)?(A:B) =>
-    elif (isinstance(e.cond, ExprCond) and
-          isinstance(e.cond.src1, ExprInt) and
-          isinstance(e.cond.src2, ExprInt)):
+    elif (e.cond.is_cond() and
+          e.cond.src1.is_int() and
+          e.cond.src2.is_int()):
         int1 = e.cond.src1.arg.arg
         int2 = e.cond.src2.arg.arg
         if int1 and int2: