diff options
Diffstat (limited to 'miasm/expression')
| -rw-r--r-- | miasm/expression/expression.py | 534 | ||||
| -rw-r--r-- | miasm/expression/simplifications.py | 54 | ||||
| -rw-r--r-- | miasm/expression/simplifications_common.py | 138 |
3 files changed, 452 insertions, 274 deletions
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py index 93094979..c2bf5b8b 100644 --- a/miasm/expression/expression.py +++ b/miasm/expression/expression.py @@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent): def str_protected_child(child, parent): return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child) -def visit_chk(visitor): - "Function decorator launching callback on Expression visit" - def wrapped(expr, callback, test_visit=lambda x: True): - if (test_visit is not None) and (not test_visit(expr)): - return expr - expr_new = visitor(expr, callback, test_visit) - if expr_new is None: - return None - expr_new2 = callback(expr_new) - return expr_new2 - return wrapped - # Expression display @@ -152,6 +140,49 @@ class DiGraphExpr(DiGraph): return "" +def is_expr(expr): + return isinstance( + expr, + ( + ExprInt, ExprId, ExprMem, + ExprSlice, ExprCompose, ExprCond, + ExprLoc, ExprOp + ) + ) + +def is_associative(expr): + "Return True iff current operation is associative" + return (expr.op in ['+', '*', '^', '&', '|']) + +def is_commutative(expr): + "Return True iff current operation is commutative" + return (expr.op in ['+', '*', '^', '&', '|']) + +def is_op_segm(expr): + """Returns True if is ExprOp and op == 'segm'""" + return expr.is_op('segm') + +def is_mem_segm(expr): + """Returns True if is ExprMem and ptr is_op_segm""" + return expr.is_mem() and is_op_segm(expr.ptr) + +def canonize_to_exprloc(locdb, expr): + """ + If expr is ExprInt, return ExprLoc with corresponding loc_key + Else, return expr + + @expr: Expr instance + """ + if expr.is_int(): + loc_key = locdb.get_or_create_offset_location(int(expr)) + ret = ExprLoc(loc_key, expr.size) + return ret + return expr + +def is_function_call(expr): + """Returns true if the considered Expr is a function call + """ + return expr.is_op() and expr.op.startswith('call') @total_ordering class LocKey(object): @@ -183,6 +214,263 @@ class LocKey(object): def __str__(self): return "loc_key_%d" % self.key + +class ExprWalkBase(object): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + """ + + def __init__(self, callback): + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + pass + elif expr.is_assign(): + ret = self.visit(expr.dst, *args, **kwargs) + if ret: + return ret + src = self.visit(expr.src, *args, **kwargs) + if ret: + return ret + elif expr.is_cond(): + ret = self.visit(expr.cond, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src1, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src2, *args, **kwargs) + if ret: + return ret + elif expr.is_mem(): + ret = self.visit(expr.ptr, *args, **kwargs) + if ret: + return ret + elif expr.is_slice(): + ret = self.visit(expr.arg, *args, **kwargs) + if ret: + return ret + elif expr.is_op(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + elif expr.is_compose(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + else: + raise TypeError("Visitor can only take Expr") + + ret = self.callback(expr, *args, **kwargs) + return ret + + +class ExprWalk(ExprWalkBase): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + Use cache mechanism. + """ + def __init__(self, callback): + self.cache = set() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = super(ExprWalk, self).visit(expr, *args, **kwargs) + if ret: + return ret + self.cache.add(expr) + return None + + +class ExprGetR(ExprWalkBase): + """ + Return ExprId/ExprMem used by a given expression + """ + def __init__(self, mem_read=False, cst_read=False): + super(ExprGetR, self).__init__(lambda x:None) + self.mem_read = mem_read + self.cst_read = cst_read + self.elements = set() + self.cache = dict() + + def get_r_leaves(self, expr): + if (expr.is_int() or expr.is_loc()) and self.cst_read: + self.elements.add(expr) + elif expr.is_mem(): + self.elements.add(expr) + elif expr.is_id(): + self.elements.add(expr) + + def visit(self, expr, *args, **kwargs): + cache_key = (expr, self.mem_read, self.cst_read) + if cache_key in self.cache: + return self.cache[cache_key] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[cache_key] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + self.get_r_leaves(expr) + if expr.is_mem() and not self.mem_read: + # Don't visit memory sons + return None + + if expr.is_assign(): + if expr.dst.is_mem() and self.mem_read: + ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs) + if expr.src.is_mem(): + self.elements.add(expr.src) + self.get_r_leaves(expr.src) + if expr.src.is_mem() and not self.mem_read: + return None + ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs) + return ret + ret = super(ExprGetR, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorBase(object): + """ + Rebuild expression by visiting sub-expressions + """ + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + ret = expr + elif expr.is_assign(): + dst = self.visit(expr.dst, *args, **kwargs) + src = self.visit(expr.src, *args, **kwargs) + ret = ExprAssign(dst, src) + elif expr.is_cond(): + cond = self.visit(expr.cond, *args, **kwargs) + src1 = self.visit(expr.src1, *args, **kwargs) + src2 = self.visit(expr.src2, *args, **kwargs) + ret = ExprCond(cond, src1, src2) + elif expr.is_mem(): + ptr = self.visit(expr.ptr, *args, **kwargs) + ret = ExprMem(ptr, expr.size) + elif expr.is_slice(): + arg = self.visit(expr.arg, *args, **kwargs) + ret = ExprSlice(arg, expr.start, expr.stop) + elif expr.is_op(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprOp(expr.op, *args) + elif expr.is_compose(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprCompose(*args) + else: + raise TypeError("Visitor can only take Expr") + return ret + + +class ExprVisitorCallbackTopToBottom(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback on each sub-expression + if @callback return non None value, replace current node with this value + Else, continue visit of sub-expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackTopToBottom, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = self.callback(expr) + if ret: + return ret + ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorCallbackBottomToTop(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback from leaves to root expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackBottomToTop, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs) + ret = self.callback(ret) + return ret + + +class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop): + def __init__(self): + super(ExprVisitorCanonize, self).__init__(self.canonize) + + def canonize(self, expr): + if not expr.is_op(): + return expr + if not expr.is_associative(): + return expr + + # ((a+b) + c) => (a + b + c) + args = [] + for arg in expr.args: + if isinstance(arg, ExprOp) and expr.op == arg.op: + args += arg.args + else: + args.append(arg) + args = canonize_expr_list(args) + new_expr = ExprOp(expr.op, *args) + return new_expr + + +class ExprVisitorContains(ExprWalkBase): + """ + Visitor to test if a needle is in an Expression + Cache results + """ + def __init__(self): + self.cache = set() + super(ExprVisitorContains, self).__init__(self.eq_expr) + + def eq_expr(self, expr, needle, *args, **kwargs): + if expr == needle: + return True + return None + + def visit(self, expr, needle, *args, **kwargs): + if (expr, needle) in self.cache: + return None + ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs) + if ret: + return ret + self.cache.add((expr, needle)) + return None + + + def contains(self, expr, needle): + return self.visit(expr, needle) + +contains_visitor = ExprVisitorContains() +canonize_visitor = ExprVisitorCanonize() + # IR definitions class Expr(object): @@ -337,36 +625,16 @@ class Expr(object): """Find and replace sub expression using dct @dct: dictionary associating replaced Expr to its new Expr value """ - return self.visit(lambda expr: dct.get(expr, expr)) + def replace(expr): + if expr in dct: + return dct[expr] + return None + visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr)) + return visitor.visit(self) def canonize(self): "Canonize the Expression" - - def must_canon(expr): - return not expr.is_canon - - def canonize_visitor(expr): - if expr.is_canon: - return expr - if isinstance(expr, ExprOp): - if expr.is_associative(): - # ((a+b) + c) => (a + b + c) - args = [] - for arg in expr.args: - if isinstance(arg, ExprOp) and expr.op == arg.op: - args += arg.args - else: - args.append(arg) - args = canonize_expr_list(args) - new_e = ExprOp(expr.op, *args) - else: - new_e = expr - else: - new_e = expr - new_e.is_canon = True - return new_e - - return self.visit(canonize_visitor, must_canon) + return canonize_visitor.visit(self) def msb(self): "Return the Most Significant Bit" @@ -424,6 +692,10 @@ class Expr(object): return False def is_aff(self): + warnings.warn('DEPRECATION WARNING: use is_assign()') + return False + + def is_assign(self): return False def is_cond(self): @@ -449,6 +721,32 @@ class Expr(object): """Returns True if is ExprMem and ptr is_op_segm""" return False + def __contains__(self, expr): + ret = contains_visitor.contains(self, expr) + return ret + + def visit(self, callback): + """ + Apply callback to all sub expression of @self + This function keeps a cache to avoid rerunning @callback on common sub + expressions. + + @callback: fn(Expr) -> Expr + """ + visitor = ExprVisitorCallbackBottomToTop(callback) + return visitor.visit(self) + + def get_r(self, mem_read=False, cst_read=False): + visitor = ExprGetR(mem_read, cst_read) + visitor.visit(self) + return visitor.elements + + + def get_w(self, mem_read=False, cst_read=False): + if self.is_assign(): + return set([self.dst]) + return set() + class ExprInt(Expr): """An ExprInt represent a constant in Miasm IR. @@ -508,12 +806,6 @@ class ExprInt(Expr): else: return str("0x%X" % self._get_int()) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -524,13 +816,6 @@ class ExprInt(Expr): return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(), self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprInt(self._arg, self._size) @@ -591,9 +876,6 @@ class ExprId(Expr): def __str__(self): return str(self._name) - def get_r(self, mem_read=False, cst_read=False): - return set([self]) - def get_w(self): return set([self]) @@ -603,13 +885,6 @@ class ExprId(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprId(self._name, self._size) @@ -653,12 +928,6 @@ class ExprLoc(Expr): def __str__(self): return str(self._loc_key) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -668,13 +937,6 @@ class ExprLoc(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprLoc(self._loc_key, self._size) @@ -745,12 +1007,6 @@ class ExprAssign(Expr): def __str__(self): return "%s = %s" % (str(self._dst), str(self._src)) - def get_r(self, mem_read=False, cst_read=False): - elements = self._src.get_r(mem_read, cst_read) - if isinstance(self._dst, ExprMem) and mem_read: - elements.update(self._dst.ptr.get_r(mem_read, cst_read)) - return elements - def get_w(self): if isinstance(self._dst, ExprMem): return set([self._dst]) # [memreg] @@ -763,19 +1019,6 @@ class ExprAssign(Expr): def _exprrepr(self): return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) - def __contains__(self, expr): - return (self == expr or - self._src.__contains__(expr) or - self._dst.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit) - if dst == self._dst and src == self._src: - return self - else: - return ExprAssign(dst, src) - def copy(self): return ExprAssign(self._dst.copy(), self._src.copy()) @@ -788,7 +1031,12 @@ class ExprAssign(Expr): arg.graph_recursive(graph) graph.add_uniq_edge(self, arg) + def is_aff(self): + warnings.warn('DEPRECATION WARNING: use is_assign()') + return True + + def is_assign(self): return True @@ -845,12 +1093,6 @@ class ExprCond(Expr): def __str__(self): return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2)) - def get_r(self, mem_read=False, cst_read=False): - out_src1 = self.src1.get_r(mem_read, cst_read) - out_src2 = self.src2.get_r(mem_read, cst_read) - return self.cond.get_r(mem_read, - cst_read).union(out_src1).union(out_src2) - def get_w(self): return set() @@ -862,21 +1104,6 @@ class ExprCond(Expr): return "%s(%r, %r, %r)" % (self.__class__.__name__, self._cond, self._src1, self._src2) - def __contains__(self, expr): - return (self == expr or - self.cond.__contains__(expr) or - self.src1.__contains__(expr) or - self.src2.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - cond = self._cond.visit(callback, test_visit) - src1 = self._src1.visit(callback, test_visit) - src2 = self._src2.visit(callback, test_visit) - if cond == self._cond and src1 == self._src1 and src2 == self._src2: - return self - return ExprCond(cond, src1, src2) - def copy(self): return ExprCond(self._cond.copy(), self._src1.copy(), @@ -953,12 +1180,6 @@ class ExprMem(Expr): def __str__(self): return "@%d[%s]" % (self.size, str(self.ptr)) - def get_r(self, mem_read=False, cst_read=False): - if mem_read: - return set(self._ptr.get_r(mem_read, cst_read).union(set([self]))) - else: - return set([self]) - def get_w(self): return set([self]) # [memreg] @@ -969,16 +1190,6 @@ class ExprMem(Expr): return "%s(%r, %r)" % (self.__class__.__name__, self._ptr, self._size) - def __contains__(self, expr): - return self == expr or self._ptr.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - ptr = self._ptr.visit(callback, test_visit) - if ptr == self._ptr: - return self - return ExprMem(ptr, self.size) - def copy(self): ptr = self.ptr.copy() return ExprMem(ptr, size=self.size) @@ -1108,10 +1319,6 @@ class ExprOp(Expr): return (self._op + '(' + ', '.join([str(arg) for arg in self._args]) + ')') - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): raise ValueError('op cannot be written!', self) @@ -1123,14 +1330,6 @@ class ExprOp(Expr): return "%s(%r, %s)" % (self.__class__.__name__, self._op, ', '.join(repr(arg) for arg in self._args)) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg.__contains__(expr): - return True - return False - def is_function_call(self): return self._op.startswith('call') @@ -1153,14 +1352,6 @@ class ExprOp(Expr): "Return True iff current operation is commutative" return (self._op in ['+', '*', '^', '&', '|']) - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) - if modified: - return ExprOp(self._op, *args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprOp(self._op, *args) @@ -1213,9 +1404,6 @@ class ExprSlice(Expr): def __str__(self): return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop) - def get_r(self, mem_read=False, cst_read=False): - return self._arg.get_r(mem_read, cst_read) - def get_w(self): return self._arg.get_w() @@ -1226,18 +1414,6 @@ class ExprSlice(Expr): return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, self._start, self._stop) - def __contains__(self, expr): - if self == expr: - return True - return self._arg.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - arg = self._arg.visit(callback, test_visit) - if arg == self._arg: - return self - return ExprSlice(arg, self._start, self._stop) - def copy(self): return ExprSlice(self._arg.copy(), self._start, self._stop) @@ -1310,10 +1486,6 @@ class ExprCompose(Expr): def __str__(self): return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}' - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): return reduce(lambda elements, arg: elements.union(arg.get_w()), self._args, set()) @@ -1325,24 +1497,6 @@ class ExprCompose(Expr): def _exprrepr(self): return "%s%r" % (self.__class__.__name__, self._args) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg == expr: - return True - if arg.__contains__(expr): - return True - return False - - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)]) - if modified: - return ExprCompose(*args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprCompose(*args) @@ -1669,8 +1823,8 @@ def match_expr(expr, pattern, tks, result=None): return False return result - elif expr.is_aff(): - if not pattern.is_aff(): + elif expr.is_assign(): + if not pattern.is_assign(): return False if match_expr(expr.src, pattern.src, tks, result) is False: return False diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py index 03a779a6..3f54b158 100644 --- a/miasm/expression/simplifications.py +++ b/miasm/expression/simplifications.py @@ -11,6 +11,7 @@ from miasm.expression import simplifications_cond from miasm.expression import simplifications_explicit from miasm.expression.expression_helper import fast_unify import miasm.expression.expression as m2_expr +from miasm.expression.expression import ExprVisitorCallbackBottomToTop # Expression Simplifier # --------------------- @@ -22,7 +23,7 @@ log_exprsimp.addHandler(console_handler) log_exprsimp.setLevel(logging.WARNING) -class ExpressionSimplifier(object): +class ExpressionSimplifier(ExprVisitorCallbackBottomToTop): """Wrapper on expression simplification passes. @@ -49,6 +50,8 @@ class ExpressionSimplifier(object): simplifications_common.simp_double_signext, simplifications_common.simp_zeroext_eq_cst, simplifications_common.simp_ext_eq_ext, + simplifications_common.simp_ext_cond_int, + simplifications_common.simp_sub_cf_zero, simplifications_common.simp_cmp_int, simplifications_common.simp_cmp_bijective_op, @@ -118,8 +121,8 @@ class ExpressionSimplifier(object): def __init__(self): + super(ExpressionSimplifier, self).__init__(self.expr_simp_inner) self.expr_simp_cb = {} - self.simplified_exprs = set() def enable_passes(self, passes): """Add passes from @passes @@ -129,7 +132,7 @@ class ExpressionSimplifier(object): """ # Clear cache of simplifiied expressions when adding a new pass - self.simplified_exprs.clear() + self.cache.clear() for k, v in viewitems(passes): self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v) @@ -156,46 +159,29 @@ class ExpressionSimplifier(object): return expression - def expr_simp(self, expression): + def expr_simp_inner(self, expression): """Apply enabled simplifications on expression and find a stable state @expression: Expr instance Return an Expr instance""" - if expression in self.simplified_exprs: - return expression - # Find a stable state while True: # Canonize and simplify - e_new = self.apply_simp(expression.canonize()) - if e_new == expression: - break - - # Launch recursivity - expression = self.expr_simp_wrapper(e_new) - self.simplified_exprs.add(expression) - # Mark expression as simplified - self.simplified_exprs.add(e_new) - - return e_new - - def expr_simp_wrapper(self, expression, callback=None): - """Apply enabled simplifications on expression - @expression: Expr instance - @manual_callback: If set, call this function instead of normal one - Return an Expr instance""" + new_expr = self.apply_simp(expression.canonize()) + if new_expr == expression: + return new_expr + # Run recursively simplification on fresh new expression + new_expr = self.visit(new_expr) + expression = new_expr + return new_expr - if expression in self.simplified_exprs: - return expression - - if callback is None: - callback = self.expr_simp - - return expression.visit(callback, lambda e: e not in self.simplified_exprs) + def expr_simp(self, expression): + "Call simplification recursively" + return self.visit(expression) - def __call__(self, expression, callback=None): - "Wrapper on expr_simp_wrapper" - return self.expr_simp_wrapper(expression, callback) + def __call__(self, expression): + "Call simplification recursively" + return self.visit(expression) # Public ExprSimplificationPass instance with commons passes diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py index 1c0bb04c..932db49a 100644 --- a/miasm/expression/simplifications_common.py +++ b/miasm/expression/simplifications_common.py @@ -32,30 +32,30 @@ def simp_cst_propagation(e_s, expr): int2 = args.pop() int1 = args.pop() if op_name == '+': - out = int1.arg + int2.arg + out = mod_size2uint[int1.size](int(int1) + int(int2)) elif op_name == '*': - out = int1.arg * int2.arg + out = mod_size2uint[int1.size](int(int1) * int(int2)) elif op_name == '**': - out =int1.arg ** int2.arg + out = mod_size2uint[int1.size](int(int1) ** int(int2)) elif op_name == '^': - out = int1.arg ^ int2.arg + out = mod_size2uint[int1.size](int(int1) ^ int(int2)) elif op_name == '&': - out = int1.arg & int2.arg + out = mod_size2uint[int1.size](int(int1) & int(int2)) elif op_name == '|': - out = int1.arg | int2.arg + out = mod_size2uint[int1.size](int(int1) | int(int2)) elif op_name == '>>': if int(int2) > int1.size: out = 0 else: - out = int1.arg >> int2.arg + out = mod_size2uint[int1.size](int(int1) >> int(int2)) elif op_name == '<<': if int(int2) > int1.size: out = 0 else: - out = int1.arg << int2.arg + out = mod_size2uint[int1.size](int(int1) << int(int2)) elif op_name == 'a>>': - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: @@ -63,55 +63,57 @@ def simp_cst_propagation(e_s, expr): else: out = 0 else: - out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) + out = mod_size2uint[int1.size](tmp1 >> tmp2) elif op_name == '>>>': - shifter = int2.arg % int2.size - out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) + shifter = int(int2) % int2.size + out = (int(int1) >> shifter) | (int(int1) << (int2.size - shifter)) elif op_name == '<<<': - shifter = int2.arg % int2.size - out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) + shifter = int(int2) % int2.size + out = (int(int1) << shifter) | (int(int1) >> (int2.size - shifter)) elif op_name == '/': - out = int1.arg // int2.arg + assert int(int2), "division by 0" + out = int(int1) // int(int2) elif op_name == '%': - out = int1.arg % int2.arg + assert int(int2), "division by 0" + out = int(int1) % int(int2) elif op_name == 'sdiv': - assert int2.arg.arg - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2int[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 // tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2int[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 // tmp2) elif op_name == 'smod': - assert int2.arg.arg - tmp1 = mod_size2int[int1.arg.size](int1.arg) - tmp2 = mod_size2int[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 % tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2int[int1.size](int(int1)) + tmp2 = mod_size2int[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 % tmp2) elif op_name == 'umod': - assert int2.arg.arg - tmp1 = mod_size2uint[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 % tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2uint[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 % tmp2) elif op_name == 'udiv': - assert int2.arg.arg - tmp1 = mod_size2uint[int1.arg.size](int1.arg) - tmp2 = mod_size2uint[int2.arg.size](int2.arg) - out = mod_size2uint[int1.arg.size](tmp1 // tmp2) + assert int(int2), "division by 0" + tmp1 = mod_size2uint[int1.size](int(int1)) + tmp2 = mod_size2uint[int2.size](int(int2)) + out = mod_size2uint[int1.size](tmp1 // tmp2) - args.append(ExprInt(out, int1.size)) + args.append(ExprInt(int(out), int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 - while args[0].arg & (1 << i) == 0 and i < args[0].size: + while int(args[0]) & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): - if args[0].arg == 0: + if int(args[0]) == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 - while args[0].arg & (1 << i) == 0: + while int(args[0]) & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) @@ -120,6 +122,7 @@ def simp_cst_propagation(e_s, expr): len(args[0].args) == 1): return args[0].args[0] + # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) @@ -207,13 +210,13 @@ def simp_cst_propagation(e_s, expr): j += 1 i += 1 - if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: + if op_name in ['+', '^', '|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and - args[1].arg == args[0].size): + int(args[1]) == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow @@ -277,7 +280,10 @@ def simp_cst_propagation(e_s, expr): # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: - return ExprOp('&', *args[:-1]) + args = args[:-1] + if len(args) == 1: + return args[0] + return ExprOp('&', *args) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: @@ -289,7 +295,7 @@ def simp_cst_propagation(e_s, expr): # ((A & mask) >> shift) with mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and - 2 ** args[1].arg > args[0].args[1].arg): + 2 ** int(args[1]) > int(args[0].args[1])): return ExprInt(0, args[0].size) # parity(int) => int @@ -315,7 +321,6 @@ def simp_cst_propagation(e_s, expr): args = args[0].args return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) - # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): @@ -450,8 +455,8 @@ def simp_cond_factor(e_s, expr): for cond, vals in viewitems(conds): new_src1 = [x.src1 for x in vals] new_src2 = [x.src2 for x in vals] - src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1)) - src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2)) + src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1)) + src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2)) c_out.append(ExprCond(cond, src1, src2)) if len(c_out) == 1: @@ -471,7 +476,7 @@ def simp_slice(e_s, expr): if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 - return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) + return ExprInt(int((int(expr.arg) >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: @@ -521,7 +526,7 @@ def simp_slice(e_s, expr): # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): - tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) + tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond @@ -536,7 +541,7 @@ def simp_slice(e_s, expr): # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): - args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] + args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size @@ -626,7 +631,7 @@ def simp_cond(_, expr): expr = expr.src1 # int ? A:B => A or B elif expr.cond.is_int(): - if expr.cond.arg == 0: + if int(expr.cond) == 0: expr = expr.src2 else: expr = expr.src1 @@ -646,8 +651,8 @@ def simp_cond(_, expr): elif (expr.cond.is_cond() and expr.cond.src1.is_int() and expr.cond.src2.is_int()): - int1 = expr.cond.src1.arg.arg - int2 = expr.cond.src2.arg.arg + int1 = int(expr.cond.src1) + int2 = int(expr.cond.src2) if int1 and int2: expr = expr.src1 elif int1 == 0 and int2 == 0: @@ -906,6 +911,15 @@ def simp_cond_flag(_, expr): return expr +def simp_sub_cf_zero(_, expr): + """FLAG_SUB_CF(0, X) => (X)?1:0""" + if not expr.is_op("FLAG_SUB_CF"): + return expr + if not expr.args[0].is_int(0): + return expr + return ExprCond(expr.args[1], ExprInt(1, 1), ExprInt(0, 1)) + + def simp_cmp_int(expr_simp, expr): """ ({X, 0} == int) => X == int[:] @@ -1069,6 +1083,13 @@ def simp_cmp_bijective_op(expr_simp, expr): args_a.remove(value) args_b.remove(value) + # a + b == a + b + c + if not args_a: + return ExprOp(TOK_EQUAL, ExprOp(op, *args_b), ExprInt(0, args_b[0].size)) + # a + b + c == a + b + if not args_b: + return ExprOp(TOK_EQUAL, ExprOp(op, *args_a), ExprInt(0, args_a[0].size)) + arg_a = ExprOp(op, *args_a) arg_b = ExprOp(op, *args_b) return ExprOp(TOK_EQUAL, arg_a, arg_b) @@ -1362,6 +1383,23 @@ def simp_ext_cst(_, expr): return ret + +def simp_ext_cond_int(e_s, expr): + """ + zeroExt(ExprCond(X, Int, Int)) => ExprCond(X, Int, Int) + """ + if not (expr.op.startswith("zeroExt") or expr.op.startswith("signExt")): + return expr + arg = expr.args[0] + if not arg.is_cond(): + return expr + if not (arg.src1.is_int() and arg.src2.is_int()): + return expr + src1 = ExprOp(expr.op, arg.src1) + src2 = ExprOp(expr.op, arg.src2) + return e_s(ExprCond(arg.cond, src1, src2)) + + def simp_slice_of_ext(_, expr): """ C.zeroExt(X)[A:B] => 0 if A >= size(C) |