about summary refs log tree commit diff stats
path: root/miasm2/expression/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/expression/expression.py')
-rw-r--r--miasm2/expression/expression.py268
1 files changed, 164 insertions, 104 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py
index d04530c3..ad76f01c 100644
--- a/miasm2/expression/expression.py
+++ b/miasm2/expression/expression.py
@@ -115,27 +115,65 @@ class Expr(object):
     "Parent class for Miasm Expressions"
 
     __slots__ = ["is_term", "is_simp", "is_canon",
-                 "is_eval", "_hash", "_repr", "_size",
+                 "is_eval", "__hash", "__repr", "__size",
                  "is_var_ident"]
 
+    all_exprs = set()
+    args2expr = {}
+    simp_exprs = set()
+    canon_exprs = set()
+    use_singleton = True
+
+    is_term = False   # Terminal expression
 
     def set_size(self, value):
         raise ValueError('size is not mutable')
 
     def __init__(self):
-        self.is_term = False   # Terminal expression
-        self.is_simp = False   # Expression already simplified
-        self.is_canon = False  # Expression already canonised
-        self.is_eval = False   # Expression already evalued
-        self.is_var_ident = False # Expression not identifier
-
-        self._hash = None
-        self._repr = None
+        self.__hash = None
+        self.__repr = None
+        self.__size = None
 
     size = property(lambda self: self._size)
 
+    @staticmethod
+    def get_object(cls, args):
+        if not cls.use_singleton:
+            return object.__new__(cls, args)
+
+        expr = Expr.args2expr.get((cls, args))
+        if expr is None:
+            expr = object.__new__(cls, args)
+            Expr.args2expr[(cls, args)] = expr
+        return expr
+
+    def __new__(cls, *args, **kwargs):
+        expr = object.__new__(cls, *args, **kwargs)
+        return expr
+
+    def get_is_simp(self):
+        return self in Expr.simp_exprs
+
+    def set_is_simp(self, value):
+        assert(value is True)
+        Expr.simp_exprs.add(self)
+
+    is_simp = property(get_is_simp, set_is_simp)
+
+    def get_is_canon(self):
+        return self in Expr.canon_exprs
+
+    def set_is_canon(self, value):
+        assert(value is True)
+        Expr.canon_exprs.add(self)
+
+    is_canon = property(get_is_canon, set_is_canon)
+
     # Common operations
 
+    def __str__(self):
+        raise NotImplementedError("Abstract Method")
+
     def __getitem__(self, i):
         if not isinstance(i, slice):
             raise TypeError("Expression: Bad slice: %s" % i)
@@ -158,9 +196,9 @@ class Expr(object):
         return self._repr
 
     def __hash__(self):
-        if self._hash is None:
-            self._hash = self._exprhash()
-        return self._hash
+        if self.__hash is None:
+            self.__hash = self._exprhash()
+        return self.__hash
 
     def pre_eq(self, other):
         """Return True if ids are equal;
@@ -341,7 +379,7 @@ class ExprInt(Expr):
      - Constant 0x12345678 on 32bits
      """
 
-    __slots__ = ["_arg"]
+    __slots__ = ["__arg"]
 
     def __init__(self, num, size=None):
         """Create an ExprInt from a modint or num/size
@@ -361,14 +399,14 @@ class ExprInt(Expr):
         else:
             raise ValueError('arg must by modint or (int,size)! %s' % num)
 
-    arg = property(lambda self: self._arg)
+        self.__arg = arg
+        self.__size = self.arg.size
 
-    def __eq__(self, other):
-        res = self.pre_eq(other)
-        if res is not None:
-            return res
-        return (self._arg == other._arg and
-                self._size == other._size)
+    size = property(lambda self: self.__size)
+    arg = property(lambda self: self.__arg)
+
+    def __new__(cls, arg):
+        return Expr.get_object(cls, (arg, arg.size))
 
     def __get_int(self):
         "Return self integer representation"
@@ -393,7 +431,7 @@ class ExprInt(Expr):
         return hash((EXPRINT, self._arg, self._size))
 
     def _exprrepr(self):
-        return "%s(%r)" % (self.__class__.__name__, self._arg)
+        return "%s(0x%X)" % (self.__class__.__name__, self.__get_int())
 
     def __contains__(self, e):
         return self == e
@@ -437,16 +475,14 @@ class ExprId(Expr):
         """
         super(ExprId, self).__init__()
 
-        self._name, self._size = name, size
+        self.__name, self.__size = name, size
+        self.is_term = is_term
 
