diff options
| -rw-r--r-- | miasm2/expression/expression.py | 88 | ||||
| -rw-r--r-- | miasm2/expression/simplifications_common.py | 5 | ||||
| -rw-r--r-- | test/expression/simplifications.py | 40 |
3 files changed, 65 insertions, 68 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py index e134e503..94c3825d 100644 --- a/miasm2/expression/expression.py +++ b/miasm2/expression/expression.py @@ -30,7 +30,7 @@ import itertools from operator import itemgetter -from miasm2.expression.modint import * +from miasm2.expression.modint import mod_size2uint, is_modint, size2mask from miasm2.core.graph import DiGraph import warnings @@ -143,10 +143,6 @@ class Expr(object): Expr.args2expr[(cls, args)] = expr return expr - def __new__(cls, *args, **kwargs): - expr = object.__new__(cls, *args, **kwargs) - return expr - def get_is_canon(self): return self in Expr.canon_exprs @@ -187,23 +183,18 @@ class Expr(object): self.__hash = self._exprhash() return self.__hash - def pre_eq(self, other): - """Return True if ids are equal; - False if instances are obviously not equal - None if we cannot simply decide""" - - if id(self) == id(other): + def __eq__(self, other): + if self is other: return True + elif self.use_singleton: + # In case of Singleton, pointer comparison is sufficient + # Avoid computation of hash and repr + return False + if self.__class__ is not other.__class__: return False if hash(self) != hash(other): return False - return None - - def __eq__(self, other): - res = self.pre_eq(other) - if res is not None: - return res return repr(self) == repr(other) def __ne__(self, a): @@ -246,8 +237,7 @@ class Expr(object): return ExprOp("**",self, a) def __invert__(self): - s = self.size - return ExprOp('^', self, ExprInt(mod_size2uint[s](size2mask(s)))) + return ExprOp('^', self, self.mask) def copy(self): "Deep copy of the expression" @@ -398,25 +388,12 @@ class ExprInt(Expr): __slots__ = Expr.__slots__ + ["__arg"] - def __init__(self, num, size=None): + def __init__(self, arg, size): """Create an ExprInt from a modint or num/size - @arg: modint or num - @size: (optionnal) int size""" - + @arg: 'intable' number + @size: int size""" super(ExprInt, self).__init__() - - if is_modint(num): - self.__arg = num - self.__size = self.arg.size - if size is not None and num.size != size: - raise RuntimeError("size must match modint size") - elif size is not None: - if size not in mod_size2uint: - define_uint(size) - self.__arg = mod_size2uint[size](num) - self.__size = self.arg.size - else: - raise ValueError('arg must by modint or (int,size)! %s' % num) + # Work is done in __new__ size = property(lambda self: self.__size) arg = property(lambda self: self.__arg) @@ -427,10 +404,29 @@ class ExprInt(Expr): def __setstate__(self, state): self.__init__(*state) - def __new__(cls, arg, size=None): - if size is None: - size = arg.size - return Expr.get_object(cls, (arg, size)) + def __new__(cls, arg, size): + """Create an ExprInt from a modint or num/size + @arg: 'intable' number + @size: int size""" + + if is_modint(arg): + assert size == arg.size + # Avoid a common blunder + assert not isinstance(arg, ExprInt) + + # Ensure arg is always a moduint + arg = int(arg) + if size not in mod_size2uint: + define_uint(size) + arg = mod_size2uint[size](arg) + + # Get the Singleton instance + expr = Expr.get_object(cls, (arg, size)) + + # Save parameters (__init__ is called with parameters unchanged) + expr.__arg = arg + expr.__size = expr.__arg.size + return expr def __get_int(self): "Return self integer representation" @@ -1321,28 +1317,28 @@ def canonize_expr_list_compose(l): def ExprInt1(i): - return ExprInt(uint1(i)) + return ExprInt(i, 1) def ExprInt8(i): - return ExprInt(uint8(i)) + return ExprInt(i, 8) def ExprInt16(i): - return ExprInt(uint16(i)) + return ExprInt(i, 16) def ExprInt32(i): - return ExprInt(uint32(i)) + return ExprInt(i, 32) def ExprInt64(i): - return ExprInt(uint64(i)) + return ExprInt(i, 64) def ExprInt_from(e, i): "Generate ExprInt with size equal to expression" - return ExprInt(mod_size2uint[e.size](i)) + return ExprInt(i, e.size) def get_expr_ids_visit(e, ids): diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py index 503a0e77..c9b7932a 100644 --- a/miasm2/expression/simplifications_common.py +++ b/miasm2/expression/simplifications_common.py @@ -3,6 +3,7 @@ # ----------------------------- # +from miasm2.expression.modint import mod_size2int, mod_size2uint from miasm2.expression.expression import * from miasm2.expression.expression_helper import * @@ -103,7 +104,7 @@ def simp_cst_propagation(e_s, e): # -(int) => -int if op == '-' and len(args) == 1 and args[0].is_int(): - return ExprInt(-args[0].arg) + return ExprInt(-int(args[0]), e.size) # A op 0 =>A if op in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): @@ -237,7 +238,7 @@ def simp_cst_propagation(e_s, e): # parity(int) => int if op == 'parity' and args[0].is_int(): - return ExprInt1(parity(args[0].arg)) + return ExprInt1(parity(int(args[0]))) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op == "*" and len(args) > 1: diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py index bf658a30..5391fbee 100644 --- a/test/expression/simplifications.py +++ b/test/expression/simplifications.py @@ -18,10 +18,10 @@ f = ExprId('f', size=64) m = ExprMem(a) s = a[:8] -i0 = ExprInt(uint32(0x0)) -i1 = ExprInt(uint32(0x1)) -i2 = ExprInt(uint32(0x2)) -icustom = ExprInt(uint32(0x12345678)) +i0 = ExprInt(0, 32) +i1 = ExprInt(1, 32) +i2 = ExprInt(2, 32) +icustom = ExprInt(0x12345678, 32) cc = ExprCond(a, b, c) o = ExprCompose(a[8:16], a[:8]) @@ -133,7 +133,7 @@ to_test = [(ExprInt32(1) - ExprInt32(1), ExprInt32(0)), ExprCond(a, ExprInt32(-0x1), ExprInt32(-0x2))), (ExprOp('*', a, b, c, ExprInt32(0x12))[0:17], ExprOp( - '*', a[0:17], b[0:17], c[0:17], ExprInt(mod_size2uint[17](0x12)))), + '*', a[0:17], b[0:17], c[0:17], ExprInt(0x12, 17))), (ExprOp('*', a, ExprInt32(0xffffffff)), -a), (ExprOp('*', -a, -b, c, ExprInt32(0x12)), @@ -227,32 +227,32 @@ to_test = [(ExprInt32(1) - ExprInt32(1), ExprInt32(0)), (ExprCompose(a, b, c)[48:80], ExprCompose(b[16:], c[:16])), - (ExprCompose(a[0:8], b[8:16], ExprInt(uint48(0x0L)))[12:32], - ExprCompose(b[12:16], ExprInt(uint16(0))) + (ExprCompose(a[0:8], b[8:16], ExprInt(0x0L, 48))[12:32], + ExprCompose(b[12:16], ExprInt(0, 16)) ), - (ExprCompose(ExprCompose(a[:8], ExprInt(uint56(0x0L)))[8:32] + (ExprCompose(ExprCompose(a[:8], ExprInt(0x0L, 56))[8:32] & - ExprInt(uint24(0x1L)), - ExprInt(uint40(0x0L))), + ExprInt(0x1L, 24), + ExprInt(0x0L, 40)), ExprInt64(0)), - (ExprCompose(ExprCompose(a[:8], ExprInt(uint56(0x0L)))[:8] + (ExprCompose(ExprCompose(a[:8], ExprInt(0x0L, 56))[:8] & - ExprInt(uint8(0x1L)), - (ExprInt(uint56(0x0L)))), - ExprCompose(a[:8]&ExprInt8(1), ExprInt(uint56(0)))), + ExprInt(0x1L, 8), + (ExprInt(0x0L, 56))), + ExprCompose(a[:8]&ExprInt8(1), ExprInt(0, 56))), (ExprCompose(ExprCompose(a[:8], - ExprInt(uint56(0x0L)))[:32] + ExprInt(0x0L, 56))[:32] & - ExprInt(uint32(0x1L)), - ExprInt(uint32(0x0L))), + ExprInt(0x1L, 32), + ExprInt(0x0L, 32)), ExprCompose(ExprCompose(ExprSlice(a, 0, 8), - ExprInt(uint24(0x0L))) + ExprInt(0x0L, 24)) & - ExprInt(uint32(0x1L)), - ExprInt(uint32(0x0L))) + ExprInt(0x1L, 32), + ExprInt(0x0L, 32)) ), (ExprCompose(a[:16], b[:16])[8:32], ExprCompose(a[8:16], b[:16])), |