diff options
Diffstat (limited to '')
| -rw-r--r-- | miasm2/expression/expression.py | 623 | ||||
| -rw-r--r-- | miasm2/expression/expression_helper.py | 196 | ||||
| -rw-r--r-- | miasm2/expression/simplifications.py | 13 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 139 |
4 files changed, 494 insertions, 477 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py index d04530c3..324d5fea 100644 --- a/miasm2/expression/expression.py +++ b/miasm2/expression/expression.py @@ -114,28 +114,52 @@ class Expr(object): "Parent class for Miasm Expressions" - __slots__ = ["is_term", "is_simp", "is_canon", - "is_eval", "_hash", "_repr", "_size", - "is_var_ident"] + __slots__ = ["__hash", "__repr", "__size"] + all_exprs = set() + args2expr = {} + canon_exprs = set() + use_singleton = True def set_size(self, value): raise ValueError('size is not mutable') def __init__(self): - self.is_term = False # Terminal expression - self.is_simp = False # Expression already simplified - self.is_canon = False # Expression already canonised - self.is_eval = False # Expression already evalued - self.is_var_ident = False # Expression not identifier + self.__hash = None + self.__repr = None + self.__size = None - self._hash = None - self._repr = None + size = property(lambda self: self.__size) - size = property(lambda self: self._size) + @staticmethod + def get_object(cls, args): + if not cls.use_singleton: + return object.__new__(cls, args) + + expr = Expr.args2expr.get((cls, args)) + if expr is None: + expr = object.__new__(cls, args) + Expr.args2expr[(cls, args)] = expr + return expr + + def __new__(cls, *args, **kwargs): + expr = object.__new__(cls, *args, **kwargs) + return expr + + def get_is_canon(self): + return self in Expr.canon_exprs + + def set_is_canon(self, value): + assert(value is True) + Expr.canon_exprs.add(self) + + is_canon = property(get_is_canon, set_is_canon) # Common operations + def __str__(self): + raise NotImplementedError("Abstract Method") + def __getitem__(self, i): if not isinstance(i, slice): raise TypeError("Expression: Bad slice: %s" % i) @@ -153,14 +177,14 @@ class Expr(object): return False def __repr__(self): - if self._repr is None: - self._repr = self._exprrepr() - return self._repr + if self.__repr is None: + self.__repr = self._exprrepr() + return self.__repr def __hash__(self): - if self._hash is None: - self._hash = self._exprhash() - return self._hash + if self.__hash is None: + self.__hash = self._exprhash() + return self.__hash def pre_eq(self, other): """Return True if ids are equal; @@ -264,8 +288,6 @@ class Expr(object): new_e = ExprOp(e.op, *args) else: new_e = e - elif isinstance(e, ExprCompose): - new_e = ExprCompose(canonize_expr_list_compose(e.args)) else: new_e = e new_e.is_canon = True @@ -287,8 +309,7 @@ class Expr(object): return self ad_size = size - self.size n = ExprInt(0, ad_size) - return ExprCompose([(self, 0, self.size), - (n, self.size, size)]) + return ExprCompose(self, n) def signExtend(self, size): """Sign extend to size @@ -298,12 +319,10 @@ class Expr(object): if self.size == size: return self ad_size = size - self.size - c = ExprCompose([(self, 0, self.size), - (ExprCond(self.msb(), - ExprInt(size2mask(ad_size), ad_size), - ExprInt(0, ad_size)), - self.size, size) - ]) + c = ExprCompose(self, + ExprCond(self.msb(), + ExprInt(size2mask(ad_size), ad_size), + ExprInt(0, ad_size))) return c def graph_recursive(self, graph): @@ -341,7 +360,8 @@ class ExprInt(Expr): - Constant 0x12345678 on 32bits """ - __slots__ = ["_arg"] + __slots__ = Expr.__slots__ + ["__arg"] + def __init__(self, num, size=None): """Create an ExprInt from a modint or num/size @@ -351,31 +371,36 @@ class ExprInt(Expr): super(ExprInt, self).__init__() if is_modint(num): - self._arg = num - self._size = self.arg.size + self.__arg = num + self.__size = self.arg.size if size is not None and num.size != size: raise RuntimeError("size must match modint size") elif size is not None: - self._arg = mod_size2uint[size](num) - self._size = self.arg.size + self.__arg = mod_size2uint[size](num) + self.__size = self.arg.size else: raise ValueError('arg must by modint or (int,size)! %s' % num) - arg = property(lambda self: self._arg) + size = property(lambda self: self.__size) + arg = property(lambda self: self.__arg) - def __eq__(self, other): - res = self.pre_eq(other) - if res is not None: - return res - return (self._arg == other._arg and - self._size == other._size) + def __getstate__(self): + return int(self.__arg), self.__size + + def __setstate__(self, state): + self.__init__(*state) + + def __new__(cls, arg, size=None): + if size is None: + size = arg.size + return Expr.get_object(cls, (arg, size)) def __get_int(self): "Return self integer representation" - return int(self._arg & size2mask(self._size)) + return int(self.__arg & size2mask(self.__size)) def __str__(self): - if self._arg < 0: + if self.__arg < 0: return str("-0x%X" % (- self.__get_int())) else: return str("0x%X" % self.__get_int()) @@ -390,10 +415,10 @@ class ExprInt(Expr): return set() def _exprhash(self): - return hash((EXPRINT, self._arg, self._size)) + return hash((EXPRINT, self.__arg, self.__size)) def _exprrepr(self): - return "%s(%r)" % (self.__class__.__name__, self._arg) + return "%s(0x%X)" % (self.__class__.__name__, self.__get_int()) def __contains__(self, e): return self == e @@ -403,7 +428,7 @@ class ExprInt(Expr): return self def copy(self): - return ExprInt(self._arg) + return ExprInt(self.__arg) def depth(self): return 1 @@ -428,7 +453,7 @@ class ExprId(Expr): - variable v1 """ - __slots__ = ["_name"] + __slots__ = Expr.__slots__ + ["__name"] def __init__(self, name, size=32): """Create an identifier @@ -437,19 +462,22 @@ class ExprId(Expr): """ super(ExprId, self).__init__() - self._name, self._size = name, size + self.__name, self.__size = name, size - name = property(lambda self: self._name) + size = property(lambda self: self.__size) + name = property(lambda self: self.__name) - def __eq__(self, other): - res = self.pre_eq(other) - if res is not None: - return res - return (self._name == other._name and - self._size == other._size) + def __getstate__(self): + return self.__name, self.__size + + def __setstate__(self, state): + self.__init__(*state) + + def __new__(cls, name, size=32): + return Expr.get_object(cls, (name, size)) def __str__(self): - return str(self._name) + return str(self.__name) def get_r(self, mem_read=False, cst_read=False): return set([self]) @@ -459,10 +487,10 @@ class ExprId(Expr): def _exprhash(self): # TODO XXX: hash size ?? - return hash((EXPRID, self._name, self._size)) + return hash((EXPRID, self.__name, self.__size)) def _exprrepr(self): - return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size) + return "%s(%r, %d)" % (self.__class__.__name__, self.__name, self.__size) def __contains__(self, e): return self == e @@ -472,7 +500,7 @@ class ExprId(Expr): return self def copy(self): - return ExprId(self._name, self._size) + return ExprId(self.__name, self.__size) def depth(self): return 1 @@ -489,7 +517,7 @@ class ExprAff(Expr): - var1 <- 2 """ - __slots__ = ["_src", "_dst"] + __slots__ = Expr.__slots__ + ["__dst", "__src"] def __init__(self, dst, src): """Create an ExprAff for dst <- src @@ -506,80 +534,75 @@ class ExprAff(Expr): if isinstance(dst, ExprSlice): # Complete the source with missing slice parts - self._dst = dst.arg + self.__dst = dst.arg rest = [(ExprSlice(dst.arg, r[0], r[1]), r[0], r[1]) for r in dst.slice_rest()] all_a = [(src, dst.start, dst.stop)] + rest all_a.sort(key=lambda x: x[1]) - self._src = ExprCompose(all_a) + args = [expr for (expr, _, _) in all_a] + self.__src = ExprCompose(*args) else: - self._dst, self._src = dst, src + self.__dst, self.__src = dst, src + + self.__size = self.dst.size - self._size = self.dst.size + size = property(lambda self: self.__size) + dst = property(lambda self: self.__dst) + src = property(lambda self: self.__src) - dst = property(lambda self: self._dst) - src = property(lambda self: self._src) + def __getstate__(self): + return self.__dst, self.__src + + def __setstate__(self, state): + self.__init__(*state) + + def __new__(cls, dst, src): + return Expr.get_object(cls, (dst, src)) def __str__(self): - return "%s = %s" % (str(self._dst), str(self._src)) + return "%s = %s" % (str(self.__dst), str(self.__src)) def get_r(self, mem_read=False, cst_read=False): - elements = self._src.get_r(mem_read, cst_read) - if isinstance(self._dst, ExprMem) and mem_read: - elements.update(self._dst.arg.get_r(mem_read, cst_read)) + elements = self.__src.get_r(mem_read, cst_read) + if isinstance(self.__dst, ExprMem) and mem_read: + elements.update(self.__dst.arg.get_r(mem_read, cst_read)) return elements def get_w(self): - if isinstance(self._dst, ExprMem): - return set([self._dst]) # [memreg] + if isinstance(self.__dst, ExprMem): + return set([self.__dst]) # [memreg] else: - return self._dst.get_w() + return self.__dst.get_w() def _exprhash(self): - return hash((EXPRAFF, hash(self._dst), hash(self._src))) + return hash((EXPRAFF, hash(self.__dst), hash(self.__src))) def _exprrepr(self): - return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src) + return "%s(%r, %r)" % (self.__class__.__name__, self.__dst, self.__src) - def __contains__(self, e): - return self == e or self._src.__contains__(e) or self._dst.__contains__(e) - - # XXX /!\ for hackish expraff to slice - def get_modified_slice(self): - """Return an Expr list of extra expressions needed during the - object instanciation""" - - dst = self._dst - if not isinstance(self._src, ExprCompose): - raise ValueError("Get mod slice not on expraff slice", str(self)) - modified_s = [] - for arg in self._src.args: - if (not isinstance(arg[0], ExprSlice) or - arg[0].arg != dst or - arg[1] != arg[0].start or - arg[2] != arg[0].stop): - # If x is not the initial expression - modified_s.append(arg) - return modified_s + def __contains__(self, expr): + return (self == expr or + self.__src.__contains__(expr) or + self.__dst.__contains__(expr)) @visit_chk def visit(self, cb, tv=None): - dst, src = self._dst.visit(cb, tv), self._src.visit(cb, tv) - if dst == self._dst and src == self._src: + dst, src = self.__dst.visit(cb, tv), self.__src.visit(cb, tv) + if dst == self.__dst and src == self.__src: return self else: return ExprAff(dst, src) def copy(self): - return ExprAff(self._dst.copy(), self._src.copy()) + return ExprAff(self.__dst.copy(), self.__src.copy()) def depth(self): - return max(self._src.depth(), self._dst.depth()) + 1 + return max(self.__src.depth(), self.__dst.depth()) + 1 def graph_recursive(self, graph): graph.add_node(self) - for arg in [self._src, self._dst]: + for arg in [self.__src, self.__dst]: arg.graph_recursive(graph) graph.add_uniq_edge(self, arg) @@ -594,7 +617,7 @@ class ExprCond(Expr): - if (cond) then ... else ... """ - __slots__ = ["_cond", "_src1", "_src2"] + __slots__ = Expr.__slots__ + ["__cond", "__src1", "__src2"] def __init__(self, cond, src1, src2): """Create an ExprCond @@ -605,65 +628,74 @@ class ExprCond(Expr): super(ExprCond, self).__init__() + self.__cond, self.__src1, self.__src2 = cond, src1, src2 assert(src1.size == src2.size) + self.__size = self.src1.size - self._cond, self._src1, self._src2 = cond, src1, src2 - self._size = self.src1.size + size = property(lambda self: self.__size) + cond = property(lambda self: self.__cond) + src1 = property(lambda self: self.__src1) + src2 = property(lambda self: self.__src2) - cond = property(lambda self: self._cond) - src1 = property(lambda self: self._src1) - src2 = property(lambda self: self._src2) + def __getstate__(self): + return self.__cond, self.__src1, self.__src2 + + def __setstate__(self, state): + self.__init__(*state) + + def __new__(cls, cond, src1, src2): + return Expr.get_object(cls, (cond, src1, src2)) def __str__(self): - return "(%s?(%s,%s))" % (str(self._cond), str(self._src1), str(self._src2)) + return "(%s?(%s,%s))" % (str(self.__cond), str(self.__src1), str(self.__src2)) def get_r(self, mem_read=False, cst_read=False): - out_src1 = self._src1.get_r(mem_read, cst_read) - out_src2 = self._src2.get_r(mem_read, cst_read) - return self._cond.get_r(mem_read, - cst_read).union(out_src1).union(out_src2) + out_src1 = self.src1.get_r(mem_read, cst_read) + out_src2 = self.src2.get_r(mem_read, cst_read) + return self.cond.get_r(mem_read, + cst_read).union(out_src1).union(out_src2) def get_w(self): return set() def _exprhash(self): return hash((EXPRCOND, hash(self.cond), - hash(self._src1), hash(self._src2))) + hash(self.__src1), hash(self.__src2))) def _exprrepr(self): return "%s(%r, %r, %r)" % (self.__class__.__name__, - self._cond, self._src1, self._src2) + self.__cond, self.__src1, self.__src2) def __contains__(self, e): return (self == e or - self._cond.__contains__(e) or - self._src1.__contains__(e) or - self._src2.__contains__(e)) + self.cond.__contains__(e) or + self.src1.__contains__(e) or + self.src2.__contains__(e)) @visit_chk def visit(self, cb, tv=None): - cond = self._cond.visit(cb, tv) - src1 = self._src1.visit(cb, tv) - src2 = self._src2.visit(cb, tv) - if (cond == self._cond and - src1 == self._src1 and - src2 == self._src2): + cond = self.__cond.visit(cb, tv) + src1 = self.__src1.visit(cb, tv) + src2 = self.__src2.visit(cb, tv) + if (cond == self.__cond and + src1 == self.__src1 and + src2 == self.__src2): return self return ExprCond(cond, src1, src2) def copy(self): - return ExprCond(self._cond.copy(), - self._src1.copy(), - self._src2.copy()) + return ExprCond(self.__cond.copy(), + self.__src1.copy(), + self.__src2.copy()) def depth(self): - return max(self._cond.depth(), - self._src1.depth(), - self._src2.depth()) + 1 + return max(self.__cond.depth(), + self.__src1.depth(), + self.__src2.depth()) + 1 def graph_recursive(self, graph): graph.add_node(self) - for arg in [self._cond, self._src1, self._src2]: + for arg in [self.__cond, self.__src1, self.__src2]: arg.graph_recursive(graph) graph.add_uniq_edge(self, arg) @@ -677,7 +709,7 @@ class ExprMem(Expr): - Memory write """ - __slots__ = ["_arg", "_size"] + __slots__ = Expr.__slots__ + ["__arg"] def __init__(self, arg, size=32): """Create an ExprMem @@ -691,16 +723,26 @@ class ExprMem(Expr): raise ValueError( 'ExprMem: arg must be an Expr (not %s)' % type(arg)) - self._arg, self._size = arg, size + self.__arg, self.__size = arg, size + + size = property(lambda self: self.__size) + arg = property(lambda self: self.__arg) - arg = property(lambda self: self._arg) + def __getstate__(self): + return self.__arg, self.__size + + def __setstate__(self, state): + self.__init__(*state) + + def __new__(cls, arg, size=32): + return Expr.get_object(cls, (arg, size)) def __str__(self): - return "@%d[%s]" % (self._size, str(self._arg)) + return "@%d[%s]" % (self.size, str(self.arg)) def get_r(self, mem_read=False, cst_read=False): if mem_read: - return set(self._arg.get_r(mem_read, cst_read).union(set([self]))) + return set(self.__arg.get_r(mem_read, cst_read).union(set([self]))) else: return set([self]) @@ -708,36 +750,36 @@ class ExprMem(Expr): return set([self]) # [memreg] def _exprhash(self): - return hash((EXPRMEM, hash(self._arg), self._size)) + return hash((EXPRMEM, hash(self.__arg), self.__size)) def _exprrepr(self): return "%s(%r, %r)" % (self.__class__.__name__, - self._arg, self._size) + self.__arg, self.__size) - def __contains__(self, e): - return self == e or self._arg.__contains__(e) + def __contains__(self, expr): + return self == expr or self.__arg.__contains__(expr) @visit_chk def visit(self, cb, tv=None): - arg = self._arg.visit(cb, tv) - if arg == self._arg: + arg = self.__arg.visit(cb, tv) + if arg == self.__arg: return self - return ExprMem(arg, self._size) + return ExprMem(arg, self.size) def copy(self): - arg = self._arg.copy() - return ExprMem(arg, size=self._size) + arg = self.arg.copy() + return ExprMem(arg, size=self.size) def is_op_segm(self): - return isinstance(self._arg, ExprOp) and self._arg.op == 'segm' + return isinstance(self.__arg, ExprOp) and self.__arg.op == 'segm' def depth(self): - return self._arg.depth() + 1 + return self.__arg.depth() + 1 def graph_recursive(self, graph): graph.add_node(self) - self._arg.graph_recursive(graph) - graph.add_uniq_edge(self, self._arg) + self.__arg.graph_recursive(graph) + graph.add_uniq_edge(self, self.__arg) class ExprOp(Expr): @@ -750,7 +792,7 @@ class ExprOp(Expr): - parity bit(var1) """ - __slots__ = ["_op", "_args"] + __slots__ = Expr.__slots__ + ["__op", "__args"] def __init__(self, op, *args): """Create an ExprOp @@ -772,44 +814,44 @@ class ExprOp(Expr): if not isinstance(op, str): raise ValueError("ExprOp: 'op' argument must be a string") - self._op, self._args = op, tuple(args) + self.__op, self.__args = op, tuple(args) # Set size for special cases - if self._op in [ + if self.__op in [ '==', 'parity', 'fcom_c0', 'fcom_c1', 'fcom_c2', 'fcom_c3', 'fxam_c0', 'fxam_c1', 'fxam_c2', 'fxam_c3', "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf", "ucomiss_zf", "ucomiss_pf", "ucomiss_cf"]: sz = 1 - elif self._op in [TOK_INF, TOK_INF_SIGNED, - TOK_INF_UNSIGNED, TOK_INF_EQUAL, - TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, - TOK_EQUAL, TOK_POS, - TOK_POS_STRICT, - ]: + elif self.__op in [TOK_INF, TOK_INF_SIGNED, + TOK_INF_UNSIGNED, TOK_INF_EQUAL, + TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED, + TOK_EQUAL, TOK_POS, + TOK_POS_STRICT, + ]: sz = 1 - elif self._op in ['mem_16_to_double', 'mem_32_to_double', - 'mem_64_to_double', 'mem_80_to_double', - 'int_16_to_double', 'int_32_to_double', - 'int_64_to_double', 'int_80_to_double']: + elif self.__op in ['mem_16_to_double', 'mem_32_to_double', + 'mem_64_to_double', 'mem_80_to_double', + 'int_16_to_double', 'int_32_to_double', + 'int_64_to_double', 'int_80_to_double']: sz = 64 - elif self._op in ['double_to_mem_16', 'double_to_int_16', - 'float_trunc_to_int_16', 'double_trunc_to_int_16']: + elif self.__op in ['double_to_mem_16', 'double_to_int_16', + 'float_trunc_to_int_16', 'double_trunc_to_int_16']: sz = 16 - elif self._op in ['double_to_mem_32', 'double_to_int_32', - 'float_trunc_to_int_32', 'double_trunc_to_int_32', - 'double_to_float']: + elif self.__op in ['double_to_mem_32', 'double_to_int_32', + 'float_trunc_to_int_32', 'double_trunc_to_int_32', + 'double_to_float']: sz = 32 - elif self._op in ['double_to_mem_64', 'double_to_int_64', - 'float_trunc_to_int_64', 'double_trunc_to_int_64', - 'float_to_double']: + elif self.__op in ['double_to_mem_64', 'double_to_int_64', + 'float_trunc_to_int_64', 'double_trunc_to_int_64', + 'float_to_double']: sz = 64 - elif self._op in ['double_to_mem_80', 'double_to_int_80', - 'float_trunc_to_int_80', - 'double_trunc_to_int_80']: + elif self.__op in ['double_to_mem_80', 'double_to_int_80', + 'float_trunc_to_int_80', + 'double_trunc_to_int_80']: sz = 80 - elif self._op in ['segm']: - sz = self._args[1].size + elif self.__op in ['segm']: + sz = self.__args[1].size else: if None in sizes: sz = None @@ -817,256 +859,275 @@ class ExprOp(Expr): # All arguments have the same size sz = list(sizes)[0] - self._size = sz + self.__size = sz - op = property(lambda self: self._op) - args = property(lambda self: self._args) + size = property(lambda self: self.__size) + op = property(lambda self: self.__op) + args = property(lambda self: self.__args) + + def __getstate__(self): + return self.__op, self.__args + + def __setstate__(self, state): + op, args = state + self.__init__(op, *args) + + def __new__(cls, op, *args): + return Expr.get_object(cls, (op, args)) def __str__(self): if self.is_associative(): - return '(' + self._op.join([str(arg) for arg in self._args]) + ')' - if (self._op.startswith('call_func_') or - self._op == 'cpuid' or - len(self._args) > 2 or - self._op in ['parity', 'segm']): - return self._op + '(' + ', '.join([str(arg) for arg in self._args]) + ')' - if len(self._args) == 2: - return ('(' + str(self._args[0]) + - ' ' + self.op + ' ' + str(self._args[1]) + ')') + return '(' + self.__op.join([str(arg) for arg in self.__args]) + ')' + if (self.__op.startswith('call_func_') or + self.__op == 'cpuid' or + len(self.__args) > 2 or + self.__op in ['parity', 'segm']): + return self.__op + '(' + ', '.join([str(arg) for arg in self.__args]) + ')' + if len(self.__args) == 2: + return ('(' + str(self.__args[0]) + + ' ' + self.op + ' ' + str(self.__args[1]) + ')') else: return reduce(lambda x, y: x + ' ' + str(y), - self._args, - '(' + str(self._op)) + ')' + self.__args, + '(' + str(self.__op)) + ')' def get_r(self, mem_read=False, cst_read=False): return reduce(lambda elements, arg: - elements.union(arg.get_r(mem_read, cst_read)), self._args, set()) + elements.union(arg.get_r(mem_read, cst_read)), self.__args, set()) def get_w(self): raise ValueError('op cannot be written!', self) def _exprhash(self): - h_hargs = [hash(arg) for arg in self._args] - return hash((EXPROP, self._op, tuple(h_hargs))) + h_hargs = [hash(arg) for arg in self.__args] + return hash((EXPROP, self.__op, tuple(h_hargs))) def _exprrepr(self): - return "%s(%r, %s)" % (self.__class__.__name__, self._op, - ', '.join(repr(arg) for arg in self._args)) + return "%s(%r, %s)" % (self.__class__.__name__, self.__op, + ', '.join(repr(arg) for arg in self.__args)) def __contains__(self, e): if self == e: return True - for arg in self._args: + for arg in self.__args: if arg.__contains__(e): return True return False def is_function_call(self): - return self._op.startswith('call') + return self.__op.startswith('call') def is_associative(self): "Return True iff current operation is associative" - return (self._op in ['+', '*', '^', '&', '|']) + return (self.__op in ['+', '*', '^', '&', '|']) def is_commutative(self): "Return True iff current operation is commutative" - return (self._op in ['+', '*', '^', '&', '|']) + return (self.__op in ['+', '*', '^', '&', '|']) @visit_chk def visit(self, cb, tv=None): - args = [arg.visit(cb, tv) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) + args = [arg.visit(cb, tv) for arg in self.__args] + modified = any([arg[0] != arg[1] for arg in zip(self.__args, args)]) if modified: - return ExprOp(self._op, *args) + return ExprOp(self.__op, *args) return self def copy(self): - args = [arg.copy() for arg in self._args] - return ExprOp(self._op, *args) + args = [arg.copy() for arg in self.__args] + return ExprOp(self.__op, *args) def depth(self): - depth = [arg.depth() for arg in self._args] + depth = [arg.depth() for arg in self.__args] return max(depth) + 1 def graph_recursive(self, graph): graph.add_node(self) - for arg in self._args: + for arg in self.__args: arg.graph_recursive(graph) graph.add_uniq_edge(self, arg) class ExprSlice(Expr): - __slots__ = ["_arg", "_start", "_stop"] + __slots__ = Expr.__slots__ + ["__arg", "__start", "__stop"] def __init__(self, arg, start, stop): super(ExprSlice, self).__init__() assert(start < stop) + self.__arg, self.__start, self.__stop = arg, start, stop + self.__size = self.__stop - self.__start - self._arg, self._start, self._stop = arg, start, stop - self._size = self._stop - self._start + size = property(lambda self: self.__size) + arg = property(lambda self: self.__arg) + start = property(lambda self: self.__start) + stop = property(lambda self: self.__stop) - arg = property(lambda self: self._arg) - start = property(lambda self: self._start) - stop = property(lambda self: self._stop) + def __getstate__(self): + return self.__arg, self.__start, self.__stop + + def __setstate__(self, state): + self.__init__(*state) + + def __new__(cls, arg, start, stop): + return Expr.get_object(cls, (arg, start, stop)) def __str__(self): - return "%s[%d:%d]" % (str(self._arg), self._start, self._stop) + return "%s[%d:%d]" % (str(self.__arg), self.__start, self.__stop) def get_r(self, mem_read=False, cst_read=False): - return self._arg.get_r(mem_read, cst_read) + return self.__arg.get_r(mem_read, cst_read) def get_w(self): - return self._arg.get_w() + return self.__arg.get_w() def _exprhash(self): - return hash((EXPRSLICE, hash(self._arg), self._start, self._stop)) + return hash((EXPRSLICE, hash(self.__arg), self.__start, self.__stop)) def _exprrepr(self): - return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg, - self._start, self._stop) + return "%s(%r, %d, %d)" % (self.__class__.__name__, self.__arg, + self.__start, self.__stop) - def __contains__(self, e): - if self == e: + def __contains__(self, expr): + if self == expr: return True - return self._arg.__contains__(e) + return self.__arg.__contains__(expr) @visit_chk def visit(self, cb, tv=None): - arg = self._arg.visit(cb, tv) - if arg == self._arg: + arg = self.__arg.visit(cb, tv) + if arg == self.__arg: return self - return ExprSlice(arg, self._start, self._stop) + return ExprSlice(arg, self.__start, self.__stop) def copy(self): - return ExprSlice(self._arg.copy(), self._start, self._stop) + return ExprSlice(self.__arg.copy(), self.__start, self.__stop) def depth(self): - return self._arg.depth() + 1 + return self.__arg.depth() + 1 def slice_rest(self): "Return the completion of the current slice" - size = self._arg.size - if self._start >= size or self._stop > size: + size = self.__arg.size + if self.__start >= size or self.__stop > size: raise ValueError('bad slice rest %s %s %s' % - (size, self._start, self._stop)) + (size, self.__start, self.__stop)) - if self._start == self._stop: + if self.__start == self.__stop: return [(0, size)] rest = [] - if self._start != 0: - rest.append((0, self._start)) - if self._stop < size: - rest.append((self._stop, size)) + if self.__start != 0: + rest.append((0, self.__start)) + if self.__stop < size: + rest.append((self.__stop, size)) return rest def graph_recursive(self, graph): graph.add_node(self) - self._arg.graph_recursive(graph) - graph.add_uniq_edge(self, self._arg) + self.__arg.graph_recursive(graph) + graph.add_uniq_edge(self, self.__arg) class ExprCompose(Expr): """ - Compose is like a hambuger. - It's arguments are tuple of: (Expression, start, stop) - start and stop are intergers, determining Expression position in the compose. - - Burger Example: - ExprCompose([(salad, 0, 3), (cheese, 3, 10), (beacon, 10, 16)]) - In the example, salad.size == 3. + Compose is like a hambuger. It concatenate Expressions """ - __slots__ = ["_args"] + __slots__ = Expr.__slots__ + ["__args"] - def __init__(self, args): + def __init__(self, *args): """Create an ExprCompose The ExprCompose is contiguous and starts at 0 - @args: tuple(Expr, int, int) + @args: [Expr, Expr, ...] + DEPRECATED: + @args: [(Expr, int, int), (Expr, int, int), ...] """ super(ExprCompose, self).__init__() - last_stop = 0 - args = sorted(args, key=itemgetter(1)) - for e, start, stop in args: - if e.size != stop - start: - raise ValueError( - "sanitycheck: ExprCompose args must have correct size!" + - " %r %r %r" % (e, e.size, stop - start)) - if last_stop != start: - raise ValueError( - "sanitycheck: ExprCompose args must be contiguous!" + - " %r" % (args)) - last_stop = stop + is_new_style = args and isinstance(args[0], Expr) + if not is_new_style: + warnings.warn('DEPRECATION WARNING: use "ExprCompose(a, b) instead of'+ + 'ExprCemul_ir_block(self, addr, step=False)" instead of emul_ir_bloc') + + self.__args = tuple(args) + self.__size = sum([arg.size for arg in args]) + + size = property(lambda self: self.__size) + args = property(lambda self: self.__args) - # Transform args to lists - o = [] - for e, a, b in args: - assert(a >= 0 and b >= 0) - o.append(tuple([e, a, b])) - self._args = tuple(o) + def __getstate__(self): + return self.__args - self._size = self._args[-1][2] + def __setstate__(self, state): + self.__init__(state) - args = property(lambda self: self._args) + def __new__(cls, *args): + is_new_style = args and isinstance(args[0], Expr) + if not is_new_style: + assert len(args) == 1 + args = args[0] + return Expr.get_object(cls, tuple(args)) def __str__(self): - return '{' + ', '.join(['%s,%d,%d' % - (str(arg[0]), arg[1], arg[2]) for arg in self._args]) + '}' + return '{' + ', '.join([str(arg) for arg in self.__args]) + '}' def get_r(self, mem_read=False, cst_read=False): return reduce(lambda elements, arg: - elements.union(arg[0].get_r(mem_read, cst_read)), self._args, set()) + elements.union(arg.get_r(mem_read, cst_read)), self.__args, set()) def get_w(self): return reduce(lambda elements, arg: - elements.union(arg[0].get_w()), self._args, set()) + elements.union(arg.get_w()), self.__args, set()) def _exprhash(self): - h_args = [EXPRCOMPOSE] + [(hash(arg[0]), arg[1], arg[2]) - for arg in self._args] + h_args = [EXPRCOMPOSE] + [hash(arg) for arg in self.__args] return hash(tuple(h_args)) def _exprrepr(self): - return "%s(%r)" % (self.__class__.__name__, self._args) + return "%s([%r])" % (self.__class__.__name__, self.__args) def __contains__(self, e): if self == e: return True - for arg in self._args: + for arg in self.__args: if arg == e: return True - if arg[0].__contains__(e): + if arg.__contains__(e): return True return False @visit_chk def visit(self, cb, tv=None): - args = [(arg[0].visit(cb, tv), arg[1], arg[2]) for arg in self._args] - modified = any([arg[0] != arg[1] for arg in zip(self._args, args)]) + args = [arg.visit(cb, tv) for arg in self.__args] + modified = any([arg != arg_new for arg, arg_new in zip(self.__args, args)]) if modified: - return ExprCompose(args) + return ExprCompose(*args) return self def copy(self): - args = [(arg[0].copy(), arg[1], arg[2]) for arg in self._args] - return ExprCompose(args) + args = [arg.copy() for arg in self.__args] + return ExprCompose(*args) def depth(self): - depth = [arg[0].depth() for arg in self._args] + depth = [arg.depth() for arg in self.__args] return max(depth) + 1 def graph_recursive(self, graph): graph.add_node(self) for arg in self.args: - arg[0].graph_recursive(graph) - graph.add_uniq_edge(self, arg[0]) + arg.graph_recursive(graph) + graph.add_uniq_edge(self, arg) + def iter_args(self): + index = 0 + for arg in self.__args: + yield index, arg + index += arg.size # Expression order for comparaison expr_order_dict = {ExprId: 1, @@ -1094,7 +1155,7 @@ def compare_exprs_compose(e1, e2): def compare_expr_list_compose(l1_e, l2_e): # Sort by list elements in incremental order, then by list size for i in xrange(min(len(l1_e), len(l2_e))): - x = compare_exprs_compose(l1_e[i], l2_e[i]) + x = compare_exprs(l1_e[i], l2_e[i]) if x: return x return cmp(len(l1_e), len(l2_e)) @@ -1325,9 +1386,7 @@ def MatchExpr(e, m, tks, result=None): if not isinstance(m, ExprCompose): return False for a1, a2 in zip(e.args, m.args): - if a1[1] != a2[1] or a1[2] != a2[2]: - return False - r = MatchExpr(a1[0], a2[0], tks, result) + r = MatchExpr(a1, a2, tks, result) if r is False: return False return result diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py index 0c661c2a..8babba70 100644 --- a/miasm2/expression/expression_helper.py +++ b/miasm2/expression/expression_helper.py @@ -34,103 +34,76 @@ def parity(a): return cpt -def merge_sliceto_slice(args): - sources = {} - non_slice = {} - sources_int = {} - for a in args: - if isinstance(a[0], m2_expr.ExprInt): - # sources_int[a.start] = a - # copy ExprInt because we will inplace modify arg just below - # /!\ TODO XXX never ever modify inplace args... - sources_int[a[1]] = (m2_expr.ExprInt(int(a[0]), - a[2] - a[1]), - a[1], - a[2]) - elif isinstance(a[0], m2_expr.ExprSlice): - if not a[0].arg in sources: - sources[a[0].arg] = [] - sources[a[0].arg].append(a) +def merge_sliceto_slice(expr): + """ + Apply basic factorisation on ExprCompose sub compoenents + @expr: ExprCompose + """ + + slices_raw = [] + other_raw = [] + integers_raw = [] + for index, arg in expr.iter_args(): + if isinstance(arg, m2_expr.ExprInt): + integers_raw.append((index, arg)) + elif isinstance(arg, m2_expr.ExprSlice): + slices_raw.append((index, arg)) else: - non_slice[a[1]] = a - # find max stop to determine size - max_size = None - for a in args: - if max_size is None or max_size < a[2]: - max_size = a[2] - - # first simplify all num slices - final_sources = [] - sorted_s = [] - for x in sources_int.values(): - x = list(x) - # mask int - v = x[0].arg & ((1 << (x[2] - x[1])) - 1) - x[0] = m2_expr.ExprInt_from(x[0], v) - x = tuple(x) - sorted_s.append((x[1], x)) - sorted_s.sort() - while sorted_s: - start, v = sorted_s.pop() - out = [m2_expr.ExprInt(v[0].arg), v[1], v[2]] - size = v[2] - v[1] - while sorted_s: - if sorted_s[-1][1][2] != start: + other_raw.append((index, arg)) + + # Find max stop to determine size + max_size = sum([arg.size for arg in expr.args]) + + integers_merged = [] + # Merge consecutive integers + while integers_raw: + index, arg = integers_raw.pop() + new_size = arg.size + value = int(arg) + while integers_raw: + prev_index, prev_value = integers_raw[-1] + # Check if intergers are consecutive + if prev_index + prev_value.size != index: break - s_start, s_stop = sorted_s[-1][1][1], sorted_s[-1][1][2] - size += s_stop - s_start - a = m2_expr.mod_size2uint[size]( - (int(out[0]) << (out[1] - s_start)) + - int(sorted_s[-1][1][0])) - out[0] = m2_expr.ExprInt(a) - sorted_s.pop() - out[1] = s_start - out[0] = m2_expr.ExprInt(int(out[0]), size) - final_sources.append((start, out)) - - final_sources_int = final_sources - # check if same sources have corresponding start/stop - # is slice AND is sliceto - simp_sources = [] - for args in sources.values(): - final_sources = [] - sorted_s = [] - for x in args: - sorted_s.append((x[1], x)) - sorted_s.sort() - while sorted_s: - start, v = sorted_s.pop() - ee = v[0].arg[v[0].start:v[0].stop] - out = ee, v[1], v[2] - while sorted_s: - if sorted_s[-1][1][2] != start: - break - if sorted_s[-1][1][0].stop != out[0].start: - break - - start = sorted_s[-1][1][1] - # out[0].start = sorted_s[-1][1][0].start - o_e, _, o_stop = out - o1, o2 = sorted_s[-1][1][0].start, o_e.stop - o_e = o_e.arg[o1:o2] - out = o_e, start, o_stop - # update _size - # out[0]._size = out[0].stop-out[0].start - sorted_s.pop() - out = out[0], start, out[2] - - final_sources.append((start, out)) + # Merge integers + index = prev_index + new_size += prev_value.size + value = value << prev_value.size + value |= int(prev_value) + integers_raw.pop() + integers_merged.append((index, m2_expr.ExprInt(value, new_size))) + + + slices_merged = [] + # Merge consecutive slices + while slices_raw: + index, arg = slices_raw.pop() + value, slice_start, slice_stop = arg.arg, arg.start, arg.stop + while slices_raw: + prev_index, prev_value = slices_raw[-1] + # Check if slices are consecutive + if prev_index + prev_value.size != index: + break + # Check if slices can ben merged + if prev_value.arg != value: + break + if prev_value.stop != slice_start: + break + # Merge slices + index = prev_index + slice_start = prev_value.start + slices_raw.pop() + slices_merged.append((index, value[slice_start:slice_stop])) - simp_sources += final_sources - simp_sources += final_sources_int + new_args = slices_merged + integers_merged + other_raw + new_args.sort() + for i, (index, arg) in enumerate(new_args[:-1]): + assert index + arg.size == new_args[i+1][0] + ret = [arg[1] for arg in new_args] - for i, v in non_slice.items(): - simp_sources.append((i, v)) + return ret - simp_sources.sort() - simp_sources = [x[1] for x in simp_sources] - return simp_sources op_propag_cst = ['+', '*', '^', '&', '|', '>>', @@ -210,9 +183,6 @@ class Variables_Identifier(object): - original expression with variables translated """ - # Attribute used to distinguish created variables from original ones - is_var_ident = "is_var_ident" - def __init__(self, expr, var_prefix="v"): """Set the expression @expr to handle and launch variable identification process @@ -287,13 +257,11 @@ class Variables_Identifier(object): for element_done in done: todo.remove(element_done) - @classmethod - def is_var_identifier(cls, expr): + def is_var_identifier(self, expr): "Return True iff @expr is a variable identifier" if not isinstance(expr, m2_expr.ExprId): return False - - return expr.is_var_ident + return expr in self._vars def find_variables_rec(self, expr): """Recursive method called by find_variable to expand @expr. @@ -310,7 +278,6 @@ class Variables_Identifier(object): identifier = m2_expr.ExprId("%s%s" % (self.var_prefix, self.var_indice.next()), size = expr.size) - identifier.is_var_ident = True self._vars[identifier] = expr # Recursion stop case @@ -333,8 +300,8 @@ class Variables_Identifier(object): self.find_variables_rec(expr.arg) elif isinstance(expr, m2_expr.ExprCompose): - for a in expr.args: - self.find_variables_rec(list(a)[0]) + for arg in expr.args: + self.find_variables_rec(arg) elif isinstance(expr, m2_expr.ExprSlice): self.find_variables_rec(expr.arg) @@ -455,21 +422,19 @@ class ExprRandom(object): """ # First layer upper_bound = random.randint(1, size) - args = [(cls._gen(size=upper_bound, depth=depth - 1), 0, upper_bound)] + args = [cls._gen(size=upper_bound, depth=depth - 1)] # Next layers while (upper_bound < size): if len(args) == (cls.compose_max_layer - 1): # We reach the maximum size - upper_bound = size + new_upper_bound = size else: - upper_bound = random.randint(args[-1][-1] + 1, size) + new_upper_bound = random.randint(upper_bound + 1, size) - args.append((cls._gen(size=upper_bound - args[-1][-1]), - args[-1][-1], - upper_bound)) - - return m2_expr.ExprCompose(args) + args.append(cls._gen(size=new_upper_bound - upper_bound)) + upper_bound = new_upper_bound + return m2_expr.ExprCompose(*args) @classmethod def memory(cls, size=32, depth=1): @@ -654,22 +619,17 @@ def possible_values(expr): elif isinstance(expr, m2_expr.ExprCompose): # Generate each possibility for sub-argument, associated with the start # and stop bit - consvals_args = [map(lambda x: (x, arg[1], arg[2]), - possible_values(arg[0])) + consvals_args = [map(lambda x: x, possible_values(arg)) for arg in expr.args] for consvals_possibility in itertools.product(*consvals_args): # Merge constraint of each sub-element - args_constraint = itertools.chain(*[consval[0].constraints + args_constraint = itertools.chain(*[consval.constraints for consval in consvals_possibility]) # Gen the corresponding constraints / ExprCompose + args = [consval.value for consval in consvals_possibility] consvals.add( ConstrainedValue(frozenset(args_constraint), - m2_expr.ExprCompose( - [(consval[0].value, - consval[1], - consval[2]) - for consval in consvals_possibility] - ))) + m2_expr.ExprCompose(*args))) else: raise RuntimeError("Unsupported type for expr: %s" % type(expr)) diff --git a/miasm2/expression/simplifications.py b/miasm2/expression/simplifications.py index cbffb219..dd4f5c04 100644 --- a/miasm2/expression/simplifications.py +++ b/miasm2/expression/simplifications.py @@ -48,6 +48,7 @@ class ExpressionSimplifier(object): def __init__(self): self.expr_simp_cb = {} + self.simplified_exprs = set() def enable_passes(self, passes): """Add passes from @passes @@ -80,7 +81,7 @@ class ExpressionSimplifier(object): @expression: Expr instance Return an Expr instance""" - if expression.is_simp: + if expression in self.simplified_exprs: return expression # Find a stable state @@ -92,10 +93,10 @@ class ExpressionSimplifier(object): # Launch recursivity expression = self.expr_simp_wrapper(e_new) - expression.is_simp = True - + self.simplified_exprs.add(expression) # Mark expression as simplified - e_new.is_simp = True + self.simplified_exprs.add(e_new) + return e_new def expr_simp_wrapper(self, expression, callback=None): @@ -104,13 +105,13 @@ class ExpressionSimplifier(object): @manual_callback: If set, call this function instead of normal one Return an Expr instance""" - if expression.is_simp: + if expression in self.simplified_exprs: return expression if callback is None: callback = self.expr_simp - return expression.visit(callback, lambda e: not(e.is_simp)) + return expression.visit(callback, lambda e: e not in self.simplified_exprs) def __call__(self, expression, callback=None): "Wrapper on expr_simp_wrapper" diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 49dfbcc0..a070fb81 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -265,13 +265,14 @@ def simp_cst_propagation(e_s, e): args = new_args # A << int with A ExprCompose => move index - if op == "<<" and isinstance(args[0], ExprCompose) and isinstance(args[1], ExprInt): + if (op == "<<" and isinstance(args[0], ExprCompose) and + isinstance(args[1], ExprInt) and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes - for expr, start, stop in args[0].args: - new_args.append((expr, start+shift, stop+shift)) + for index, arg in args[0].iter_args(): + new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size @@ -281,12 +282,13 @@ def simp_cst_propagation(e_s, e): if stop > final_size: expr = expr[:expr.size - (stop - final_size)] stop = final_size - filter_args.append((expr, start, stop)) + filter_args.append(expr) min_index = min(start, min_index) # create entry 0 + assert min_index != 0 expr = ExprInt(0, min_index) - filter_args = [(expr, 0, min_index)] + filter_args - return ExprCompose(filter_args) + args = [expr] + filter_args + return ExprCompose(*args) # A >> int with A ExprCompose => move index if op == ">>" and isinstance(args[0], ExprCompose) and isinstance(args[1], ExprInt): @@ -294,8 +296,8 @@ def simp_cst_propagation(e_s, e): shift = int(args[1]) new_args = [] # shift indexes - for expr, start, stop in args[0].args: - new_args.append((expr, start-shift, stop-shift)) + for index, arg in args[0].iter_args(): + new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 @@ -305,29 +307,30 @@ def simp_cst_propagation(e_s, e): if start < 0: expr = expr[-start:] start = 0 - filter_args.append((expr, start, stop)) + filter_args.append(expr) max_index = max(stop, max_index) # create entry 0 expr = ExprInt(0, final_size - max_index) - filter_args += [(expr, max_index, final_size)] - return ExprCompose(filter_args) + args = filter_args + [expr] + return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op in ['|', '&', '^'] and all([isinstance(arg, ExprCompose) for arg in args]): bounds = set() for arg in args: - bound = tuple([(start, stop) for (expr, start, stop) in arg.args]) + bound = tuple([expr.size for expr in arg.args]) bounds.add(bound) if len(bounds) == 1: bound = list(bounds)[0] - new_args = [[expr] for (expr, start, stop) in args[0].args] + new_args = [[expr] for expr in args[0].args] for sub_arg in args[1:]: - for i, (expr, start, stop) in enumerate(sub_arg.args): + for i, expr in enumerate(sub_arg.args): new_args[i].append(expr) + args = [] for i, arg in enumerate(new_args): - new_args[i] = ExprOp(op, *arg), bound[i][0], bound[i][1] - return ExprCompose(new_args) + args.append(ExprOp(op, *arg)) + return ExprCompose(*args) # <<<c_rez, >>>c_rez if op in [">>>c_rez", "<<<c_rez"]: @@ -448,40 +451,41 @@ def simp_slice(e_s, e): return new_e elif isinstance(e.arg, ExprCompose): # Slice(Compose(A), x) => Slice(A, y) - for a in e.arg.args: - if a[1] <= e.start and a[2] >= e.stop: - new_e = a[0][e.start - a[1]:e.stop - a[1]] + for index, arg in e.arg.iter_args(): + if index <= e.start and index+arg.size >= e.stop: + new_e = arg[e.start - index:e.stop - index] return new_e # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C out = [] - for arg, s_start, s_stop in e.arg.args: + for index, arg in e.arg.iter_args(): # arg is before slice start - if e.start >= s_stop: + if e.start >= index + arg.size: continue # arg is after slice stop - elif e.stop <= s_start: + elif e.stop <= index: continue # arg is fully included in slice - elif e.start <= s_start and s_stop <= e.stop: - out.append((arg, s_start - e.start, s_stop - e.start)) + elif e.start <= index and index + arg.size <= e.stop: + out.append(arg) continue # arg is truncated at start - if e.start > s_start: - slice_start = e.start - s_start + if e.start > index: + slice_start = e.start - index a_start = 0 else: # arg is not truncated at start slice_start = 0 - a_start = s_start - e.start + a_start = index - e.start # a is truncated at stop - if e.stop < s_stop: - slice_stop = arg.size + e.stop - s_stop - slice_start + if e.stop < index + arg.size: + slice_stop = arg.size + e.stop - (index + arg.size) - slice_start a_stop = e.stop - e.start else: slice_stop = arg.size - a_stop = s_stop - e.start - out.append((arg[slice_start:slice_stop], a_start, a_stop)) - return ExprCompose(out) + a_stop = index + arg.size - e.start + out.append(arg[slice_start:slice_stop]) + + return ExprCompose(*out) # ExprMem(x, size)[:A] => ExprMem(x, a) # XXXX todo hum, is it safe? @@ -533,68 +537,61 @@ def simp_slice(e_s, e): def simp_compose(e_s, e): "Commons simplification on ExprCompose" - args = merge_sliceto_slice(e.args) + args = merge_sliceto_slice(e) out = [] # compose of compose - for a in args: - if isinstance(a[0], ExprCompose): - for x, start, stop in a[0].args: - out.append((x, start + a[1], stop + a[1])) + for arg in args: + if isinstance(arg, ExprCompose): + out += arg.args else: - out.append(a) + out.append(arg) args = out # Compose(a) with a.size = compose.size => a - if len(args) == 1 and args[0][1] == 0 and args[0][2] == e.size: - return args[0][0] + if len(args) == 1 and args[0].size == e.size: + return args[0] # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z) if (len(args) == 2 and - isinstance(args[1][0], ExprInt) and - args[1][0].arg == 0): - a1 = args[0] - a2 = args[1] - if (isinstance(a1[0], ExprSlice) and - a1[1] == 0 and - a1[0].stop == a1[0].arg.size and - a2[1] == a1[0].size and - a2[2] == a1[0].arg.size): - new_e = a1[0].arg >> ExprInt( - a1[0].start, a1[0].arg.size) + isinstance(args[1], ExprInt) and + int(args[1]) == 0): + if (isinstance(args[0], ExprSlice) and + args[0].stop == args[0].arg.size and + args[0].size + args[1].size == args[0].arg.size): + new_e = args[0].arg >> ExprInt(args[0].start, args[0].arg.size) return new_e # Compose with ExprCond with integers for src1/src2 and intergers => # propagage integers # {XXX?(0x0,0x1)?(0x0,0x1),0,8, 0x0,8,32} => XXX?(int1, int2) - ok = True - expr_cond = None - expr_ints = [] - for i, a in enumerate(args): - if not is_int_or_cond_src_int(a[0]): + expr_cond_index = None + expr_ints_or_conds = [] + for i, arg in enumerate(args): + if not is_int_or_cond_src_int(arg): ok = False break - expr_ints.append(a) - if isinstance(a[0], ExprCond): - if expr_cond is not None: + expr_ints_or_conds.append(arg) + if isinstance(arg, ExprCond): + if expr_cond_index is not None: ok = False - expr_cond = i - cond = a[0] + expr_cond_index = i + cond = arg - if ok and expr_cond is not None: + if ok and expr_cond_index is not None: src1 = [] src2 = [] - for i, a in enumerate(expr_ints): - if i == expr_cond: - src1.append((a[0].src1, a[1], a[2])) - src2.append((a[0].src2, a[1], a[2])) + for i, arg in enumerate(expr_ints_or_conds): + if i == expr_cond_index: + src1.append(arg.src1) + src2.append(arg.src2) else: - src1.append(a) - src2.append(a) - src1 = e_s.apply_simp(ExprCompose(src1)) - src2 = e_s.apply_simp(ExprCompose(src2)) + src1.append(arg) + src2.append(arg) + src1 = e_s.apply_simp(ExprCompose(*src1)) + src2 = e_s.apply_simp(ExprCompose(*src2)) if isinstance(src1, ExprInt) and isinstance(src2, ExprInt): return ExprCond(cond.cond, src1, src2) - return ExprCompose(args) + return ExprCompose(*args) def simp_cond(e_s, e): |