-    name = property(lambda self: self._name)
+    size = property(lambda self: self.__size)
+    name = property(lambda self: self.__name)
 
-    def __eq__(self, other):
-        res = self.pre_eq(other)
-        if res is not None:
-            return res
-        return (self._name == other._name and
-                self._size == other._size)
+    def __new__(cls, name, size=32):
+        return Expr.get_object(cls, (name, size))
 
     def __str__(self):
         return str(self._name)
@@ -459,10 +495,10 @@ class ExprId(Expr):
 
     def _exprhash(self):
         # TODO XXX: hash size ??
-        return hash((EXPRID, self._name, self._size))
+        return hash((EXPRID, self.__name, self.__size))
 
     def _exprrepr(self):
-        return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size)
+        return "%s(%r, %d)" % (self.__class__.__name__, self.__name, self.__size)
 
     def __contains__(self, e):
         return self == e
@@ -472,7 +508,7 @@ class ExprId(Expr):
         return self
 
     def copy(self):
-        return ExprId(self._name, self._size)
+        return ExprId(self.__name, self.__size)
 
     def depth(self):
         return 1
@@ -506,20 +542,24 @@ class ExprAff(Expr):
 
         if isinstance(dst, ExprSlice):
             # Complete the source with missing slice parts
-            self._dst = dst.arg
+            self.__dst = dst.arg
             rest = [(ExprSlice(dst.arg, r[0], r[1]), r[0], r[1])
                     for r in dst.slice_rest()]
             all_a = [(src, dst.start, dst.stop)] + rest
             all_a.sort(key=lambda x: x[1])
-            self._src = ExprCompose(all_a)
+            self.__src = ExprCompose(all_a)
 
         else:
-            self._dst, self._src = dst, src
+            self.__dst, self.__src = dst, src
+
+        self.__size = self.dst.size
 
-        self._size = self.dst.size
+    size = property(lambda self: self.__size)
+    dst = property(lambda self: self.__dst)
+    src = property(lambda self: self.__src)
 
-    dst = property(lambda self: self._dst)
-    src = property(lambda self: self._src)
+    def __new__(cls, dst, src):
+        return Expr.get_object(cls, (dst, src))
 
     def __str__(self):
         return "%s = %s" % (str(self._dst), str(self._src))
@@ -542,8 +582,10 @@ class ExprAff(Expr):
     def _exprrepr(self):
         return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src)
 
-    def __contains__(self, e):
-        return self == e or self._src.__contains__(e) or self._dst.__contains__(e)
+    def __contains__(self, expr):
+        return (self == expr or
+                self._src.__contains__(expr) or
+                self._dst.__contains__(expr))
 
     # XXX /!\ for hackish expraff to slice
     def get_modified_slice(self):
@@ -605,23 +647,26 @@ class ExprCond(Expr):
 
         super(ExprCond, self).__init__()
 
+        self.__cond, self.__src1, self.__src2 = cond, src1, src2
         assert(src1.size == src2.size)
+        self.__size = self.src1.size
 
-        self._cond, self._src1, self._src2 = cond, src1, src2
-        self._size = self.src1.size
+    size = property(lambda self: self.__size)
+    cond = property(lambda self: self.__cond)
+    src1 = property(lambda self: self.__src1)
+    src2 = property(lambda self: self.__src2)
 
-    cond = property(lambda self: self._cond)
-    src1 = property(lambda self: self._src1)
-    src2 = property(lambda self: self._src2)
+    def __new__(cls, cond, src1, src2):
+        return Expr.get_object(cls, (cond, src1, src2))
 
     def __str__(self):
         return "(%s?(%s,%s))" % (str(self._cond), str(self._src1), str(self._src2))
 
     def get_r(self, mem_read=False, cst_read=False):
-        out_src1 = self._src1.get_r(mem_read, cst_read)
-        out_src2 = self._src2.get_r(mem_read, cst_read)
-        return self._cond.get_r(mem_read,
-                                cst_read).union(out_src1).union(out_src2)
+        out_src1 = self.src1.get_r(mem_read, cst_read)
+        out_src2 = self.src2.get_r(mem_read, cst_read)
+        return self.cond.get_r(mem_read,
+                               cst_read).union(out_src1).union(out_src2)
 
     def get_w(self):
         return set()
@@ -636,9 +681,9 @@ class ExprCond(Expr):
 
     def __contains__(self, e):
         return (self == e or
-                self._cond.__contains__(e) or
-                self._src1.__contains__(e) or
-                self._src2.__contains__(e))
+                self.cond.__contains__(e) or
+                self.src1.__contains__(e) or
+                self.src2.__contains__(e))
 
     @visit_chk
     def visit(self, cb, tv=None):
