diff options
| author | serpilliere <serpilliere@users.noreply.github.com> | 2020-03-30 14:59:40 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-03-30 14:59:40 +0200 |
| commit | 1d5127036dc7e1688c102c5781a6618b5dd27f16 (patch) | |
| tree | 021484913b08e5c72366da78e6a6c5d68e31893f /miasm/expression/expression.py | |
| parent | 83196a14885467a043666882db7d8120bb127b61 (diff) | |
| parent | bf9babdcd886d51666c04e4fc39a4b03e281974a (diff) | |
| download | focaccia-miasm-1d5127036dc7e1688c102c5781a6618b5dd27f16.tar.gz focaccia-miasm-1d5127036dc7e1688c102c5781a6618b5dd27f16.zip | |
Merge pull request #1158 from serpilliere/expr_visitor
Add Expression visitor
Diffstat (limited to 'miasm/expression/expression.py')
| -rw-r--r-- | miasm/expression/expression.py | 478 |
1 files changed, 290 insertions, 188 deletions
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py index d0e57b46..9ac631fa 100644 --- a/miasm/expression/expression.py +++ b/miasm/expression/expression.py @@ -98,18 +98,6 @@ def should_parenthesize_child(child, parent): def str_protected_child(child, parent): return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child) -def visit_chk(visitor): - "Function decorator launching callback on Expression visit" - def wrapped(expr, callback, test_visit=lambda x: True): - if (test_visit is not None) and (not test_visit(expr)): - return expr - expr_new = visitor(expr, callback, test_visit) - if expr_new is None: - return None - expr_new2 = callback(expr_new) - return expr_new2 - return wrapped - # Expression display @@ -183,6 +171,263 @@ class LocKey(object): def __str__(self): return "loc_key_%d" % self.key + +class ExprWalkBase(object): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + """ + + def __init__(self, callback): + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + pass + elif expr.is_assign(): + ret = self.visit(expr.dst, *args, **kwargs) + if ret: + return ret + src = self.visit(expr.src, *args, **kwargs) + if ret: + return ret + elif expr.is_cond(): + ret = self.visit(expr.cond, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src1, *args, **kwargs) + if ret: + return ret + ret = self.visit(expr.src2, *args, **kwargs) + if ret: + return ret + elif expr.is_mem(): + ret = self.visit(expr.ptr, *args, **kwargs) + if ret: + return ret + elif expr.is_slice(): + ret = self.visit(expr.arg, *args, **kwargs) + if ret: + return ret + elif expr.is_op(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + elif expr.is_compose(): + for arg in expr.args: + ret = self.visit(arg, *args, **kwargs) + if ret: + return ret + else: + raise TypeError("Visitor can only take Expr") + + ret = self.callback(expr, *args, **kwargs) + return ret + + +class ExprWalk(ExprWalkBase): + """ + Walk through sub-expressions, call @callback on them. + If @callback returns a non None value, stop walk and return this value + Use cache mechanism. + """ + def __init__(self, callback): + self.cache = set() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = super(ExprWalk, self).visit(expr, *args, **kwargs) + if ret: + return ret + self.cache.add(expr) + return None + + +class ExprGetR(ExprWalkBase): + """ + Return ExprId/ExprMem used by a given expression + """ + def __init__(self, mem_read=False, cst_read=False): + super(ExprGetR, self).__init__(lambda x:None) + self.mem_read = mem_read + self.cst_read = cst_read + self.elements = set() + self.cache = dict() + + def get_r_leaves(self, expr): + if (expr.is_int() or expr.is_loc()) and self.cst_read: + self.elements.add(expr) + elif expr.is_mem(): + self.elements.add(expr) + elif expr.is_id(): + self.elements.add(expr) + + def visit(self, expr, *args, **kwargs): + cache_key = (expr, self.mem_read, self.cst_read) + if cache_key in self.cache: + return self.cache[cache_key] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[cache_key] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + self.get_r_leaves(expr) + if expr.is_mem() and not self.mem_read: + # Don't visit memory sons + return None + + if expr.is_assign(): + if expr.dst.is_mem() and self.mem_read: + ret = super(ExprGetR, self).visit(expr.dst, *args, **kwargs) + if expr.src.is_mem(): + self.elements.add(expr.src) + self.get_r_leaves(expr.src) + if expr.src.is_mem() and not self.mem_read: + return None + ret = super(ExprGetR, self).visit(expr.src, *args, **kwargs) + return ret + ret = super(ExprGetR, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorBase(object): + """ + Rebuild expression by visiting sub-expressions + """ + def visit(self, expr, *args, **kwargs): + if expr.is_int() or expr.is_id() or expr.is_loc(): + ret = expr + elif expr.is_assign(): + dst = self.visit(expr.dst, *args, **kwargs) + src = self.visit(expr.src, *args, **kwargs) + ret = ExprAssign(dst, src) + elif expr.is_cond(): + cond = self.visit(expr.cond, *args, **kwargs) + src1 = self.visit(expr.src1, *args, **kwargs) + src2 = self.visit(expr.src2, *args, **kwargs) + ret = ExprCond(cond, src1, src2) + elif expr.is_mem(): + ptr = self.visit(expr.ptr, *args, **kwargs) + ret = ExprMem(ptr, expr.size) + elif expr.is_slice(): + arg = self.visit(expr.arg, *args, **kwargs) + ret = ExprSlice(arg, expr.start, expr.stop) + elif expr.is_op(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprOp(expr.op, *args) + elif expr.is_compose(): + args = [self.visit(arg, *args, **kwargs) for arg in expr.args] + ret = ExprCompose(*args) + else: + raise TypeError("Visitor can only take Expr") + return ret + + +class ExprVisitorCallbackTopToBottom(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback on each sub-expression + if @Ā¢allback return non None value, replace current node with this value + Else, continue visit of sub-expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackTopToBottom, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = self.callback(expr) + if ret: + return ret + ret = super(ExprVisitorCallbackTopToBottom, self).visit(expr, *args, **kwargs) + return ret + + +class ExprVisitorCallbackBottomToTop(ExprVisitorBase): + """ + Rebuild expression by visiting sub-expressions + Call @callback from leaves to root expressions + """ + def __init__(self, callback): + super(ExprVisitorCallbackBottomToTop, self).__init__() + self.cache = dict() + self.callback = callback + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return self.cache[expr] + ret = self.visit_inner(expr, *args, **kwargs) + self.cache[expr] = ret + return ret + + def visit_inner(self, expr, *args, **kwargs): + ret = super(ExprVisitorCallbackBottomToTop, self).visit(expr, *args, **kwargs) + ret = self.callback(ret) + return ret + + +class ExprVisitorCanonize(ExprVisitorCallbackBottomToTop): + def __init__(self): + super(ExprVisitorCanonize, self).__init__(self.canonize) + + def canonize(self, expr): + if not expr.is_op(): + return expr + if not expr.is_associative(): + return expr + + # ((a+b) + c) => (a + b + c) + args = [] + for arg in expr.args: + if isinstance(arg, ExprOp) and expr.op == arg.op: + args += arg.args + else: + args.append(arg) + args = canonize_expr_list(args) + new_expr = ExprOp(expr.op, *args) + return new_expr + + +class ExprVisitorContains(ExprWalkBase): + """ + Visitor to test if a needle is in an Expression + Cache results + """ + def __init__(self): + self.cache = set() + super(ExprVisitorContains, self).__init__(self.eq_expr) + + def eq_expr(self, expr, needle, *args, **kwargs): + if expr == needle: + return True + return None + + def visit(self, expr, needle, *args, **kwargs): + if (expr, needle) in self.cache: + return None + ret = super(ExprVisitorContains, self).visit(expr, needle, *args, **kwargs) + if ret: + return ret + self.cache.add((expr, needle)) + return None + + + def contains(self, expr, needle): + return self.visit(expr, needle) + +contains_visitor = ExprVisitorContains() +canonize_visitor = ExprVisitorCanonize() + # IR definitions class Expr(object): @@ -337,36 +582,16 @@ class Expr(object): """Find and replace sub expression using dct @dct: dictionary associating replaced Expr to its new Expr value """ - return self.visit(lambda expr: dct.get(expr, expr)) + def replace(expr): + if expr in dct: + return dct[expr] + return None + visitor = ExprVisitorCallbackTopToBottom(lambda expr:replace(expr)) + return visitor.visit(self) def canonize(self): "Canonize the Expression" - - def must_canon(expr): - return not expr.is_canon - - def canonize_visitor(expr): - if expr.is_canon: - return expr - if isinstance(expr, ExprOp): - if expr.is_associative(): - # ((a+b) + c) => (a + b + c) - args = [] - for arg in expr.args: - if isinstance(arg, ExprOp) and expr.op == arg.op: - args += arg.args - else: - args.append(arg) - args = canonize_expr_list(args) - new_e = ExprOp(expr.op, *args) - else: - new_e = expr - else: - new_e = expr - new_e.is_canon = True - return new_e - - return self.visit(canonize_visitor, must_canon) + return canonize_visitor.visit(self) def msb(self): "Return the Most Significant Bit" @@ -453,6 +678,32 @@ class Expr(object): """Returns True if is ExprMem and ptr is_op_segm""" return False + def __contains__(self, expr): + ret = contains_visitor.contains(self, expr) + return ret + + def visit(self, callback): + """ + Apply callbak to all sub expression of @self + This function keeps a cache to avoid rerunning @callbak on common sub + expressions. + + @callback: fn(Expr) -> Expr + """ + visitor = ExprVisitorCallbackBottomToTop(callback) + return visitor.visit(self) + + def get_r(self, mem_read=False, cst_read=False): + visitor = ExprGetR(mem_read, cst_read) + visitor.visit(self) + return visitor.elements + + + def get_w(self, mem_read=False, cst_read=False): + if self.is_assign(): + return set([self.dst]) + return set() + class ExprInt(Expr): """An ExprInt represent a constant in Miasm IR. @@ -512,12 +763,6 @@ class ExprInt(Expr): else: return str("0x%X" % self._get_int()) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -528,13 +773,6 @@ class ExprInt(Expr): return "%s(0x%X, %d)" % (self.__class__.__name__, self._get_int(), self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprInt(self._arg, self._size) @@ -595,9 +833,6 @@ class ExprId(Expr): def __str__(self): return str(self._name) - def get_r(self, mem_read=False, cst_read=False): - return set([self]) - def get_w(self): return set([self]) @@ -607,13 +842,6 @@ class ExprId(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprId(self._name, self._size) @@ -657,12 +885,6 @@ class ExprLoc(Expr): def __str__(self): return str(self._loc_key) - def get_r(self, mem_read=False, cst_read=False): - if cst_read: - return set([self]) - else: - return set() - def get_w(self): return set() @@ -672,13 +894,6 @@ class ExprLoc(Expr): def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size) - def __contains__(self, expr): - return self == expr - - @visit_chk - def visit(self, callback, test_visit=None): - return self - def copy(self): return ExprLoc(self._loc_key, self._size) @@ -749,12 +964,6 @@ class ExprAssign(Expr): def __str__(self): return "%s = %s" % (str(self._dst), str(self._src)) - def get_r(self, mem_read=False, cst_read=False): - elements = self._src.get_r(mem_read, cst_read) - if isinstance(self._dst, ExprMem) and mem_read: - elements.update(self._dst.ptr.get_r(mem_read, cst_read)) - return elements - def get_w(self): if isinstance(self._dst, ExprMem): return set([self._dst]) # [memreg] @@ -767,19 +976,6 @@ class ExprAssign(Expr): def _exprrepr(self): return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) - def __contains__(self, expr): - return (self == expr or - self._src.__contains__(expr) or - self._dst.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - dst, src = self._dst.visit(callback, test_visit), self._src.visit(callback, test_visit) - if dst == self._dst and src == self._src: - return self - else: - return ExprAssign(dst, src) - def copy(self): return ExprAssign(self._dst.copy(), self._src.copy()) @@ -854,12 +1050,6 @@ class ExprCond(Expr): def __str__(self): return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2)) - def get_r(self, mem_read=False, cst_read=False): - out_src1 = self.src1.get_r(mem_read, cst_read) - out_src2 = self.src2.get_r(mem_read, cst_read) - return self.cond.get_r(mem_read, - cst_read).union(out_src1).union(out_src2) - def get_w(self): return set() @@ -871,21 +1061,6 @@ class ExprCond(Expr): return "%s(%r, %r, %r)" % (self.__class__.__name__, self._cond, self._src1, self._src2) - def __contains__(self, expr): - return (self == expr or - self.cond.__contains__(expr) or - self.src1.__contains__(expr) or - self.src2.__contains__(expr)) - - @visit_chk - def visit(self, callback, test_visit=None): - cond = self._cond.visit(callback, test_visit) - src1 = self._src1.visit(callback, test_visit) - src2 = self._src2.visit(callback, test_visit) - if cond == self._cond and src1 == self._src1 and src2 == self._src2: - return self - return ExprCond(cond, src1, src2) - def copy(self): return ExprCond(self._cond.copy(), self._src1.copy(), @@ -962,12 +1137,6 @@ class ExprMem(Expr): def __str__(self): return "@%d[%s]" % (self.size, str(self.ptr)) - def get_r(self, mem_read=False, cst_read=False): - if mem_read: - return set(self._ptr.get_r(mem_read, cst_read).union(set([self]))) - else: - return set([self]) - def get_w(self): return set([self]) # [memreg] @@ -978,16 +1147,6 @@ class ExprMem(Expr): return "%s(%r, %r)" % (self.__class__.__name__, self._ptr, self._size) - def __contains__(self, expr): - return self == expr or self._ptr.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - ptr = self._ptr.visit(callback, test_visit) - if ptr == self._ptr: - return self - return ExprMem(ptr, self.size) - def copy(self): ptr = self.ptr.copy() return ExprMem(ptr, size=self.size) @@ -1117,10 +1276,6 @@ class ExprOp(Expr): return (self._op + '(' + ', '.join([str(arg) for arg in self._args]) + ')') - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): raise ValueError('op cannot be written!', self) @@ -1132,14 +1287,6 @@ class ExprOp(Expr): return "%s(%r, %s)" % (self.__class__.__name__, self._op, ', '.join(repr(arg) for arg in self._args)) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg.__contains__(expr): - return True - return False - def is_function_call(self): return self._op.startswith('call') @@ -1162,14 +1309,6 @@ class ExprOp(Expr): "Return True iff current operation is commutative" return (self._op in ['+', '*', '^', '&', '|']) - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) - if modified: - return ExprOp(self._op, *args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprOp(self._op, *args) @@ -1222,9 +1361,6 @@ class ExprSlice(Expr): def __str__(self): return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop) - def get_r(self, mem_read=False, cst_read=False): - return self._arg.get_r(mem_read, cst_read) - def get_w(self): return self._arg.get_w() @@ -1235,18 +1371,6 @@ class ExprSlice(Expr): return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, self._start, self._stop) - def __contains__(self, expr): - if self == expr: - return True - return self._arg.__contains__(expr) - - @visit_chk - def visit(self, callback, test_visit=None): - arg = self._arg.visit(callback, test_visit) - if arg == self._arg: - return self - return ExprSlice(arg, self._start, self._stop) - def copy(self): return ExprSlice(self._arg.copy(), self._start, self._stop) @@ -1319,10 +1443,6 @@ class ExprCompose(Expr): def __str__(self): return '{' + ', '.join(["%s %s %s" % (arg, idx, idx + arg.size) for idx, arg in self.iter_args()]) + '}' - def get_r(self, mem_read=False, cst_read=False): - return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) - def get_w(self): return reduce(lambda elements, arg: elements.union(arg.get_w()), self._args, set()) @@ -1334,24 +1454,6 @@ class ExprCompose(Expr): def _exprrepr(self): return "%s%r" % (self.__class__.__name__, self._args) - def __contains__(self, expr): - if self == expr: - return True - for arg in self._args: - if arg == expr: - return True - if arg.__contains__(expr): - return True - return False - - @visit_chk - def visit(self, callback, test_visit=None): - args = [arg.visit(callback, test_visit) for arg in self._args] - modified = any([arg != arg_new for arg, arg_new in zip(self._args, args)]) - if modified: - return ExprCompose(*args) - return self - def copy(self): args = [arg.copy() for arg in self._args] return ExprCompose(*args) |