about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorserpilliere <fabrice.desclaux@cea.fr>2015-02-21 20:51:31 +0100
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2016-11-04 16:45:46 +0100
commit8e8a60b5d55db6209d05577d038ed3b4dc961b60 (patch)
tree2cb6c1b7133d073897abe2e072662369ec60ee57
parenta15e0faca425c6e2591448e510bf14f1c3f04e14 (diff)
downloadmiasm-8e8a60b5d55db6209d05577d038ed3b4dc961b60.tar.gz
miasm-8e8a60b5d55db6209d05577d038ed3b4dc961b60.zip
Expression: Use singleton pattern for Expression
Start the transformation of Expression into immutable.

Multiple problems were present in Expression class. One of them was
comparison done through hash, which could generate collisions. The
attributes is_simp/is_canon where linked to the instance, and could not
survive to expression simplification.
Diffstat (limited to '')
-rw-r--r--miasm2/core/sembuilder.py6
-rw-r--r--miasm2/expression/expression.py268
-rw-r--r--miasm2/expression/expression_helper.py2
3 files changed, 168 insertions, 108 deletions
diff --git a/miasm2/core/sembuilder.py b/miasm2/core/sembuilder.py
index ce327ce1..7f80b64e 100644
--- a/miasm2/core/sembuilder.py
+++ b/miasm2/core/sembuilder.py
@@ -16,7 +16,7 @@ class MiasmTransformer(ast.NodeTransformer):
     X if Y else Z -> ExprCond(Y, X, Z)
     'X'(Y)        -> ExprOp('X', Y)
     ('X' % Y)(Z)  -> ExprOp('X' % Y, Z)
-    {a, b}        -> ExprCompose([a, 0, a.size], [b, a.size, a.size + b.size])
+    {a, b}        -> ExprCompose([(a, 0, a.size), (b, a.size, a.size + b.size)])
     """
 
     # Parsers
@@ -95,7 +95,7 @@ class MiasmTransformer(ast.NodeTransformer):
         return call
 
     def visit_Set(self, node):
-        "{a, b} -> ExprCompose([a, 0, a.size], [b, a.size, a.size + b.size])"
+        "{a, b} -> ExprCompose([(a, 0, a.size)], (b, a.size, a.size + b.size)])"
         if len(node.elts) == 0:
             return node
 
@@ -109,7 +109,7 @@ class MiasmTransformer(ast.NodeTransformer):
                                   right=ast.Attribute(value=elt,
                                                       attr='size',
                                                       ctx=ast.Load()))
-            new_elts.append(ast.List(elts=[elt, index, new_index],
+            new_elts.append(ast.Tuple(elts=[elt, index, new_index],
                                      ctx=ast.Load()))
             index = new_index
         return ast.Call(func=ast.Name(id='ExprCompose',
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py
index d04530c3..ad76f01c 100644
--- a/miasm2/expression/expression.py
+++ b/miasm2/expression/expression.py
@@ -115,27 +115,65 @@ class Expr(object):
     "Parent class for Miasm Expressions"
 
     __slots__ = ["is_term", "is_simp", "is_canon",
-                 "is_eval", "_hash", "_repr", "_size",
+                 "is_eval", "__hash", "__repr", "__size",
                  "is_var_ident"]
 
+    all_exprs = set()
+    args2expr = {}
+    simp_exprs = set()
+    canon_exprs = set()
+    use_singleton = True
+
+    is_term = False   # Terminal expression
 
     def set_size(self, value):
         raise ValueError('size is not mutable')
 
     def __init__(self):
-        self.is_term = False   # Terminal expression
-        self.is_simp = False   # Expression already simplified
-        self.is_canon = False  # Expression already canonised
-        self.is_eval = False   # Expression already evalued
-        self.is_var_ident = False # Expression not identifier
-
-        self._hash = None
-        self._repr = None
+        self.__hash = None
+        self.__repr = None
+        self.__size = None
 
     size = property(lambda self: self._size)
 
+    @staticmethod
+    def get_object(cls, args):
+        if not cls.use_singleton:
+            return object.__new__(cls, args)
+
+        expr = Expr.args2expr.get((cls, args))
+        if expr is None:
+            expr = object.__new__(cls, args)
+            Expr.args2expr[(cls, args)] = expr
+        return expr
+
+    def __new__(cls, *args, **kwargs):
+        expr = object.__new__(cls, *args, **kwargs)
+        return expr
+
+    def get_is_simp(self):
+        return self in Expr.simp_exprs
+
+    def set_is_simp(self, value):
+        assert(value is True)
+        Expr.simp_exprs.add(self)
+
+    is_simp = property(get_is_simp, set_is_simp)
+
+    def get_is_canon(self):
+        return self in Expr.canon_exprs
+
+    def set_is_canon(self, value):
+        assert(value is True)
+        Expr.canon_exprs.add(self)
+
+    is_canon = property(get_is_canon, set_is_canon)
+
     # Common operations
 
+    def __str__(self):
+        raise NotImplementedError("Abstract Method")
+
     def __getitem__(self, i):
         if not isinstance(i, slice):
             raise TypeError("Expression: Bad slice: %s" % i)
@@ -158,9 +196,9 @@ class Expr(object):
         return self._repr
 
     def __hash__(self):
-        if self._hash is None:
-            self._hash = self._exprhash()
-        return self._hash
+        if self.__hash is None:
+            self.__hash = self._exprhash()
+        return self.__hash
 
     def pre_eq(self, other):
         """Return True if ids are equal;