@@ -691,12 +736,16 @@ class ExprMem(Expr):
             raise ValueError(
                 'ExprMem: arg must be an Expr (not %s)' % type(arg))
 
-        self._arg, self._size = arg, size
+        self.__arg, self.__size = arg, size
+
+    size = property(lambda self: self.__size)
+    arg = property(lambda self: self.__arg)
 
-    arg = property(lambda self: self._arg)
+    def __new__(cls, arg, size=32):
+        return Expr.get_object(cls, (arg, size))
 
     def __str__(self):
-        return "@%d[%s]" % (self._size, str(self._arg))
+        return "@%d[%s]" % (self.size, str(self.arg))
 
     def get_r(self, mem_read=False, cst_read=False):
         if mem_read:
@@ -714,19 +763,19 @@ class ExprMem(Expr):
         return "%s(%r, %r)" % (self.__class__.__name__,
                                self._arg, self._size)
 
-    def __contains__(self, e):
-        return self == e or self._arg.__contains__(e)
+    def __contains__(self, expr):
+        return self == expr or self._arg.__contains__(expr)
 
     @visit_chk
     def visit(self, cb, tv=None):
         arg = self._arg.visit(cb, tv)
         if arg == self._arg:
             return self
-        return ExprMem(arg, self._size)
+        return ExprMem(arg, self.size)
 
     def copy(self):
-        arg = self._arg.copy()
-        return ExprMem(arg, size=self._size)
+        arg = self.arg.copy()
+        return ExprMem(arg, size=self.size)
 
     def is_op_segm(self):
         return isinstance(self._arg, ExprOp) and self._arg.op == 'segm'
@@ -772,43 +821,43 @@ class ExprOp(Expr):
         if not isinstance(op, str):
             raise ValueError("ExprOp: 'op' argument must be a string")
 
-        self._op, self._args = op, tuple(args)
+        self.__op, self._args = op, tuple(args)
 
         # Set size for special cases
-        if self._op in [
+        if self.__op in [
                 '==', 'parity', 'fcom_c0', 'fcom_c1', 'fcom_c2', 'fcom_c3',
                 'fxam_c0', 'fxam_c1', 'fxam_c2', 'fxam_c3',
                 "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf",
                 "ucomiss_zf", "ucomiss_pf", "ucomiss_cf"]:
             sz = 1
-        elif self._op in [TOK_INF, TOK_INF_SIGNED,
-                          TOK_INF_UNSIGNED, TOK_INF_EQUAL,
-                          TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
-                          TOK_EQUAL, TOK_POS,
-                          TOK_POS_STRICT,
-                          ]:
+        elif self.__op in [TOK_INF, TOK_INF_SIGNED,
+                           TOK_INF_UNSIGNED, TOK_INF_EQUAL,
+                           TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
+                           TOK_EQUAL, TOK_POS,
+                           TOK_POS_STRICT,
+                           ]:
             sz = 1
-        elif self._op in ['mem_16_to_double', 'mem_32_to_double',
-                          'mem_64_to_double', 'mem_80_to_double',
-                          'int_16_to_double', 'int_32_to_double',
-                          'int_64_to_double', 'int_80_to_double']:
+        elif self.__op in ['mem_16_to_double', 'mem_32_to_double',
+                           'mem_64_to_double', 'mem_80_to_double',
+                           'int_16_to_double', 'int_32_to_double',
+                           'int_64_to_double', 'int_80_to_double']:
             sz = 64
-        elif self._op in ['double_to_mem_16', 'double_to_int_16',
-                          'float_trunc_to_int_16', 'double_trunc_to_int_16']:
+        elif self.__op in ['double_to_mem_16', 'double_to_int_16',
+                           'float_trunc_to_int_16', 'double_trunc_to_int_16']:
             sz = 16
-        elif self._op in ['double_to_mem_32', 'double_to_int_32',
-                          'float_trunc_to_int_32', 'double_trunc_to_int_32',
-                          'double_to_float']:
+        elif self.__op in ['double_to_mem_32', 'double_to_int_32',
+                           'float_trunc_to_int_32', 'double_trunc_to_int_32',
+                           'double_to_float']:
             sz = 32
-        elif self._op in ['double_to_mem_64', 'double_to_int_64',
-                          'float_trunc_to_int_64', 'double_trunc_to_int_64',
-                          'float_to_double']:
+        elif self.__op in ['double_to_mem_64', 'double_to_int_64',
+                           'float_trunc_to_int_64', 'double_trunc_to_int_64',
+                           'float_to_double']:
             sz = 64
-        elif self._op in ['double_to_mem_80', 'double_to_int_80',
-                          'float_trunc_to_int_80',
-                          'double_trunc_to_int_80']:
+        elif self.__op in ['double_to_mem_80', 'double_to_int_80',
+                           'float_trunc_to_int_80',
+                           'double_trunc_to_int_80']:
             sz = 80
-        elif self._op in ['segm']:
+        elif self.__op in ['segm']:
             sz = self._args[1].size
         else:
             if None in sizes:
@@ -817,10 +866,14 @@ class ExprOp(Expr):
                 # All arguments have the same size
                 sz = list(sizes)[0]
 
-        self._size = sz
+        self.__size = sz
+
+    size = property(lambda self: self.__size)
+    op = property(lambda self: self.__op)
+    args = property(lambda self: self.__args)
 
-    op = property(lambda self: self._op)
-    args = property(lambda self: self._args)
+    def __new__(cls, op, *args):
+        return Expr.get_object(cls, (op, args))
 
     def __str__(self):
         if self.is_associative():
@@ -840,7 +893,7 @@ class ExprOp(Expr):
 
     def get_r(self, mem_read=False, cst_read=False):
         return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
+                      elements.union(arg.get_r(mem_read, cst_read)), self.__args, set())
 
     def get_w(self):
         raise ValueError('op cannot be written!', self)
