diff options
| author | Camille Mougey <commial@gmail.com> | 2017-04-21 17:19:42 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-04-21 17:19:42 +0200 |
| commit | 7a9ba958c66c80b843bdd571f6989a8bb3e98dce (patch) | |
| tree | 423cf39b965539436061a568530a84c11d749957 | |
| parent | 31109b86989e2e0d3bc09a0283d7518979545011 (diff) | |
| parent | 51d26ba82c6b74e371027cbecd5fd2025fbc6618 (diff) | |
| download | miasm-7a9ba958c66c80b843bdd571f6989a8bb3e98dce.tar.gz miasm-7a9ba958c66c80b843bdd571f6989a8bb3e98dce.zip | |
Merge pull request #533 from serpilliere/fix_exprmatch
Fix exprmatch
| -rw-r--r-- | miasm2/expression/expression.py | 566 |
1 files changed, 287 insertions, 279 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py index ee233abe..bf27218b 100644 --- a/miasm2/expression/expression.py +++ b/miasm2/expression/expression.py @@ -28,12 +28,11 @@ # +import warnings import itertools -from operator import itemgetter from miasm2.expression.modint import mod_size2uint, is_modint, size2mask, \ define_uint from miasm2.core.graph import DiGraph -import warnings # Define tokens TOK_INF = "<" @@ -59,14 +58,14 @@ EXPRCOMPOSE = 5 def visit_chk(visitor): "Function decorator launching callback on Expression visit" - def wrapped(e, cb, test_visit=lambda x: True): - if (test_visit is not None) and (not test_visit(e)): - return e - e_new = visitor(e, cb, test_visit) - if e_new is None: + def wrapped(expr, callback, test_visit=lambda x: True): + if (test_visit is not None) and (not test_visit(expr)): + return expr + expr_new = visitor(expr, callback, test_visit) + if expr_new is None: return None - e_new2 = cb(e_new) - return e_new2 + expr_new2 = callback(expr_new) + return expr_new2 return wrapped @@ -123,7 +122,7 @@ class Expr(object): canon_exprs = set() use_singleton = True - def set_size(self, value): + def set_size(self, _): raise ValueError('size is not mutable') def __init__(self): @@ -134,21 +133,21 @@ class Expr(object): size = property(lambda self: self.__size) @staticmethod - def get_object(cls, args): - if not cls.use_singleton: - return object.__new__(cls, args) + def get_object(expr_cls, args): + if not expr_cls.use_singleton: + return object.__new__(expr_cls, args) - expr = Expr.args2expr.get((cls, args)) + expr = Expr.args2expr.get((expr_cls, args)) if expr is None: - expr = object.__new__(cls, args) - Expr.args2expr[(cls, args)] = expr + expr = object.__new__(expr_cls, args) + Expr.args2expr[(expr_cls, args)] = expr return expr def get_is_canon(self): return self in Expr.canon_exprs def set_is_canon(self, value): - assert(value is True) + assert value is True Expr.canon_exprs.add(self) is_canon = property(get_is_canon, set_is_canon) @@ -198,44 +197,44 @@ class Expr(object): return False return repr(self) == repr(other) - def __ne__(self, a): - return not self.__eq__(a) + def __ne__(self, other): + return not self.__eq__(other) - def __add__(self, a): - return ExprOp('+', self, a) + def __add__(self, other): + return ExprOp('+', self, other) - def __sub__(self, a): - return ExprOp('+', self, ExprOp('-', a)) + def __sub__(self, other): + return ExprOp('+', self, ExprOp('-', other)) - def __div__(self, a): - return ExprOp('/', self, a) + def __div__(self, other): + return ExprOp('/', self, other) - def __mod__(self, a): - return ExprOp('%', self, a) + def __mod__(self, other): + return ExprOp('%', self, other) - def __mul__(self, a): - return ExprOp('*', self, a) + def __mul__(self, other): + return ExprOp('*', self, other) - def __lshift__(self, a): - return ExprOp('<<', self, a) + def __lshift__(self, other): + return ExprOp('<<', self, other) - def __rshift__(self, a): - return ExprOp('>>', self, a) + def __rshift__(self, other): + return ExprOp('>>', self, other) - def __xor__(self, a): - return ExprOp('^', self, a) + def __xor__(self, other): + return ExprOp('^', self, other) - def __or__(self, a): - return ExprOp('|', self, a) + def __or__(self, other): + return ExprOp('|', self, other) - def __and__(self, a): - return ExprOp('&', self, a) + def __and__(self, other): + return ExprOp('&', self, other) def __neg__(self): return ExprOp('-', self) - def __pow__(self, a): - return ExprOp("**",self, a) + def __pow__(self, other): + return ExprOp("**", self, other) def __invert__(self): return ExprOp('^', self, self.mask) @@ -254,37 +253,37 @@ class Expr(object): if dct is None: dct = {} - def my_replace(e, dct): - if e in dct: - return dct[e] - return e + def my_replace(expr, dct): + if expr in dct: + return dct[expr] + return expr - return self.visit(lambda e: my_replace(e, dct)) + return self.visit(lambda expr: my_replace(expr, dct)) def canonize(self): "Canonize the Expression" - def must_canon(e): - return not e.is_canon + def must_canon(expr): + return not expr.is_canon - def canonize_visitor(e): - if e.is_canon: - return e - if isinstance(e, ExprOp): - if e.is_associative(): + def canonize_visitor(expr): + if expr.is_canon: + return expr + if isinstance(expr, ExprOp): + if expr.is_associative(): # ((a+b) + c) => (a + b + c) args = [] - for arg in e.args: - if isinstance(arg, ExprOp) and e.op == arg.op: + for arg in expr.args: + if isinstance(arg, ExprOp) and expr.op == arg.op: args += arg.args else: args.append(arg) args = canonize_expr_list(args) - new_e = ExprOp(e.op, *args) + new_e = ExprOp(expr.op, *args) else: - new_e = e + new_e = expr else: - new_e = e + new_e = expr new_e.is_canon = True return new_e @@ -292,33 +291,30 @@ class Expr(object): def msb(self): "Return the Most Significant Bit" - s = self.size - return self[s - 1:s] + return self[self.size - 1:self.size] def zeroExtend(self, size): """Zero extend to size @size: int """ - assert(self.size <= size) + assert self.size <= size if self.size == size: return self ad_size = size - self.size - n = ExprInt(0, ad_size) - return ExprCompose(self, n) + return ExprCompose(self, ExprInt(0, ad_size)) def signExtend(self, size): """Sign extend to size @size: int """ - assert(self.size <= size) + assert self.size <= size if self.size == size: return self ad_size = size - self.size - c = ExprCompose(self, - ExprCond(self.msb(), - ExprInt(size2mask(ad_size), ad_size), - ExprInt(0, ad_size))) - return c + return ExprCompose(self, + ExprCond(self.msb(), + ExprInt(size2mask(ad_size), ad_size), + ExprInt(0, ad_size))) def graph_recursive(self, graph): """Recursive method used by graph @@ -453,11 +449,11 @@ class ExprInt(Expr): return "%s(0x%X, %d)" % (self.__class__.__name__, self.__get_int(), self.__size) - def __contains__(self, e): - return self == e + def __contains__(self, expr): + return self == expr @visit_chk - def visit(self, cb, tv=None): + def visit(self, callback, test_visit=None): return self def copy(self): @@ -522,17 +518,16 @@ class ExprId(Expr): return set([self]) def _exprhash(self): - # TODO XXX: hash size ?? return hash((EXPRID, self.__name, self.__size)) def _exprrepr(self): return "%s(%r, %d)" % (self.__class__.__name__, self.__name, self.__size) - def __contains__(self, e): - return self == e + def __contains__(self, expr): + return self == expr @visit_chk - def visit(self, cb, tv=None): + def visit(self, callback, test_visit=None): return self def copy(self): @@ -571,7 +566,7 @@ class ExprAff(Expr): if dst.size != src.size: raise ValueError( "sanitycheck: ExprAff args must have same size! %s" % - ([(str(arg), arg.size) for arg in [dst, src]])) + ([(str(arg), arg.size) for arg in [dst, src]])) self.__size = self.dst.size @@ -627,8 +622,8 @@ class ExprAff(Expr): self.__dst.__contains__(expr)) @visit_chk - def visit(self, cb, tv=None): - dst, src = self.__dst.visit(cb, tv), self.__src.visit(cb, tv) + def visit(self, callback, test_visit=None): + dst, src = self.__dst.visit(callback, test_visit), self.__src.visit(callback, test_visit) if dst == self.__dst and src == self.__src: return self else: @@ -672,7 +667,7 @@ class ExprCond(Expr): super(ExprCond, self).__init__() self.__cond, self.__src1, self.__src2 = cond, src1, src2 - assert(src1.size == src2.size) + assert src1.size == src2.size self.__size = self.src1.size size = property(lambda self: self.__size) @@ -707,20 +702,18 @@ class ExprCond(Expr): return "%s(%r, %r, %r)" % (self.__class__.__name__, self.__cond, self.__src1, self.__src2) - def __contains__(self, e): - return (self == e or - self.cond.__contains__(e) or - self.src1.__contains__(e) or - self.src2.__contains__(e)) + def __contains__(self, expr): + return (self == expr or + self.cond.__contains__(expr) or + self.src1.__contains__(expr) or + self.src2.__contains__(expr)) @visit_chk - def visit(self, cb, tv=None): - cond = self.__cond.visit(cb, tv) - src1 = self.__src1.visit(cb, tv) - src2 = self.__src2.visit(cb, tv) - if (cond == self.__cond and - src1 == self.__src1 and - src2 == self.__src2): + def visit(self, callback, test_visit=None): + cond = self.__cond.visit(callback, test_visit) + src1 = self.__src1.visit(callback, test_visit) + src2 = self.__src2.visit(callback, test_visit) + if cond == self.__cond and src1 == self.__src1 and src2 == self.__src2: return self return ExprCond(cond, src1, src2) @@ -802,8 +795,8 @@ class ExprMem(Expr): return self == expr or self.__arg.__contains__(expr) @visit_chk - def visit(self, cb, tv=None): - arg = self.__arg.visit(cb, tv) + def visit(self, callback, test_visit=None): + arg = self.__arg.visit(callback, test_visit) if arg == self.__arg: return self return ExprMem(arg, self.size) @@ -855,7 +848,7 @@ class ExprOp(Expr): if op not in ["segm"]: raise ValueError( "sanitycheck: ExprOp args must have same size! %s" % - ([(str(arg), arg.size) for arg in args])) + ([(str(arg), arg.size) for arg in args])) if not isinstance(op, str): raise ValueError("ExprOp: 'op' argument must be a string") @@ -869,44 +862,44 @@ class ExprOp(Expr): 'fxam_c0', 'fxam_c1', 'fxam_c2', 'fxam_c3', "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf", "ucomiss_zf", "ucomiss_pf", "ucomiss_cf"]: - sz = 1 + size = 1 elif self.__op in [TOK_INF, TOK_INF_SIGNED, TOK_INF_UNSIGNED, TOK_INF_EQUAL, TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, TOK_EQUAL, TOK_POS, TOK_POS_STRICT, - ]: - sz = 1 + ]: + size = 1 elif self.__op in ['mem_16_to_double', 'mem_32_to_double', 'mem_64_to_double', 'mem_80_to_double', 'int_16_to_double', 'int_32_to_double', 'int_64_to_double', 'int_80_to_double']: - sz = 64 + size = 64 elif self.__op in ['double_to_mem_16', 'double_to_int_16', 'float_trunc_to_int_16', 'double_trunc_to_int_16']: - sz = 16 + size = 16 elif self.__op in ['double_to_mem_32', 'double_to_int_32', 'float_trunc_to_int_32', 'double_trunc_to_int_32', 'double_to_float']: - sz = 32 + size = 32 elif self.__op in ['double_to_mem_64', 'double_to_int_64', 'float_trunc_to_int_64', 'double_trunc_to_int_64', 'float_to_double']: - sz = 64 + size = 64 elif self.__op in ['double_to_mem_80', 'double_to_int_80', 'float_trunc_to_int_80', 'double_trunc_to_int_80']: - sz = 80 + size = 80 elif self.__op in ['segm']: - sz = self.__args[1].size + size = self.__args[1].size else: if None in sizes: - sz = None + size = None else: # All arguments have the same size - sz = list(sizes)[0] + size = list(sizes)[0] - self.__size = sz + self.__size = size size = property(lambda self: self.__size) op = property(lambda self: self.__op) @@ -950,11 +943,11 @@ class ExprOp(Expr): return "%s(%r, %s)" % (self.__class__.__name__, self.__op, ', '.join(repr(arg) for arg in self.__args)) - def __contains__(self, e): - if self == e: + def __contains__(self, expr): + if self == expr: return True for arg in self.__args: - if arg.__contains__(e): + if arg.__contains__(expr): return True return False @@ -970,8 +963,8 @@ class ExprOp(Expr): return (self.__op in ['+', '*', '^', '&', '|']) @visit_chk - def visit(self, cb, tv=None): - args = [arg.visit(cb, tv) for arg in self.__args] + def visit(self, callback, test_visit=None): + args = [arg.visit(callback, test_visit) for arg in self.__args] modified = any([arg[0] != arg[1] for arg in zip(self.__args, args)]) if modified: return ExprOp(self.__op, *args) @@ -1007,7 +1000,7 @@ class ExprSlice(Expr): def __init__(self, arg, start, stop): super(ExprSlice, self).__init__() - assert(start < stop) + assert start < stop self.__arg, self.__start, self.__stop = arg, start, stop self.__size = self.__stop - self.__start @@ -1045,8 +1038,8 @@ class ExprSlice(Expr): return self.__arg.__contains__(expr) @visit_chk - def visit(self, cb, tv=None): - arg = self.__arg.visit(cb, tv) + def visit(self, callback, test_visit=None): + arg = self.__arg.visit(callback, test_visit) if arg == self.__arg: return self return ExprSlice(arg, self.__start, self.__stop) @@ -1137,19 +1130,19 @@ class ExprCompose(Expr): def _exprrepr(self): return "%s%r" % (self.__class__.__name__, self.__args) - def __contains__(self, e): - if self == e: + def __contains__(self, expr): + if self == expr: return True for arg in self.__args: - if arg == e: + if arg == expr: return True - if arg.__contains__(e): + if arg.__contains__(expr): return True return False @visit_chk - def visit(self, cb, tv=None): - args = [arg.visit(cb, tv) for arg in self.__args] + def visit(self, callback, test_visit=None): + args = [arg.visit(callback, test_visit) for arg in self.__args] modified = any([arg != arg_new for arg, arg_new in zip(self.__args, args)]) if modified: return ExprCompose(*args) @@ -1179,116 +1172,116 @@ class ExprCompose(Expr): return True # Expression order for comparaison -expr_order_dict = {ExprId: 1, +EXPR_ORDER_DICT = {ExprId: 1, ExprCond: 2, ExprMem: 3, ExprOp: 4, ExprSlice: 5, ExprCompose: 7, ExprInt: 8, - } + } -def compare_exprs_compose(e1, e2): +def compare_exprs_compose(expr1, expr2): # Sort by start bit address, then expr, then stop but address - x = cmp(e1[1], e2[1]) - if x: - return x - x = compare_exprs(e1[0], e2[0]) - if x: - return x - x = cmp(e1[2], e2[2]) - return x + ret = cmp(expr1[1], expr2[1]) + if ret: + return ret + ret = compare_exprs(expr1[0], expr2[0]) + if ret: + return ret + ret = cmp(expr1[2], expr2[2]) + return ret def compare_expr_list_compose(l1_e, l2_e): # Sort by list elements in incremental order, then by list size for i in xrange(min(len(l1_e), len(l2_e))): - x = compare_exprs(l1_e[i], l2_e[i]) - if x: - return x + ret = compare_exprs(l1_e[i], l2_e[i]) + if ret: + return ret return cmp(len(l1_e), len(l2_e)) def compare_expr_list(l1_e, l2_e): # Sort by list elements in incremental order, then by list size for i in xrange(min(len(l1_e), len(l2_e))): - x = compare_exprs(l1_e[i], l2_e[i]) - if x: - return x + ret = compare_exprs(l1_e[i], l2_e[i]) + if ret: + return ret return cmp(len(l1_e), len(l2_e)) -def compare_exprs(e1, e2): +def compare_exprs(expr1, expr2): """Compare 2 expressions for canonization - @e1: Expr - @e2: Expr + @expr1: Expr + @expr2: Expr 0 => == - 1 => e1 > e2 - -1 => e1 < e2 + 1 => expr1 > expr2 + -1 => expr1 < expr2 """ - c1 = e1.__class__ - c2 = e2.__class__ - if c1 != c2: - return cmp(expr_order_dict[c1], expr_order_dict[c2]) - if e1 == e2: + cls1 = expr1.__class__ + cls2 = expr2.__class__ + if cls1 != cls2: + return cmp(EXPR_ORDER_DICT[cls1], EXPR_ORDER_DICT[cls2]) + if expr1 == expr2: return 0 - if c1 == ExprInt: - ret = cmp(e1.size, e2.size) + if cls1 == ExprInt: + ret = cmp(expr1.size, expr2.size) if ret != 0: return ret - return cmp(e1.arg, e2.arg) - elif c1 == ExprId: - x = cmp(e1.name, e2.name) - if x: - return x - return cmp(e1.size, e2.size) - elif c1 == ExprAff: + return cmp(expr1.arg, expr2.arg) + elif cls1 == ExprId: + ret = cmp(expr1.name, expr2.name) + if ret: + return ret + return cmp(expr1.size, expr2.size) + elif cls1 == ExprAff: raise NotImplementedError( "Comparaison from an ExprAff not yet implemented") - elif c2 == ExprCond: - x = compare_exprs(e1.cond, e2.cond) - if x: - return x - x = compare_exprs(e1.src1, e2.src1) - if x: - return x - x = compare_exprs(e1.src2, e2.src2) - return x - elif c1 == ExprMem: - x = compare_exprs(e1.arg, e2.arg) - if x: - return x - return cmp(e1.size, e2.size) - elif c1 == ExprOp: - if e1.op != e2.op: - return cmp(e1.op, e2.op) - return compare_expr_list(e1.args, e2.args) - elif c1 == ExprSlice: - x = compare_exprs(e1.arg, e2.arg) - if x: - return x - x = cmp(e1.start, e2.start) - if x: - return x - x = cmp(e1.stop, e2.stop) - return x - elif c1 == ExprCompose: - return compare_expr_list_compose(e1.args, e2.args) + elif cls2 == ExprCond: + ret = compare_exprs(expr1.cond, expr2.cond) + if ret: + return ret + ret = compare_exprs(expr1.src1, expr2.src1) + if ret: + return ret + ret = compare_exprs(expr1.src2, expr2.src2) + return ret + elif cls1 == ExprMem: + ret = compare_exprs(expr1.arg, expr2.arg) + if ret: + return ret + return cmp(expr1.size, expr2.size) + elif cls1 == ExprOp: + if expr1.op != expr2.op: + return cmp(expr1.op, expr2.op) + return compare_expr_list(expr1.args, expr2.args) + elif cls1 == ExprSlice: + ret = compare_exprs(expr1.arg, expr2.arg) + if ret: + return ret + ret = cmp(expr1.start, expr2.start) + if ret: + return ret + ret = cmp(expr1.stop, expr2.stop) + return ret + elif cls1 == ExprCompose: + return compare_expr_list_compose(expr1.args, expr2.args) raise NotImplementedError( - "Comparaison between %r %r not implemented" % (e1, e2)) + "Comparaison between %r %r not implemented" % (expr1, expr2)) -def canonize_expr_list(l): - l = list(l) - l.sort(cmp=compare_exprs) - return l +def canonize_expr_list(expr_list): + expr_list = list(expr_list) + expr_list.sort(cmp=compare_exprs) + return expr_list -def canonize_expr_list_compose(l): - l = list(l) - l.sort(cmp=compare_exprs_compose) - return l +def canonize_expr_list_compose(expr_list): + expr_list = list(expr_list) + expr_list.sort(cmp=compare_exprs_compose) + return expr_list # Generate ExprInt with common size @@ -1323,47 +1316,51 @@ def ExprInt64(i): return ExprInt(i, 64) -def ExprInt_from(e, i): +def ExprInt_from(expr, i): "Generate ExprInt with size equal to expression" warnings.warn('DEPRECATION WARNING: use ExprInt(i, expr.size) instead of'\ 'ExprInt_from(expr, i))') - return ExprInt(i, e.size) + return ExprInt(i, expr.size) -def get_expr_ids_visit(e, ids): - if isinstance(e, ExprId): - ids.add(e) - return e +def get_expr_ids_visit(expr, ids): + """Visitor to retrieve ExprId in @expr + @expr: Expr""" + if isinstance(expr, ExprId): + ids.add(expr) + return expr -def get_expr_ids(e): +def get_expr_ids(expr): + """Retrieve ExprId in @expr + @expr: Expr""" ids = set() - e.visit(lambda x: get_expr_ids_visit(x, ids)) + expr.visit(lambda x: get_expr_ids_visit(x, ids)) return ids -def test_set(e, v, tks, result): +def test_set(expr, pattern, tks, result): """Test if v can correspond to e. If so, update the context in result. Otherwise, return False - @e : Expr - @v : Expr + @expr : Expr to match + @pattern : pattern Expr @tks : list of ExprId, available jokers @result : dictionary of ExprId -> Expr, current context """ - if not v in tks: - return e == v - if v in result and result[v] != e: + if not pattern in tks: + return expr == pattern + if pattern in result and result[pattern] != expr: return False - result[v] = e + result[pattern] = expr return result -def MatchExpr(pattern, expr, tks, result=None): - """Try to match the @expr expression with the pattern @pattern with @tks jokers. +def match_expr(expr, pattern, tks, result=None): + """Try to match the @pattern expression with the pattern @expr with @tks jokers. Result is output dictionary with matching joker values. - @pattern : Expr pattern - @expr : Targetted Expr to match + @expr : Expr pattern + @pattern : Targetted Expr to match @tks : list of ExprId, available jokers @result : dictionary of ExprId -> Expr, output matching context """ @@ -1371,41 +1368,41 @@ def MatchExpr(pattern, expr, tks, result=None): if result is None: result = {} - if expr in tks: - # expr is a Joker - return test_set(pattern, expr, tks, result) + if pattern in tks: + # pattern is a Joker + return test_set(expr, pattern, tks, result) - if pattern.is_int(): - return test_set(pattern, expr, tks, result) + if expr.is_int(): + return test_set(expr, pattern, tks, result) - elif pattern.is_id(): - return test_set(pattern, expr, tks, result) + elif expr.is_id(): + return test_set(expr, pattern, tks, result) - elif pattern.is_op(): + elif expr.is_op(): - # e need to be the same operation than expr - if not expr.is_op(): + # expr need to be the same operation than pattern + if not pattern.is_op(): return False - if pattern.op != expr.op: + if expr.op != pattern.op: return False - if len(pattern.args) != len(expr.args): + if len(expr.args) != len(pattern.args): return False # Perform permutation only if the current operation is commutative - if pattern.is_commutative(): - permutations = itertools.permutations(pattern.args) + if expr.is_commutative(): + permutations = itertools.permutations(expr.args) else: - permutations = [pattern.args] + permutations = [expr.args] # For each permutations of arguments for permut in permutations: good = True # We need to use a copy of result to not override it myresult = dict(result) - for sub_pattern, sub_expr in zip(permut, expr.args): - r = MatchExpr(sub_pattern, sub_expr, tks, myresult) + for sub_expr, sub_pattern in zip(permut, pattern.args): + ret = MatchExpr(sub_expr, sub_pattern, tks, myresult) # If the current permutation do not match EVERY terms - if r is False: + if ret is False: good = False break if good is True: @@ -1418,105 +1415,116 @@ def MatchExpr(pattern, expr, tks, result=None): # Recursive tests - elif pattern.is_mem(): - if not expr.is_mem(): + elif expr.is_mem(): + if not pattern.is_mem(): return False - if pattern.size != expr.size: + if expr.size != pattern.size: return False - return MatchExpr(pattern.arg, expr.arg, tks, result) + return MatchExpr(expr.arg, pattern.arg, tks, result) - elif pattern.is_slice(): - if not expr.is_slice(): + elif expr.is_slice(): + if not pattern.is_slice(): return False - if pattern.start != expr.start or pattern.stop != expr.stop: + if expr.start != pattern.start or expr.stop != pattern.stop: return False - return MatchExpr(pattern.arg, expr.arg, tks, result) + return MatchExpr(expr.arg, pattern.arg, tks, result) - elif pattern.is_cond(): - if not expr.is_cond(): + elif expr.is_cond(): + if not pattern.is_cond(): return False - if MatchExpr(pattern.cond, expr.cond, tks, result) is False: + if MatchExpr(expr.cond, pattern.cond, tks, result) is False: return False - if MatchExpr(pattern.src1, expr.src1, tks, result) is False: + if MatchExpr(expr.src1, pattern.src1, tks, result) is False: return False - if MatchExpr(pattern.src2, expr.src2, tks, result) is False: + if MatchExpr(expr.src2, pattern.src2, tks, result) is False: return False return result - elif pattern.is_compose(): - if not expr.is_compose(): + elif expr.is_compose(): + if not pattern.is_compose(): return False - for sub_pattern, sub_expr in zip(pattern.args, expr.args): - if MatchExpr(sub_pattern, sub_expr, tks, result) is False: + for sub_expr, sub_pattern in zip(expr.args, pattern.args): + if MatchExpr(sub_expr, sub_pattern, tks, result) is False: return False return result - elif pattern.is_aff(): - if not expr.is_aff(): + elif expr.is_aff(): + if not pattern.is_aff(): return False - if MatchExpr(pattern.src, expr.src, tks, result) is False: + if MatchExpr(expr.src, pattern.src, tks, result) is False: return False - if MatchExpr(pattern.dst, expr.dst, tks, result) is False: + if MatchExpr(expr.dst, pattern.dst, tks, result) is False: return False return result else: - raise NotImplementedError("MatchExpr: Unknown type: %s" % type(pattern)) + raise NotImplementedError("MatchExpr: Unknown type: %s" % type(expr)) + + +def MatchExpr(expr, pattern, tks, result=None): + warnings.warn('DEPRECATION WARNING: use match_expr instead of MatchExpr') + return match_expr(expr, pattern, tks, result) def get_rw(exprs): o_r = set() o_w = set() - for e in exprs: - o_r.update(e.get_r(mem_read=True)) - for e in exprs: - o_w.update(e.get_w()) + for expr in exprs: + o_r.update(expr.get_r(mem_read=True)) + for expr in exprs: + o_w.update(expr.get_w()) return o_r, o_w def get_list_rw(exprs, mem_read=False, cst_read=True): - """ - return list of read/write reg/cst/mem for each expressions + """Return list of read/write reg/cst/mem for each @exprs + @exprs: list of expressions + @mem_read: walk though memory accesses + @cst_read: retrieve constants """ list_rw = [] # cst_num = 0 - for e in exprs: + for expr in exprs: o_r = set() o_w = set() # get r/w - o_r.update(e.get_r(mem_read=mem_read, cst_read=cst_read)) - if isinstance(e.dst, ExprMem): - o_r.update(e.dst.arg.get_r(mem_read=mem_read, cst_read=cst_read)) - o_w.update(e.get_w()) + o_r.update(expr.get_r(mem_read=mem_read, cst_read=cst_read)) + if isinstance(expr.dst, ExprMem): + o_r.update(expr.dst.arg.get_r(mem_read=mem_read, cst_read=cst_read)) + o_w.update(expr.get_w()) # each cst is indexed o_r_rw = set() - for r in o_r: - o_r_rw.add(r) + for read in o_r: + o_r_rw.add(read) o_r = o_r_rw list_rw.append((o_r, o_w)) return list_rw -def get_expr_ops(e): - def visit_getops(e, out=None): +def get_expr_ops(expr): + """Retrieve operators of an @expr + @expr: Expr""" + def visit_getops(expr, out=None): if out is None: out = set() - if isinstance(e, ExprOp): - out.add(e.op) - return e + if isinstance(expr, ExprOp): + out.add(expr.op) + return expr ops = set() - e.visit(lambda x: visit_getops(x, ops)) + expr.visit(lambda x: visit_getops(x, ops)) return ops -def get_expr_mem(e): - def visit_getmem(e, out=None): +def get_expr_mem(expr): + """Retrieve memory accesses of an @expr + @expr: Expr""" + def visit_getmem(expr, out=None): if out is None: out = set() - if isinstance(e, ExprMem): - out.add(e) - return e + if isinstance(expr, ExprMem): + out.add(expr) + return expr ops = set() - e.visit(lambda x: visit_getmem(x, ops)) + expr.visit(lambda x: visit_getmem(x, ops)) return ops |