@@ -341,7 +379,7 @@ class ExprInt(Expr):
      - Constant 0x12345678 on 32bits
      """
 
-    __slots__ = ["_arg"]
+    __slots__ = ["__arg"]
 
     def __init__(self, num, size=None):
         """Create an ExprInt from a modint or num/size
@@ -361,14 +399,14 @@ class ExprInt(Expr):
         else:
             raise ValueError('arg must by modint or (int,size)! %s' % num)
 
-    arg = property(lambda self: self._arg)
+        self.__arg = arg
+        self.__size = self.arg.size
 
-    def __eq__(self, other):
-        res = self.pre_eq(other)
-        if res is not None:
-            return res
-        return (self._arg == other._arg and
-                self._size == other._size)
+    size = property(lambda self: self.__size)
+    arg = property(lambda self: self.__arg)
+
+    def __new__(cls, arg):
+        return Expr.get_object(cls, (arg, arg.size))
 
     def __get_int(self):
         "Return self integer representation"
@@ -393,7 +431,7 @@ class ExprInt(Expr):
         return hash((EXPRINT, self._arg, self._size))
 
     def _exprrepr(self):
-        return "%s(%r)" % (self.__class__.__name__, self._arg)
+        return "%s(0x%X)" % (self.__class__.__name__, self.__get_int())
 
     def __contains__(self, e):
         return self == e
