about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm2/expression/expression.py88
-rw-r--r--miasm2/expression/simplifications_common.py5
-rw-r--r--test/expression/simplifications.py40
3 files changed, 65 insertions, 68 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py
index e134e503..94c3825d 100644
--- a/miasm2/expression/expression.py
+++ b/miasm2/expression/expression.py
@@ -30,7 +30,7 @@
 
 import itertools
 from operator import itemgetter
-from miasm2.expression.modint import *
+from miasm2.expression.modint import mod_size2uint, is_modint, size2mask
 from miasm2.core.graph import DiGraph
 import warnings
 
@@ -143,10 +143,6 @@ class Expr(object):
             Expr.args2expr[(cls, args)] = expr
         return expr
 
-    def __new__(cls, *args, **kwargs):
-        expr = object.__new__(cls, *args, **kwargs)
-        return expr
-
     def get_is_canon(self):
         return self in Expr.canon_exprs
 
@@ -187,23 +183,18 @@ class Expr(object):
             self.__hash = self._exprhash()
         return self.__hash
 
-    def pre_eq(self, other):
-        """Return True if ids are equal;
-        False if instances are obviously not equal
-        None if we cannot simply decide"""
-
-        if id(self) == id(other):
+    def __eq__(self, other):
+        if self is other:
             return True
+        elif self.use_singleton:
+            # In case of Singleton, pointer comparison is sufficient
+            # Avoid computation of hash and repr
+            return False
+
         if self.__class__ is not other.__class__:
             return False
         if hash(self) != hash(other):
             return False
-        return None
-
-    def __eq__(self, other):
-        res = self.pre_eq(other)
-        if res is not None:
-            return res
         return repr(self) == repr(other)
 
     def __ne__(self, a):
@@ -246,8 +237,7 @@ class Expr(object):
         return ExprOp("**",self, a)
 
     def __invert__(self):
-        s = self.size
-        return ExprOp('^', self, ExprInt(mod_size2uint[s](size2mask(s))))
+        return ExprOp('^', self, self.mask)
 
     def copy(self):
         "Deep copy of the expression"
@@ -398,25 +388,12 @@ class ExprInt(Expr):
     __slots__ = Expr.__slots__ + ["__arg"]
 
 