@@ -903,13 +956,16 @@ class ExprSlice(Expr):
         super(ExprSlice, self).__init__()
 
         assert(start < stop)
+        self.__arg, self.__start, self.__stop = arg, start, stop
+        self.__size = self.__stop - self.__start
 
-        self._arg, self._start, self._stop = arg, start, stop
-        self._size = self._stop - self._start
+    size = property(lambda self: self.__size)
+    arg = property(lambda self: self.__arg)
+    start = property(lambda self: self.__start)
+    stop = property(lambda self: self.__stop)
 
-    arg = property(lambda self: self._arg)
-    start = property(lambda self: self._start)
-    stop = property(lambda self: self._stop)
+    def __new__(cls, arg, start, stop):
+        return Expr.get_object(cls, (arg, start, stop))
 
     def __str__(self):
         return "%s[%d:%d]" % (str(self._arg), self._start, self._stop)
@@ -927,10 +983,10 @@ class ExprSlice(Expr):
         return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg,
                                    self._start, self._stop)
 
-    def __contains__(self, e):
-        if self == e:
+    def __contains__(self, expr):
+        if self == expr:
             return True
-        return self._arg.__contains__(e)
+        return self.__arg.__contains__(expr)
 
     @visit_chk
     def visit(self, cb, tv=None):
@@ -1009,31 +1065,35 @@ class ExprCompose(Expr):
         for e, a, b in args:
             assert(a >= 0 and b >= 0)
             o.append(tuple([e, a, b]))
-        self._args = tuple(o)
+        self.__args = tuple(o)
+
+        self.__size = self.__args[-1][2]
 
-        self._size = self._args[-1][2]
+    size = property(lambda self: self.__size)
+    args = property(lambda self: self.__args)
 
-    args = property(lambda self: self._args)
+    def __new__(cls, args):
+        return Expr.get_object(cls, tuple(args))
 
     def __str__(self):
         return '{' + ', '.join(['%s,%d,%d' %
-                                (str(arg[0]), arg[1], arg[2]) for arg in self._args]) + '}'
+                                (str(arg[0]), arg[1], arg[2]) for arg in self.__args]) + '}'
 
     def get_r(self, mem_read=False, cst_read=False):
         return reduce(lambda elements, arg:
-                      elements.union(arg[0].get_r(mem_read, cst_read)), self._args, set())
+                      elements.union(arg[0].get_r(mem_read, cst_read)), self.__args, set())
 
     def get_w(self):
         return reduce(lambda elements, arg:
-                      elements.union(arg[0].get_w()), self._args, set())
+                      elements.union(arg[0].get_w()), self.__args, set())
 
     def _exprhash(self):
         h_args = [EXPRCOMPOSE] + [(hash(arg[0]), arg[1], arg[2])
-                                  for arg in self._args]
+                                  for arg in self.__args]
         return hash(tuple(h_args))
 
     def _exprrepr(self):
-        return "%s(%r)" % (self.__class__.__name__, self._args)
+        return "%s(%r)" % (self.__class__.__name__, self.__args)
 
     def __contains__(self, e):
         if self == e: