about summary refs log tree commit diff stats
path: root/miasm/expression
diff options
context:
space:
mode:
authorserpilliere <devnull@localhost>2012-06-12 13:08:34 +0200
committerserpilliere <devnull@localhost>2012-06-12 13:08:34 +0200
commitf29e0cfd42e6207929205e26dc6afebb686ca339 (patch)
tree67af001655f4ef7105e95582001f6d22b95a47fa /miasm/expression
parent7349000a9649cbce710a3e5df7a779e56c0e7ae4 (diff)
downloadfocaccia-miasm-f29e0cfd42e6207929205e26dc6afebb686ca339.tar.gz
focaccia-miasm-f29e0cfd42e6207929205e26dc6afebb686ca339.zip
expression: replace reload_expr with replace_expr using visitor
Diffstat (limited to 'miasm/expression')
-rw-r--r--miasm/expression/expression.py80
-rw-r--r--miasm/expression/expression_eval_abstract.py2
-rw-r--r--miasm/expression/expression_helper.py4
3 files changed, 26 insertions, 60 deletions
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index ad95e45b..a705bee6 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -75,14 +75,14 @@ def visit_chk(visitor):
         e_new = visitor(e, cb)
         e_new2 = cb(e_new)
         if e_new2 == e:
-            return e
+            return e_new2
         if e_new2 == e_new:
-            return e_new
+            return e_new2
         while True:
             #print 'NEW', e, e_new
             e = cb(e_new2)
             if e_new2 == e:
-                return e
+                return e_new2
             e_new2 = e
     return wrapped
 
@@ -138,6 +138,20 @@ class Expr:
     def __invert__(self):
         s = self.get_size()
         return ExprOp('^', self, ExprInt(size2type[s](my_size_mask[s])))
+    def copy(self):
+        """
+        deep copy of the expression
+        """
+        return self.visit(lambda x:x)
+    def replace_expr(self, dct = {}):
+        """
+        find and replace sub expression using dct
+        """
+        def my_replace(e, dct):
+            if e in dct:
+                return dct[e]
+            return e
+        return self.visit(lambda e:my_replace(e, dct))
 
 class ExprTop(Expr):
     def __init__(self, e=None):
@@ -152,8 +166,6 @@ class ExprTop(Expr):
         raise ValueError("get_r on TOP")
     def get_size(self):
         raise ValueError("get_size on TOP")
-    def reload_expr(self, g = {}):
-        return ExprTop(self.e)
     def __eq__(self, a):
         return isinstance(a, ExprTop)
     def __hash__(self):
@@ -178,10 +190,6 @@ class ExprInt(Expr):
         return set()
     def get_size(self):
         return 8*self.arg.nbytes
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        return ExprInt(self.arg)
     def __contains__(self, e):
         return self == e
     def __eq__(self, a):
@@ -198,7 +206,7 @@ class ExprInt(Expr):
         return self
     @visit_chk
     def visit(self, cb):
-        return self
+        return ExprInt(self.arg)
 
 class ExprId(Expr):
     def __init__(self, name, size = 32, is_term = False):
@@ -212,11 +220,6 @@ class ExprId(Expr):
         return set([self])
     def get_size(self):
         return self.size
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        else:
-            return ExprId(self.name, self.size)
     def __contains__(self, e):
         return self == e
     def __eq__(self, a):
@@ -235,7 +238,7 @@ class ExprId(Expr):
         return self
     @visit_chk
     def visit(self, cb):
-        return self
+        return ExprId(self.name, self.size)
 
 memreg = ExprId('MEM')
 
@@ -264,12 +267,6 @@ class ExprAff(Expr):
     #return dst size? XXX
     def get_size(self):
         return self.dst.get_size()
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        dst = self.dst.reload_expr(g)
-        src = self.src.reload_expr(g)
-        return ExprAff(dst, src )
     def __contains__(self, e):
         return self == e or self.src.__contains__(e) or self.dst.__contains__(e)
     def __eq__(self, a):
@@ -309,13 +306,6 @@ class ExprCond(Expr):
     #return src1 size? XXX
     def get_size(self):
         return self.src1.get_size()
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        cond = self.cond.reload_expr(g)
-        src1 = self.src1.reload_expr(g)
-        src2 = self.src2.reload_expr(g)
-        return ExprCond(cond, src1, src2)
     def __contains__(self, e):
         return self == e or self.cond.__contains__(e) or self.src1.__contains__(e) or self.src2.__contains__(e)
     def __eq__(self, a):
@@ -352,14 +342,6 @@ class ExprMem(Expr):
         return set([self]) #[memreg]
     def get_size(self):
         return self.size
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        arg = self.arg.reload_expr(g)
-        segm = self.segm
-        if isinstance(segm, Expr):
-            segm = self.segm.reload_expr(g)
-        return ExprMem(arg, self.size, segm)
     def __contains__(self, e):
         return self == e or self.arg.__contains__(e)
     def __eq__(self, a):
@@ -377,7 +359,10 @@ class ExprMem(Expr):
         return ExprMem(self.arg.canonize(), size = self.size)
     @visit_chk
     def visit(self, cb):
-        return ExprMem(self.arg.visit(cb), self.size, self.segm)
+        segm = self.segm
+        if isinstance(segm, Expr):
+            segm = self.segm.visit(cb)
+        return ExprMem(self.arg.visit(cb), self.size, segm)
 
 class ExprOp(Expr):
     def __init__(self, op, *args):
@@ -400,13 +385,6 @@ class ExprOp(Expr):
             if not a:
                 a = self.args[1].get_size()
         return a
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        args = []
-        for a in self.args:
-            args.append(a.reload_expr(g))
-        return ExprOp(self.op, *args )
     def __contains__(self, e):
         if self == e:
             return True
@@ -578,11 +556,6 @@ class ExprSlice(Expr):
         return self.arg.get_w()
     def get_size(self):
         return self.stop-self.start
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        arg = self.arg.reload_expr(g)
-        return ExprSlice(arg, self.start, self.stop )
     def __contains__(self, e):
         if self == e:
             return True
@@ -621,13 +594,6 @@ class ExprCompose(Expr):
         return reduce(lambda x,y:x.union(y[0].get_r(mem_read)), self.args, set())
     def get_size(self):
         return max([x[2] for x in self.args]) - min([x[1] for x in self.args])
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        args = []
-        for a in self.args:
-            args.append((a[0].reload_expr(g), a[1], a[2]))
-        return ExprCompose(args )
     def __contains__(self, e):
         if self == e:
             return True
diff --git a/miasm/expression/expression_eval_abstract.py b/miasm/expression/expression_eval_abstract.py
index d8a37b40..39a41114 100644
--- a/miasm/expression/expression_eval_abstract.py
+++ b/miasm/expression/expression_eval_abstract.py
@@ -182,7 +182,7 @@ class eval_abs:
             print repr(g[str(xx)]), g[str(xx)]
 
             if isinstance(m.pool[x], Expr):
-                new_pool[g[str(xx)]] = m.pool[x].reload_expr(g)
+                new_pool[g[str(xx)]] = m.pool[x].replace_expr(g)
             else:
                 new_pool[g[str(xx)]] = m.pool[x]
 
diff --git a/miasm/expression/expression_helper.py b/miasm/expression/expression_helper.py
index 37a2a273..6a04e66f 100644
--- a/miasm/expression/expression_helper.py
+++ b/miasm/expression/expression_helper.py
@@ -72,7 +72,7 @@ def merge_sliceto_slice(args):
 
     while sorted_s:
         start, v = sorted_s.pop()
-        out = v[0].reload_expr(), v[1], v[2]
+        out = v[0].copy(), v[1], v[2]
         while sorted_s:
             if sorted_s[-1][1][2] != start:
                 break
@@ -101,7 +101,7 @@ def merge_sliceto_slice(args):
         sorted_s.sort()
         while sorted_s:
             start, v = sorted_s.pop()
-            out = v[0].reload_expr(), v[1], v[2]
+            out = v[0].copy(), v[1], v[2]
             while sorted_s:
                 if sorted_s[-1][1][2] != start:
                     break