diff options
| author | Fabrice Desclaux <fabrice.desclaux@cea.fr> | 2020-03-10 16:35:51 +0100 |
|---|---|---|
| committer | Fabrice Desclaux <fabrice.desclaux@cea.fr> | 2020-03-29 00:08:03 +0100 |
| commit | bf9babdcd886d51666c04e4fc39a4b03e281974a (patch) | |
| tree | 021484913b08e5c72366da78e6a6c5d68e31893f | |
| parent | 83196a14885467a043666882db7d8120bb127b61 (diff) | |
| download | miasm-bf9babdcd886d51666c04e4fc39a4b03e281974a.tar.gz miasm-bf9babdcd886d51666c04e4fc39a4b03e281974a.zip | |
Add Expression visitor
| -rw-r--r-- | miasm/analysis/data_flow.py | 20 | ||||
| -rw-r--r-- | miasm/analysis/depgraph.py | 117 | ||||
| -rw-r--r-- | miasm/analysis/dse.py | 4 | ||||
| -rw-r--r-- | miasm/expression/expression.py | 478 | ||||
| -rw-r--r-- | miasm/expression/simplifications.py | 52 | ||||
| -rw-r--r-- | miasm/expression/simplifications_common.py | 8 | ||||
| -rw-r--r-- | test/expression/expression.py | 45 |
7 files changed, 420 insertions, 304 deletions
diff --git a/miasm/analysis/data_flow.py b/miasm/analysis/data_flow.py index 7bd6d72f..6395fa8c 100644 --- a/miasm/analysis/data_flow.py +++ b/miasm/analysis/data_flow.py @@ -6,8 +6,8 @@ from future.utils import viewitems, viewvalues from miasm.core.utils import encode_hex from miasm.core.graph import DiGraph from miasm.ir.ir import AssignBlock, IRBlock -from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ - ExprAssign, ExprOp +from miasm.expression.expression import ExprLoc, ExprMem, ExprSlice, ExprId, \ + ExprInt, ExprAssign, ExprOp, ExprCompose, ExprCond, ExprWalk from miasm.expression.simplifications import expr_simp from miasm.core.interval import interval from miasm.expression.expression_helper import possible_values @@ -736,22 +736,16 @@ def expr_test_visit(expr, test): return False -def expr_has_mem_test(expr, result): - if result: - # Don't analyse if we already found a candidate - return False - if expr.is_mem(): - result.add(expr) - return False - return True - - def expr_has_mem(expr): """ Return True if expr contains at least one memory access @expr: Expr instance """ - return expr_test_visit(expr, expr_has_mem_test) + + def has_mem(self): + return self.is_mem() + visitor = ExprWalk(has_mem) + return visitor.visit(expr) class PropagateThroughExprId(object): diff --git a/miasm/analysis/depgraph.py b/miasm/analysis/depgraph.py index 7113dd51..b0d13318 100644 --- a/miasm/analysis/depgraph.py +++ b/miasm/analysis/depgraph.py @@ -4,7 +4,8 @@ from functools import total_ordering from future.utils import viewitems -from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign +from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign, \ + ExprWalk from miasm.core.graph import DiGraph from miasm.core.locationdb import LocationDB from miasm.expression.simplifications import expr_simp_explicit @@ -449,6 +450,50 @@ class FollowExpr(object): if not(only_follow) or follow_expr.follow) +class FilterExprSources(ExprWalk): + """ + Walk Expression to find sources to track + @follow_mem: (optional) Track memory syntactically + @follow_call: (optional) Track through "call" + """ + def __init__(self, follow_mem, follow_call): + super(FilterExprSources, self).__init__(lambda x:None) + self.follow_mem = follow_mem + self.follow_call = follow_call + self.nofollow = set() + self.follow = set() + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = self.visit_inner(expr, *args, **kwargs) + self.cache.add(expr) + return ret + + def visit_inner(self, expr, *args, **kwargs): + if expr.is_id(): + self.follow.add(expr) + elif expr.is_int(): + self.nofollow.add(expr) + elif expr.is_loc(): + self.nofollow.add(expr) + elif expr.is_mem(): + if self.follow_mem: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + elif expr.is_function_call(): + if self.follow_call: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + + ret = super(FilterExprSources, self).visit(expr, *args, **kwargs) + return ret + + class DependencyGraph(object): """Implementation of a dependency graph @@ -480,10 +525,14 @@ class DependencyGraph(object): self._cb_follow = [] if apply_simp: self._cb_follow.append(self._follow_simp_expr) - self._cb_follow.append(lambda exprs: self._follow_exprs(exprs, - follow_mem, - follow_call)) - self._cb_follow.append(self._follow_no_loc_key) + self._cb_follow.append(lambda exprs: self.do_follow(exprs, follow_mem, follow_call)) + + @staticmethod + def do_follow(exprs, follow_mem, follow_call): + visitor = FilterExprSources(follow_mem, follow_call) + for expr in exprs: + visitor.visit(expr) + return visitor.follow, visitor.nofollow @staticmethod def _follow_simp_expr(exprs): @@ -495,64 +544,6 @@ class DependencyGraph(object): follow.add(expr_simp_explicit(expr)) return follow, set() - @staticmethod - def get_expr(expr, follow, nofollow): - """Update @follow/@nofollow according to insteresting nodes - Returns same expression (non modifier visitor). - - @expr: expression to handle - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - """ - if expr.is_id(): - follow.add(expr) - elif expr.is_int(): - nofollow.add(expr) - elif expr.is_mem(): - follow.add(expr) - return expr - - @staticmethod - def follow_expr(expr, _, nofollow, follow_mem=False, follow_call=False): - """Returns True if we must visit sub expressions. - @expr: expression to browse - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - @follow_mem: force the visit of memory sub expressions - @follow_call: force the visit of call sub expressions - """ - if not follow_mem and expr.is_mem(): - nofollow.add(expr) - return False - if not follow_call and expr.is_function_call(): - nofollow.add(expr) - return False - return True - - @classmethod - def _follow_exprs(cls, exprs, follow_mem=False, follow_call=False): - """Extracts subnodes from exprs and returns followed/non followed - expressions according to @follow_mem/@follow_call - - """ - follow, nofollow = set(), set() - for expr in exprs: - expr.visit(lambda x: cls.get_expr(x, follow, nofollow), - lambda x: cls.follow_expr(x, follow, nofollow, - follow_mem, follow_call)) - return follow, nofollow - - @staticmethod - def _follow_no_loc_key(exprs): - """Do not follow loc_keys""" - follow = set() - for expr in exprs: - if expr.is_int() or expr.is_loc(): - continue - follow.add(expr) - - return follow, set() - def _follow_apply_cb(self, expr): """Apply callback functions to @expr @expr : FollowExpr instance""" diff --git a/miasm/analysis/dse.py b/miasm/analysis/dse.py index 3a0482a3..ada3c4bd 100644 --- a/miasm/analysis/dse.py +++ b/miasm/analysis/dse.py @@ -333,8 +333,8 @@ class DSEEngine(object): self.handle(ExprInt(cur_addr, self.ir_arch.IRDst.size)) # Avoid memory issue in ExpressionSimplifier - if len(self.symb.expr_simp.simplified_exprs) > 100000: - self.symb.expr_simp.simplified_exprs.clear() + if len(self.symb.expr_simp.cache) > 100000: + self.symb.expr_simp.cache.clear() # Get IR blocks if cur_addr in self.addr_to_cacheblocks: diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py index d0e57b46..9ac631fa 100644 --- a/miasm/expression/expression.py +++ b/miasm/expression/expression.py @@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent): def str_protected_child(child, parent): return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child) -def visit_chk(visitor): - "Function decorator launching callback on Expression visit" - def wrapped(expr, callback, test_visit=lambda x: True): - if (test_visit is not None) and (not test_visit(expr)): - return expr - expr_new = visitor(expr, callback, test_visit) - if expr_new is None: - return None - expr_new2 = callback(expr_new) - return expr_new2 - return wrapped - # Expression display @@ -183,6 +171,263 @@ class LocKey(object): def __str__(self): return "loc_key_%d" % self.key + +class ExprWalkBase(object): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + """ + + def __init__(self, callback): + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + pass + elif expr.is_assign(): + ret = self.visit(expr.dst, *args, **kwargs) + if ret: + return ret + src = self.visit(expr.src, *args, **kwargs) + if ret: + return ret + elif expr.is_cond(): + ret = self.visit(expr.cond, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src1, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src2, *args, **kwargs) + if ret: + return ret + elif expr.is_mem(): + ret = self.visit(expr.ptr, *args, **kwargs) + if ret: + return ret + elif expr.is_slice(): + ret = self.visit(expr.arg, *args, **kwargs) + if ret: + return ret + elif expr.is_op(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + elif expr.is_compose(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + else: + raise TypeError("Visitor can only take Expr") + + ret = self.callback(expr, *args, **kwargs) + return ret + + +class ExprWalk(ExprWalkBase): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + Use cache mechanism. + """ + def __init__(self, callback): + self.cache = set() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = super(ExprWalk, self).visit(expr, *args, **kwargs) + if ret: + return ret + self.cache.add(expr) + return None + + +class ExprGetR(ExprWalkBase): + """ + Return ExprId/ExprMem used by a given expression + """ + def __init__(self, mem_read=False, cst_read=False): + super(ExprGetR, self).__init__(lambda x:None) + self.mem_read = mem_read + self.cst_read = cst_read + self.elements = set() + self.cache = dict() + + def get_r_leaves(self, expr): + if (expr.is_int() or expr.is_loc()) and self.cst_read: + self.elements.add(expr) + elif expr.is_mem(): + self.elements.add(expr) + elif expr.is_id(): + self.elements.add(expr) + + def visit(self, expr, *args, **kwargs): + cache_key = (expr, self.mem_read, self.cst_read) + if cache_key in self.cache: + return self.cache[cache_key] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[cache_key] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + self.get_r_leaves(expr) + if expr.is_mem() and not self.mem_read: + # Don't visit memory sons + return None + + if expr.is_assign(): + if expr.dst.is_mem() and self.mem_read: + ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs) + if expr.src.is_mem(): + self.elements.add(expr.src) + self.get_r_leaves(expr.src) + if expr.src.is_mem() and not self.mem_read: + return None + ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs) + return ret + ret = super(ExprGetR, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorBase(object): + """ + Rebuild expression by visiting sub-expressions + """ + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + ret = expr + elif expr.is_assign(): + dst = self.visit(expr.dst, *args, **kwargs) + src = self.visit(expr.src, *args, **kwargs) + ret = ExprAssign(dst, src) + elif expr.is_cond(): + cond = self.visit(expr.cond, *args, **kwargs) + src1 = self.visit(expr.src1, *args, **kwargs) + src2 = self.visit(expr.src2, *args, **kwargs) + ret = ExprCond(cond, src1, src2) + elif expr.is_mem(): + ptr = self.visit(expr.ptr, *args, **kwargs) + ret = ExprMem(ptr, expr.size) + elif expr.is_slice(): + arg = self.visit(expr.arg, *args, **kwargs) + ret = ExprSlice(arg, expr.start, expr.stop) + elif expr.is_op(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprOp(expr.op, *args) + elif expr.is_compose(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprCompose(*args) + else: + raise TypeError("Visitor can only take Expr") + return ret + + +class ExprVisitorCallbackTopToBottom(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback on each sub-expression + if @Ā¢allback return non None value, replace current node with this value + Else, continue visit of sub-expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackTopToBottom, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = self.callback(expr) + if ret: + return ret + ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorCallbackBottomToTop(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback from leaves to root expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackBottomToTop, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs) + ret = self.callback(ret) + return ret + + +class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop): + def __init__(self): + super(ExprVisitorCanonize, self).__init__(self.canonize) + + def canonize(self, expr): + if not expr.is_op(): + return expr + if not expr.is_associative(): + return expr + + # ((a+b) + c) => (a + b + c) + args = [] + for arg in expr.args: + if isinstance(arg, ExprOp) and expr.op == arg.op: + args += arg.args + else: + args.append(arg) + args = canonize_expr_list(args) + new_expr = ExprOp(expr.op, *args) + return new_expr + + +class ExprVisitorContains(ExprWalkBase): + """ + Visitor to test if a needle is in an Expression + Cache results + """ + def __init__(self): + self.cache = set() + super(ExprVisitorContains, self).__init__(self.eq_expr) + + def eq_expr(self, expr, needle, *args, **kwargs): + if expr == needle: + return True + return None + + def visit(self, expr, needle, *args, **kwargs): + if (expr, needle) in self.cache: + return None + ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs) + if ret: + return ret + self.cache.add((expr, needle)) + return None + + + def contains(self, expr, needle): + return self.visit(expr, needle) + +contains_visitor = ExprVisitorContains() +canonize_visitor = ExprVisitorCanonize() + # IR definitions class Expr(object): @@ -337,36 +582,16 @@ class Expr(object): """Find and replace sub expression using dct @dct: dictionary associating replaced Expr to its new Expr value """ - return self.visit(lambda expr: dct.get(expr, expr)) + def replace(expr): + if expr in dct: + return dct[expr] + return None + visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr)) + return visitor.visit(self) def canonize(self): "Canonize the Expression" - - def must_canon(expr): - return not expr.is_canon - - def canonize_visitor(expr): - if expr.is_canon: - return expr - if isinstance(expr, ExprOp): - if expr.is_associative(): - # ((a+b) + c) => (a + b + c) - args = [] - for arg in expr.args: - if isinstance(arg, ExprOp) and expr.op == arg.op: - args += arg.args - else: - args.append(arg) - args = canonize_expr_list(args) - new_e = ExprOp(expr.op, *args) - else: - new_e = expr - else: - new_e = expr - new_e.is_canon = True - return new_e - - return self.visit(canonize_visitor, must_canon) + return canonize_visitor.visit(self) def msb(self): "Return the Most Significant Bit" @@ -453,6 +678,32 @@ class Expr(object): """Returns True if is ExprMem and ptr is_op_segm""" return False + def __contains__(self, expr): + ret = contains_visitor.contains(self, expr) + return ret + + def visit(self, callback): + """ + Apply callbak to all sub expression of @self + This function keeps a cache to avoid rerunning @callbak on common sub + expressions. + + @callback: fn(Expr) -> Expr + """ + visitor = ExprVisitorCallbackBottomToTop(callback) + return visitor.visit(self) + + def get_r(self, mem_read=False, cst_read=False): + visitor = ExprGetR(mem_read, cst_read) + visitor.visit(self) + return visitor.elements + + + def get_w(self, mem_read=False, cst_read=False): + if self.is_assign(): + return set([self.dst]) + return set() + class ExprInt(Expr): """An ExprInt represent a constant in Miasm IR. @@ -512,12 +763,6 @@ class ExprInt(Expr): else: return str("0x%X" % self._get_int()) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -528,13 +773,6 @@ class ExprInt(Expr): return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(), self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprInt(self._arg, self._size) @@ -595,9 +833,6 @@ class ExprId(Expr): def __str__(self): return str(self._name) - def get_r(self, mem_read=False, cst_read=False): - return set([self]) - def get_w(self): return set([self]) @@ -607,13 +842,6 @@ class ExprId(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprId(self._name, self._size) @@ -657,12 +885,6 @@ class ExprLoc(Expr): def __str__(self): return str(self._loc_key) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -672,13 +894,6 @@ class ExprLoc(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprLoc(self._loc_key, self._size) @@ -749,12 +964,6 @@ class ExprAssign(Expr): def __str__(self): return "%s = %s" % (str(self._dst), str(self._src)) - def get_r(self, mem_read=False, cst_read=False): - elements = self._src.get_r(mem_read, cst_read) - if isinstance(self._dst, ExprMem) and mem_read: - elements.update(self._dst.ptr.get_r(mem_read, cst_read)) - return elements - def get_w(self): if isinstance(self._dst, ExprMem): return set([self._dst]) # [memreg] @@ -767,19 +976,6 @@ class ExprAssign(Expr): def _exprrepr(self): return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) - def __contains__(self, expr): - return (self == expr or - self._src.__contains__(expr) or - self._dst.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit) - if dst == self._dst and src == self._src: - return self - else: - return ExprAssign(dst, src) - def copy(self): return ExprAssign(self._dst.copy(), self._src.copy()) @@ -854,12 +1050,6 @@ class ExprCond(Expr): def __str__(self): return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2)) - def get_r(self, mem_read=False, cst_read=False): - out_src1 = self.src1.get_r(mem_read, cst_read) - out_src2 = self.src2.get_r(mem_read, cst_read) - return self.cond.get_r(mem_read, - cst_read).union(out_src1).union(out_src2) - def get_w(self): return set() @@ -871,21 +1061,6 @@ class ExprCond(Expr): return "%s(%r, %r, %r)" % (self.__class__.__name__, self._cond, self._src1, self._src2) - def __contains__(self, expr): - return (self == expr or - self.cond.__contains__(expr) or - self.src1.__contains__(expr) or - self.src2.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - cond = self._cond.visit(callback, test_visit) - src1 = self._src1.visit(callback, test_visit) - src2 = self._src2.visit(callback, test_visit) - if cond == self._cond and src1 == self._src1 and src2 == self._src2: - return self - return ExprCond(cond, src1, src2) - def copy(self): return ExprCond(self._cond.copy(), self._src1.copy(), @@ -962,12 +1137,6 @@ class ExprMem(Expr): def __str__(self): return "@%d[%s]" % (self.size, str(self.ptr)) - def get_r(self, mem_read=False, cst_read=False): - if mem_read: - return set(self._ptr.get_r(mem_read, cst_read).union(set([self]))) - else: - return set([self]) - def get_w(self): return set([self]) # [memreg] @@ -978,16 +1147,6 @@ class ExprMem(Expr): return "%s(%r, %r)" % (self.__class__.__name__, self._ptr, self._size) - def __contains__(self, expr): - return self == expr or self._ptr.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - ptr = self._ptr.visit(callback, test_visit) - if ptr == self._ptr: - return self - return ExprMem(ptr, self.size) - def copy(self): ptr = self.ptr.copy() return ExprMem(ptr, size=self.size) @@ -1117,10 +1276,6 @@ class ExprOp(Expr): return (self._op + '(' + ', '.join([str(arg) for arg in self._args]) + ')') - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): raise ValueError('op cannot be written!', self) @@ -1132,14 +1287,6 @@ class ExprOp(Expr): return "%s(%r, %s)" % (self.__class__.__name__, self._op, ', '.join(repr(arg) for arg in self._args)) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg.__contains__(expr): - return True - return False - def is_function_call(self): return self._op.startswith('call') @@ -1162,14 +1309,6 @@ class ExprOp(Expr): "Return True iff current operation is commutative" return (self._op in ['+', '*', '^', '&', '|']) - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) - if modified: - return ExprOp(self._op, *args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprOp(self._op, *args) @@ -1222,9 +1361,6 @@ class ExprSlice(Expr): def __str__(self): return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop) - def get_r(self, mem_read=False, cst_read=False): - return self._arg.get_r(mem_read, cst_read) - def get_w(self): return self._arg.get_w() @@ -1235,18 +1371,6 @@ class ExprSlice(Expr): return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, self._start, self._stop) - def __contains__(self, expr): - if self == expr: - return True - return self._arg.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - arg = self._arg.visit(callback, test_visit) - if arg == self._arg: - return self - return ExprSlice(arg, self._start, self._stop) - def copy(self): return ExprSlice(self._arg.copy(), self._start, self._stop) @@ -1319,10 +1443,6 @@ class ExprCompose(Expr): def __str__(self): return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}' - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): return reduce(lambda elements, arg: elements.union(arg.get_w()), self._args, set()) @@ -1334,24 +1454,6 @@ class ExprCompose(Expr): def _exprrepr(self): return "%s%r" % (self.__class__.__name__, self._args) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg == expr: - return True - if arg.__contains__(expr): - return True - return False - - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)]) - if modified: - return ExprCompose(*args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprCompose(*args) diff --git a/miasm/expression/simplifications.py b/miasm/expression/simplifications.py index 7b1d3629..a56aa0f8 100644 --- a/miasm/expression/simplifications.py +++ b/miasm/expression/simplifications.py @@ -11,6 +11,7 @@ from miasm.expression import simplifications_cond from miasm.expression import simplifications_explicit from miasm.expression.expression_helper import fast_unify import miasm.expression.expression as m2_expr +from miasm.expression.expression import ExprVisitorCallbackBottomToTop # Expression Simplifier # --------------------- @@ -22,7 +23,7 @@ log_exprsimp.addHandler(console_handler) log_exprsimp.setLevel(logging.WARNING) -class ExpressionSimplifier(object): +class ExpressionSimplifier(ExprVisitorCallbackBottomToTop): """Wrapper on expression simplification passes. @@ -118,8 +119,8 @@ class ExpressionSimplifier(object): def __init__(self): + super(ExpressionSimplifier, self).__init__(self.expr_simp_inner) self.expr_simp_cb = {} - self.simplified_exprs = set() def enable_passes(self, passes): """Add passes from @passes @@ -129,7 +130,7 @@ class ExpressionSimplifier(object): """ # Clear cache of simplifiied expressions when adding a new pass - self.simplified_exprs.clear() + self.cache.clear() for k, v in viewitems(passes): self.expr_simp_cb[k] = fast_unify(self.expr_simp_cb.get(k, []) + v) @@ -156,46 +157,29 @@ class ExpressionSimplifier(object): return expression - def expr_simp(self, expression): + def expr_simp_inner(self, expression): """Apply enabled simplifications on expression and find a stable state @expression: Expr instance Return an Expr instance""" - if expression in self.simplified_exprs: - return expression - # Find a stable state while True: # Canonize and simplify - e_new = self.apply_simp(expression.canonize()) - if e_new == expression: - break - - # Launch recursivity - expression = self.expr_simp_wrapper(e_new) - self.simplified_exprs.add(expression) - # Mark expression as simplified - self.simplified_exprs.add(e_new) - - return e_new - - def expr_simp_wrapper(self, expression, callback=None): - """Apply enabled simplifications on expression - @expression: Expr instance - @manual_callback: If set, call this function instead of normal one - Return an Expr instance""" + new_expr = self.apply_simp(expression.canonize()) + if new_expr == expression: + return new_expr + # Run recursively simplification on fresh new expression + new_expr = self.visit(new_expr) + expression = new_expr + return new_expr - if expression in self.simplified_exprs: - return expression - - if callback is None: - callback = self.expr_simp - - return expression.visit(callback, lambda e: e not in self.simplified_exprs) + def expr_simp(self, expression): + "Call simplification recursively" + return self.visit(expression) - def __call__(self, expression, callback=None): - "Wrapper on expr_simp_wrapper" - return self.expr_simp_wrapper(expression, callback) + def __call__(self, expression): + "Call simplification recursively" + return self.visit(expression) # Public ExprSimplificationPass instance with commons passes diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py index 90f8945b..38859f3a 100644 --- a/miasm/expression/simplifications_common.py +++ b/miasm/expression/simplifications_common.py @@ -450,8 +450,8 @@ def simp_cond_factor(e_s, expr): for cond, vals in viewitems(conds): new_src1 = [x.src1 for x in vals] new_src2 = [x.src2 for x in vals] - src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1)) - src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2)) + src1 = e_s.expr_simp(ExprOp(expr.op, *new_src1)) + src2 = e_s.expr_simp(ExprOp(expr.op, *new_src2)) c_out.append(ExprCond(cond, src1, src2)) if len(c_out) == 1: @@ -521,7 +521,7 @@ def simp_slice(e_s, expr): # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): - tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) + tmp = e_s.expr_simp(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond @@ -536,7 +536,7 @@ def simp_slice(e_s, expr): # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): - args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] + args = [e_s.expr_simp(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size diff --git a/test/expression/expression.py b/test/expression/expression.py index 3597eae8..9b0c2807 100644 --- a/test/expression/expression.py +++ b/test/expression/expression.py @@ -17,6 +17,7 @@ assert big_cst.size == 0x1000 # Possible values #- Common constants A = ExprId("A", 32) +B = ExprId("B", 32) cond1 = ExprId("cond1", 1) cond2 = ExprId("cond2", 16) cst1 = ExprInt(1, 32) @@ -71,3 +72,47 @@ for expr in [ aff = ExprAssign(A[0:32], cst1) assert aff.dst == A and aff.src == cst1 + + +mem = ExprMem(A, 32) +assert mem.get_r() == set([mem]) +assert mem.get_r(mem_read=True) == set([mem, A]) + +C = A+B +D = C + A + +assert A in A +assert A in C +assert B in C +assert C in C + +assert A in D +assert B in D +assert C in D +assert D in D + +assert C not in A +assert C not in B + +assert D not in A +assert D not in B +assert D not in C + + +assert cst1.get_r(cst_read=True) == set([cst1]) +mem1 = ExprMem(A, 32) +mem2 = ExprMem(mem1 + B, 32) +assert mem2.get_r() == set([mem2]) + +assign1 = ExprAssign(A, cst1) +assert assign1.get_r() == set([]) + +assign2 = ExprAssign(mem1, D) +assert assign2.get_r() == set([A, B]) +assert assign2.get_r(mem_read=True) == set([A, B]) +assert assign2.get_w() == set([mem1]) + +assign3 = ExprAssign(mem1, mem2) +assert assign3.get_r() == set([mem2]) +assert assign3.get_r(mem_read=True) == set([mem1, mem2, A, B]) +assert assign3.get_w() == set([mem1]) |