-    def __init__(self, num, size=None):
+    def __init__(self, arg, size):
         """Create an ExprInt from a modint or num/size
-        @arg: modint or num
-        @size: (optionnal) int size"""
-
+        @arg: 'intable' number
+        @size: int size"""
         super(ExprInt, self).__init__()
-
-        if is_modint(num):
-            self.__arg = num
-            self.__size = self.arg.size
-            if size is not None and num.size != size:
-                raise RuntimeError("size must match modint size")
-        elif size is not None:
-            if size not in mod_size2uint:
-                define_uint(size)
-            self.__arg = mod_size2uint[size](num)
-            self.__size = self.arg.size
-        else:
-            raise ValueError('arg must by modint or (int,size)! %s' % num)
+        # Work is done in __new__
 
     size = property(lambda self: self.__size)
     arg = property(lambda self: self.__arg)
@@ -427,10 +404,29 @@ class ExprInt(Expr):
     def __setstate__(self, state):
         self.__init__(*state)
 
-    def __new__(cls, arg, size=None):
-        if size is None:
-            size = arg.size
-        return Expr.get_object(cls, (arg, size))
+    def __new__(cls, arg, size):
+        """Create an ExprInt from a modint or num/size
+        @arg: 'intable' number
+        @size: int size"""
+
+        if is_modint(arg):
+            assert size == arg.size
+        # Avoid a common blunder
+        assert not isinstance(arg, ExprInt)
+
+        # Ensure arg is always a moduint
+        arg = int(arg)
+        if size not in mod_size2uint:
+            define_uint(size)
+        arg = mod_size2uint[size](arg)
+
+        # Get the Singleton instance
+        expr = Expr.get_object(cls, (arg, size))
+
+        # Save parameters (__init__ is called with parameters unchanged)
+        expr.__arg = arg
+        expr.__size = expr.__arg.size
+        return expr
 
     def __get_int(self):
         "Return self integer representation"
@@ -1321,28 +1317,28 @@ def canonize_expr_list_compose(l):
 
 
 def ExprInt1(i):
-    return ExprInt(uint1(i))
+    return ExprInt(i, 1)
 
 
 def ExprInt8(i):
-    return ExprInt(uint8(i))
+    return ExprInt(i, 8)
 
 
 def ExprInt16(i):
-    return ExprInt(uint16(i))
+    return ExprInt(i, 16)
 
 
 def ExprInt32(i):
-    return ExprInt(uint32(i))
+    return ExprInt(i, 32)
 
 
 def ExprInt64(i):
-    return ExprInt(uint64(i))
+    return ExprInt(i, 64)
 
 
 def ExprInt_from(e, i):
     "Generate ExprInt with size equal to expression"
-    return ExprInt(mod_size2uint[e.size](i))
+    return ExprInt(i, e.size)
 
 
 def get_expr_ids_visit(e, ids):
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 503a0e77..c9b7932a 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -3,6 +3,7 @@
 # ----------------------------- #
 
 
+from miasm2.expression.modint import mod_size2int, mod_size2uint
 from miasm2.expression.expression import *
 from miasm2.expression.expression_helper import *
 
@@ -103,7 +104,7 @@ def simp_cst_propagation(e_s, e):
 
     # -(int) => -int
     if op == '-' and len(args) == 1 and args[0].is_int():
-        return ExprInt(-args[0].arg)
+        return ExprInt(-int(args[0]), e.size)
     # A op 0 =>A
     if op in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1:
         if args[-1].is_int(0):
@@ -237,7 +238,7 @@ def simp_cst_propagation(e_s, e):
 
     # parity(int) => int
     if op == 'parity' and args[0].is_int():
-        return ExprInt1(parity(args[0].arg))
+        return ExprInt1(parity(int(args[0])))
 
     # (-a) * b * (-c) * (-d) => (-a) * b * c * d
     if op == "*" and len(args) > 1:
diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py
index bf658a30..5391fbee 100644
--- a/test/expression/simplifications.py
+++ b/test/expression/simplifications.py
@@ -18,10 +18,10 @@ f = ExprId('f', size=64)
 m = ExprMem(a)
 s = a[:8]
 
-i0 = ExprInt(uint32(0x0))
-i1 = ExprInt(uint32(0x1))
-i2 = ExprInt(uint32(0x2))
-icustom = ExprInt(uint32(0x12345678))
+i0 = ExprInt(0, 32)
+i1 = ExprInt(1, 32)
+i2 = ExprInt(2, 32)
+icustom = ExprInt(0x12345678, 32)
 cc = ExprCond(a, b, c)
 
 o = ExprCompose(a[8:16], a[:8])
@@ -133,7 +133,7 @@ to_test = [(ExprInt32(1) - ExprInt32(1), ExprInt32(0)),
      ExprCond(a, ExprInt32(-0x1), ExprInt32(-0x2))),
     (ExprOp('*', a, b, c, ExprInt32(0x12))[0:17],
      ExprOp(
-     '*', a[0:17], b[0:17], c[0:17], ExprInt(mod_size2uint[17](0x12)))),
+     '*', a[0:17], b[0:17], c[0:17], ExprInt(0x12, 17))),
     (ExprOp('*', a, ExprInt32(0xffffffff)),
      -a),
     (ExprOp('*', -a, -b, c, ExprInt32(0x12)),
@@ -227,32 +227,32 @@ to_test = [(ExprInt32(1) - ExprInt32(1), ExprInt32(0)),
     (ExprCompose(a, b, c)[48:80],
      ExprCompose(b[16:], c[:16])),
 
-    (ExprCompose(a[0:8], b[8:16], ExprInt(uint48(0x0L)))[12:32],
-     ExprCompose(b[12:16], ExprInt(uint16(0)))
+    (ExprCompose(a[0:8], b[8:16], ExprInt(0x0L, 48))[12:32],
+     ExprCompose(b[12:16], ExprInt(0, 16))
        ),
 
-    (ExprCompose(ExprCompose(a[:8], ExprInt(uint56(0x0L)))[8:32]
+    (ExprCompose(ExprCompose(a[:8], ExprInt(0x0L, 56))[8:32]
                   &
-                  ExprInt(uint24(0x1L)),
-                  ExprInt(uint40(0x0L))),
+                  ExprInt(0x1L, 24),
+                  ExprInt(0x0L, 40)),
      ExprInt64(0)),
 
-    (ExprCompose(ExprCompose(a[:8], ExprInt(uint56(0x0L)))[:8]
+    (ExprCompose(ExprCompose(a[:8], ExprInt(0x0L, 56))[:8]
                  &
-                 ExprInt(uint8(0x1L)),
-                 (ExprInt(uint56(0x0L)))),
-     ExprCompose(a[:8]&ExprInt8(1), ExprInt(uint56(0)))),
+                 ExprInt(0x1L, 8),
+                 (ExprInt(0x0L, 56))),
+     ExprCompose(a[:8]&ExprInt8(1), ExprInt(0, 56))),
 
     (ExprCompose(ExprCompose(a[:8],
-                             ExprInt(uint56(0x0L)))[:32]
+                             ExprInt(0x0L, 56))[:32]
                  &
-                 ExprInt(uint32(0x1L)),
-                 ExprInt(uint32(0x0L))),
+                 ExprInt(0x1L, 32),
+                 ExprInt(0x0L, 32)),
      ExprCompose(ExprCompose(ExprSlice(a, 0, 8),
-                             ExprInt(uint24(0x0L)))
+                             ExprInt(0x0L, 24))
                  &
-                 ExprInt(uint32(0x1L)),
-                 ExprInt(uint32(0x0L)))
+                 ExprInt(0x1L, 32),
+                 ExprInt(0x0L, 32))
        ),
     (ExprCompose(a[:16], b[:16])[8:32],
      ExprCompose(a[8:16], b[:16])),