@@ -437,16 +475,14 @@ class ExprId(Expr):
         """
         super(ExprId, self).__init__()
 
-        self._name, self._size = name, size
+        self.__name, self.__size = name, size
+        self.is_term = is_term
 
-    name = property(lambda self: self._name)
+    size = property(lambda self: self.__size)
+    name = property(lambda self: self.__name)
 
-    def __eq__(self, other):
-        res = self.pre_eq(other)
-        if res is not None:
-            return res
-        return (self._name == other._name and
-                self._size == other._size)
+    def __new__(cls, name, size=32):
+        return Expr.get_object(cls, (name, size))
 
     def __str__(self):
         return str(self._name)
@@ -459,10 +495,10 @@ class ExprId(Expr):
 
     def _exprhash(self):
         # TODO XXX: hash size ??
-        return hash((EXPRID, self._name, self._size))
+        return hash((EXPRID, self.__name, self.__size))
 
     def _exprrepr(self):
-        return "%s(%r, %d)" % (self.__class__.__name__, self._name, self._size)
+        return "%s(%r, %d)" % (self.__class__.__name__, self.__name, self.__size)
 
     def __contains__(self, e):
         return self == e
@@ -472,7 +508,7 @@ class ExprId(Expr):
         return self
 
     def copy(self):
-        return ExprId(self._name, self._size)
+        return ExprId(self.__name, self.__size)
 
     def depth(self):
         return 1
@@ -506,20 +542,24 @@ class ExprAff(Expr):
 
         if isinstance(dst, ExprSlice):
             # Complete the source with missing slice parts
-            self._dst = dst.arg
+            self.__dst = dst.arg
             rest = [(ExprSlice(dst.arg, r[0], r[1]), r[0], r[1])
                     for r in dst.slice_rest()]
             all_a = [(src, dst.start, dst.stop)] + rest
             all_a.sort(key=lambda x: x[1])
-            self._src = ExprCompose(all_a)
+            self.__src = ExprCompose(all_a)
 
         else:
-            self._dst, self._src = dst, src
+            self.__dst, self.__src = dst, src
+
+        self.__size = self.dst.size
 
-        self._size = self.dst.size
+    size = property(lambda self: self.__size)
+    dst = property(lambda self: self.__dst)
+    src = property(lambda self: self.__src)
 
-    dst = property(lambda self: self._dst)
-    src = property(lambda self: self._src)
+    def __new__(cls, dst, src):
+        return Expr.get_object(cls, (dst, src))
 
     def __str__(self):
         return "%s = %s" % (str(self._dst), str(self._src))
@@ -542,8 +582,10 @@ class ExprAff(Expr):
     def _exprrepr(self):
         return "%s(%r, %r)" % (self.__class__.__name__, self._dst, self._src)
 
-    def __contains__(self, e):
-        return self == e or self._src.__contains__(e) or self._dst.__contains__(e)
+    def __contains__(self, expr):
+        return (self == expr or
+                self._src.__contains__(expr) or
+                self._dst.__contains__(expr))
 
     # XXX /!\ for hackish expraff to slice
     def get_modified_slice(self):
@@ -605,23 +647,26 @@ class ExprCond(Expr):
 
         super(ExprCond, self).__init__()
 
+        self.__cond, self.__src1, self.__src2 = cond, src1, src2
         assert(src1.size == src2.size)
+        self.__size = self.src1.size
 
-        self._cond, self._src1, self._src2 = cond, src1, src2
-        self._size = self.src1.size
+    size = property(lambda self: self.__size)
+    cond = property(lambda self: self.__cond)
+    src1 = property(lambda self: self.__src1)
+    src2 = property(lambda self: self.__src2)
 
-    cond = property(lambda self: self._cond)
-    src1 = property(lambda self: self._src1)
-    src2 = property(lambda self: self._src2)
+    def __new__(cls, cond, src1, src2):
+        return Expr.get_object(cls, (cond, src1, src2))
 
     def __str__(self):
         return "(%s?(%s,%s))" % (str(self._cond), str(self._src1), str(self._src2))
 
     def get_r(self, mem_read=False, cst_read=False):
-        out_src1 = self._src1.get_r(mem_read, cst_read)
-        out_src2 = self._src2.get_r(mem_read, cst_read)
-        return self._cond.get_r(mem_read,
-                                cst_read).union(out_src1).union(out_src2)
+        out_src1 = self.src1.get_r(mem_read, cst_read)
+        out_src2 = self.src2.get_r(mem_read, cst_read)
+        return self.cond.get_r(mem_read,
+                               cst_read).union(out_src1).union(out_src2)
 
     def get_w(self):
         return set()
@@ -636,9 +681,9 @@ class ExprCond(Expr):
 
     def __contains__(self, e):
         return (self == e or
-                self._cond.__contains__(e) or
-                self._src1.__contains__(e) or
-                self._src2.__contains__(e))
+                self.cond.__contains__(e) or
+                self.src1.__contains__(e) or
+                self.src2.__contains__(e))
 
     @visit_chk
     def visit(self, cb, tv=None):
@@ -691,12 +736,16 @@ class ExprMem(Expr):
             raise ValueError(
                 'ExprMem: arg must be an Expr (not %s)' % type(arg))
 
-        self._arg, self._size = arg, size
+        self.__arg, self.__size = arg, size
+
+    size = property(lambda self: self.__size)
+    arg = property(lambda self: self.__arg)
 
-    arg = property(lambda self: self._arg)
+    def __new__(cls, arg, size=32):
+        return Expr.get_object(cls, (arg, size))
 
     def __str__(self):
-        return "@%d[%s]" % (self._size, str(self._arg))
+        return "@%d[%s]" % (self.size, str(self.arg))
 
     def get_r(self, mem_read=False, cst_read=False):
         if mem_read:
@@ -714,19 +763,19 @@ class ExprMem(Expr):
         return "%s(%r, %r)" % (self.__class__.__name__,
                                self._arg, self._size)
 
-    def __contains__(self, e):
-        return self == e or self._arg.__contains__(e)
+    def __contains__(self, expr):
+        return self == expr or self._arg.__contains__(expr)
 
     @visit_chk
     def visit(self, cb, tv=None):
         arg = self._arg.visit(cb, tv)
         if arg == self._arg:
             return self
-        return ExprMem(arg, self._size)
+        return ExprMem(arg, self.size)
 
     def copy(self):
-        arg = self._arg.copy()
-        return ExprMem(arg, size=self._size)
+        arg = self.arg.copy()
+        return ExprMem(arg, size=self.size)
 
     def is_op_segm(self):
         return isinstance(self._arg, ExprOp) and self._arg.op == 'segm'
@@ -772,43 +821,43 @@ class ExprOp(Expr):
         if not isinstance(op, str):
             raise ValueError("ExprOp: 'op' argument must be a string")
 
-        self._op, self._args = op, tuple(args)
+        self.__op, self._args = op, tuple(args)
 
         # Set size for special cases
-        if self._op in [
+        if self.__op in [
                 '==', 'parity', 'fcom_c0', 'fcom_c1', 'fcom_c2', 'fcom_c3',
                 'fxam_c0', 'fxam_c1', 'fxam_c2', 'fxam_c3',
                 "access_segment_ok", "load_segment_limit_ok", "bcdadd_cf",
                 "ucomiss_zf", "ucomiss_pf", "ucomiss_cf"]:
             sz = 1
-        elif self._op in [TOK_INF, TOK_INF_SIGNED,
-                          TOK_INF_UNSIGNED, TOK_INF_EQUAL,
-                          TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
-                          TOK_EQUAL, TOK_POS,
-                          TOK_POS_STRICT,
-                          ]:
+        elif self.__op in [TOK_INF, TOK_INF_SIGNED,
+                           TOK_INF_UNSIGNED, TOK_INF_EQUAL,
+                           TOK_INF_EQUAL_SIGNED, TOK_INF_EQUAL_UNSIGNED,
+                           TOK_EQUAL, TOK_POS,
+                           TOK_POS_STRICT,
+                           ]:
             sz = 1
-        elif self._op in ['mem_16_to_double', 'mem_32_to_double',
-                          'mem_64_to_double', 'mem_80_to_double',
-                          'int_16_to_double', 'int_32_to_double',
-                          'int_64_to_double', 'int_80_to_double']:
+        elif self.__op in ['mem_16_to_double', 'mem_32_to_double',
+                           'mem_64_to_double', 'mem_80_to_double',
+                           'int_16_to_double', 'int_32_to_double',
+                           'int_64_to_double', 'int_80_to_double']:
             sz = 64
-        elif self._op in ['double_to_mem_16', 'double_to_int_16',
-                          'float_trunc_to_int_16', 'double_trunc_to_int_16']:
+        elif self.__op in ['double_to_mem_16', 'double_to_int_16',
+                           'float_trunc_to_int_16', 'double_trunc_to_int_16']:
             sz = 16
-        elif self._op in ['double_to_mem_32', 'double_to_int_32',
-                          'float_trunc_to_int_32', 'double_trunc_to_int_32',
-                          'double_to_float']:
+        elif self.__op in ['double_to_mem_32', 'double_to_int_32',
+                           'float_trunc_to_int_32', 'double_trunc_to_int_32',
+                           'double_to_float']:
             sz = 32
-        elif self._op in ['double_to_mem_64', 'double_to_int_64',
-                          'float_trunc_to_int_64', 'double_trunc_to_int_64',
-                          'float_to_double']:
+        elif self.__op in ['double_to_mem_64', 'double_to_int_64',
+                           'float_trunc_to_int_64', 'double_trunc_to_int_64',
+                           'float_to_double']:
             sz = 64
-        elif self._op in ['double_to_mem_80', 'double_to_int_80',
-                          'float_trunc_to_int_80',
-                          'double_trunc_to_int_80']:
+        elif self.__op in ['double_to_mem_80', 'double_to_int_80',
+                           'float_trunc_to_int_80',
+                           'double_trunc_to_int_80']:
             sz = 80
-        elif self._op in ['segm']:
+        elif self.__op in ['segm']:
             sz = self._args[1].size
         else:
             if None in sizes:
@@ -817,10 +866,14 @@ class ExprOp(Expr):
                 # All arguments have the same size
                 sz = list(sizes)[0]
 
-        self._size = sz
+        self.__size = sz
+
+    size = property(lambda self: self.__size)
+    op = property(lambda self: self.__op)
+    args = property(lambda self: self.__args)
 
-    op = property(lambda self: self._op)
-    args = property(lambda self: self._args)
+    def __new__(cls, op, *args):
+        return Expr.get_object(cls, (op, args))
 
     def __str__(self):
         if self.is_associative():
@@ -840,7 +893,7 @@ class ExprOp(Expr):
 
     def get_r(self, mem_read=False, cst_read=False):
         return reduce(lambda elements, arg:
-                      elements.union(arg.get_r(mem_read, cst_read)), self._args, set())
+                      elements.union(arg.get_r(mem_read, cst_read)), self.__args, set())
 
     def get_w(self):
         raise ValueError('op cannot be written!', self)
@@ -903,13 +956,16 @@ class ExprSlice(Expr):
         super(ExprSlice, self).__init__()
 
         assert(start < stop)
+        self.__arg, self.__start, self.__stop = arg, start, stop
+        self.__size = self.__stop - self.__start
 
-        self._arg, self._start, self._stop = arg, start, stop
-        self._size = self._stop - self._start
+    size = property(lambda self: self.__size)
+    arg = property(lambda self: self.__arg)
+    start = property(lambda self: self.__start)
+    stop = property(lambda self: self.__stop)
 
-    arg = property(lambda self: self._arg)
-    start = property(lambda self: self._start)
-    stop = property(lambda self: self._stop)
+    def __new__(cls, arg, start, stop):
+        return Expr.get_object(cls, (arg, start, stop))
 
     def __str__(self):
         return "%s[%d:%d]" % (str(self._arg), self._start, self._stop)
@@ -927,10 +983,10 @@ class ExprSlice(Expr):
         return "%s(%r, %d, %d)" % (self.__class__.__name__, self._arg,
                                    self._start, self._stop)
 
-    def __contains__(self, e):
-        if self == e:
+    def __contains__(self, expr):
+        if self == expr:
             return True
-        return self._arg.__contains__(e)
+        return self.__arg.__contains__(expr)
 
     @visit_chk
     def visit(self, cb, tv=None):
@@ -1009,31 +1065,35 @@ class ExprCompose(Expr):
         for e, a, b in args:
             assert(a >= 0 and b >= 0)
             o.append(tuple([e, a, b]))
-        self._args = tuple(o)
+        self.__args = tuple(o)
+
+        self.__size = self.__args[-1][2]
 
-        self._size = self._args[-1][2]
+    size = property(lambda self: self.__size)
+    args = property(lambda self: self.__args)
 
-    args = property(lambda self: self._args)
+    def __new__(cls, args):
+        return Expr.get_object(cls, tuple(args))
 
     def __str__(self):
         return '{' + ', '.join(['%s,%d,%d' %
-                                (str(arg[0]), arg[1], arg[2]) for arg in self._args]) + '}'
+                                (str(arg[0]), arg[1], arg[2]) for arg in self.__args]) + '}'
 
     def get_r(self, mem_read=False, cst_read=False):
         return reduce(lambda elements, arg:
-                      elements.union(arg[0].get_r(mem_read, cst_read)), self._args, set())
+                      elements.union(arg[0].get_r(mem_read, cst_read)), self.__args, set())
 
     def get_w(self):
         return reduce(lambda elements, arg:
-                      elements.union(arg[0].get_w()), self._args, set())
+                      elements.union(arg[0].get_w()), self.__args, set())
 
     def _exprhash(self):
         h_args = [EXPRCOMPOSE] + [(hash(arg[0]), arg[1], arg[2])
-                                  for arg in self._args]
+                                  for arg in self.__args]
         return hash(tuple(h_args))
 
     def _exprrepr(self):
-        return "%s(%r)" % (self.__class__.__name__, self._args)
+        return "%s(%r)" % (self.__class__.__name__, self.__args)
 
     def __contains__(self, e):
         if self == e:
diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py
index 0c661c2a..3b85d720 100644
--- a/miasm2/expression/expression_helper.py
+++ b/miasm2/expression/expression_helper.py
@@ -86,7 +86,7 @@ def merge_sliceto_slice(args):
             sorted_s.pop()
             out[1] = s_start
         out[0] = m2_expr.ExprInt(int(out[0]), size)
-        final_sources.append((start, out))
+        final_sources.append((start, tuple(out)))
 
     final_sources_int = final_sources
     # check if same sources have corresponding start/stop