about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorserpilliere <devnull@localhost>2012-06-12 13:08:34 +0200
committerserpilliere <devnull@localhost>2012-06-12 13:08:34 +0200
commitf29e0cfd42e6207929205e26dc6afebb686ca339 (patch)
tree67af001655f4ef7105e95582001f6d22b95a47fa
parent7349000a9649cbce710a3e5df7a779e56c0e7ae4 (diff)
downloadmiasm-f29e0cfd42e6207929205e26dc6afebb686ca339.tar.gz
miasm-f29e0cfd42e6207929205e26dc6afebb686ca339.zip
expression: replace reload_expr with replace_expr using visitor
-rw-r--r--example/expression/find_conditions.py6
-rw-r--r--miasm/arch/ia32_sem.py2
-rw-r--r--miasm/expression/expression.py80
-rw-r--r--miasm/expression/expression_eval_abstract.py2
-rw-r--r--miasm/expression/expression_helper.py4
-rw-r--r--miasm/tools/to_c_helper.py6
6 files changed, 33 insertions, 67 deletions
diff --git a/example/expression/find_conditions.py b/example/expression/find_conditions.py
index 9fa5def8..3fe64f2c 100644
--- a/example/expression/find_conditions.py
+++ b/example/expression/find_conditions.py
@@ -75,8 +75,8 @@ def emul_mn(states_todo, states_done, all_blocs, job_done):
             c2 = {ad.cond: ExprInt(uint32(1))}
             p1[ad.cond] = ExprInt(uint32(0))
             p2[ad.cond] = ExprInt(uint32(1))
-            ad1 = machine.eval_expr(ad.reload_expr(c1), {})
-            ad2 = machine.eval_expr(ad.reload_expr(c2), {})
+            ad1 = machine.eval_expr(ad.replace_expr(c1), {})
+            ad2 = machine.eval_expr(ad.replace_expr(c2), {})
             if not (isinstance(ad1, ExprInt) and isinstance(ad2, ExprInt)):
                 print str(ad1), str(ad2)
                 raise ValueError("zarb condition")
@@ -120,4 +120,4 @@ for ad, pool in states_done:
 
 machine = x86_machine()
 for k, v in list(all_info):
-    print machine.eval_expr(k.reload_expr({}), {}), "=", v
+    print machine.eval_expr(k.replace_expr({}), {}), "=", v
diff --git a/miasm/arch/ia32_sem.py b/miasm/arch/ia32_sem.py
index df94e071..6870639d 100644
--- a/miasm/arch/ia32_sem.py
+++ b/miasm/arch/ia32_sem.py
@@ -948,7 +948,7 @@ def pop(info, a):
     e.append(ExprAff(esp, new_esp))
     #XXX FIX XXX for pop [esp]
     if isinstance(a, ExprMem):
-        a =a.reload_expr({esp:new_esp})
+        a =a.replace_expr({esp:new_esp})
     e.append(ExprAff(a, ExprMem(esp, s)))
     return e
 
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index ad95e45b..a705bee6 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -75,14 +75,14 @@ def visit_chk(visitor):
         e_new = visitor(e, cb)
         e_new2 = cb(e_new)
         if e_new2 == e:
-            return e
+            return e_new2
         if e_new2 == e_new:
-            return e_new
+            return e_new2
         while True:
             #print 'NEW', e, e_new
             e = cb(e_new2)
             if e_new2 == e:
-                return e
+                return e_new2
             e_new2 = e
     return wrapped
 
@@ -138,6 +138,20 @@ class Expr:
     def __invert__(self):
         s = self.get_size()
         return ExprOp('^', self, ExprInt(size2type[s](my_size_mask[s])))
+    def copy(self):
+        """
+        deep copy of the expression
+        """
+        return self.visit(lambda x:x)
+    def replace_expr(self, dct = {}):
+        """
+        find and replace sub expression using dct
+        """
+        def my_replace(e, dct):
+            if e in dct:
+                return dct[e]
+            return e
+        return self.visit(lambda e:my_replace(e, dct))
 
 class ExprTop(Expr):
     def __init__(self, e=None):
@@ -152,8 +166,6 @@ class ExprTop(Expr):
         raise ValueError("get_r on TOP")
     def get_size(self):
         raise ValueError("get_size on TOP")
-    def reload_expr(self, g = {}):
-        return ExprTop(self.e)
     def __eq__(self, a):
         return isinstance(a, ExprTop)
     def __hash__(self):
@@ -178,10 +190,6 @@ class ExprInt(Expr):
         return set()
     def get_size(self):
         return 8*self.arg.nbytes
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        return ExprInt(self.arg)
     def __contains__(self, e):
         return self == e
     def __eq__(self, a):
@@ -198,7 +206,7 @@ class ExprInt(Expr):
         return self
     @visit_chk
     def visit(self, cb):
-        return self
+        return ExprInt(self.arg)
 
 class ExprId(Expr):
     def __init__(self, name, size = 32, is_term = False):
@@ -212,11 +220,6 @@ class ExprId(Expr):
         return set([self])
     def get_size(self):
         return self.size
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        else:
-            return ExprId(self.name, self.size)
     def __contains__(self, e):
         return self == e
     def __eq__(self, a):
@@ -235,7 +238,7 @@ class ExprId(Expr):
         return self
     @visit_chk
     def visit(self, cb):
-        return self
+        return ExprId(self.name, self.size)
 
 memreg = ExprId('MEM')
 
@@ -264,12 +267,6 @@ class ExprAff(Expr):
     #return dst size? XXX
     def get_size(self):
         return self.dst.get_size()
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        dst = self.dst.reload_expr(g)
-        src = self.src.reload_expr(g)
-        return ExprAff(dst, src )
     def __contains__(self, e):
         return self == e or self.src.__contains__(e) or self.dst.__contains__(e)
     def __eq__(self, a):
@@ -309,13 +306,6 @@ class ExprCond(Expr):
     #return src1 size? XXX
     def get_size(self):
         return self.src1.get_size()
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        cond = self.cond.reload_expr(g)
-        src1 = self.src1.reload_expr(g)
-        src2 = self.src2.reload_expr(g)
-        return ExprCond(cond, src1, src2)
     def __contains__(self, e):
         return self == e or self.cond.__contains__(e) or self.src1.__contains__(e) or self.src2.__contains__(e)
     def __eq__(self, a):
@@ -352,14 +342,6 @@ class ExprMem(Expr):
         return set([self]) #[memreg]
     def get_size(self):
         return self.size
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        arg = self.arg.reload_expr(g)
-        segm = self.segm
-        if isinstance(segm, Expr):
-            segm = self.segm.reload_expr(g)
-        return ExprMem(arg, self.size, segm)
     def __contains__(self, e):
         return self == e or self.arg.__contains__(e)
     def __eq__(self, a):
@@ -377,7 +359,10 @@ class ExprMem(Expr):
         return ExprMem(self.arg.canonize(), size = self.size)
     @visit_chk
     def visit(self, cb):
-        return ExprMem(self.arg.visit(cb), self.size, self.segm)
+        segm = self.segm
+        if isinstance(segm, Expr):
+            segm = self.segm.visit(cb)
+        return ExprMem(self.arg.visit(cb), self.size, segm)
 
 class ExprOp(Expr):
     def __init__(self, op, *args):
@@ -400,13 +385,6 @@ class ExprOp(Expr):
             if not a:
                 a = self.args[1].get_size()
         return a
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        args = []
-        for a in self.args:
-            args.append(a.reload_expr(g))
-        return ExprOp(self.op, *args )
     def __contains__(self, e):
         if self == e:
             return True
@@ -578,11 +556,6 @@ class ExprSlice(Expr):
         return self.arg.get_w()
     def get_size(self):
         return self.stop-self.start
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        arg = self.arg.reload_expr(g)
-        return ExprSlice(arg, self.start, self.stop )
     def __contains__(self, e):
         if self == e:
             return True
@@ -621,13 +594,6 @@ class ExprCompose(Expr):
         return reduce(lambda x,y:x.union(y[0].get_r(mem_read)), self.args, set())
     def get_size(self):
         return max([x[2] for x in self.args]) - min([x[1] for x in self.args])
-    def reload_expr(self, g = {}):
-        if self in g:
-            return g[self]
-        args = []
-        for a in self.args:
-            args.append((a[0].reload_expr(g), a[1], a[2]))
-        return ExprCompose(args )
     def __contains__(self, e):
         if self == e:
             return True
diff --git a/miasm/expression/expression_eval_abstract.py b/miasm/expression/expression_eval_abstract.py
index d8a37b40..39a41114 100644
--- a/miasm/expression/expression_eval_abstract.py
+++ b/miasm/expression/expression_eval_abstract.py
@@ -182,7 +182,7 @@ class eval_abs:
             print repr(g[str(xx)]), g[str(xx)]
 
             if isinstance(m.pool[x], Expr):
-                new_pool[g[str(xx)]] = m.pool[x].reload_expr(g)
+                new_pool[g[str(xx)]] = m.pool[x].replace_expr(g)
             else:
                 new_pool[g[str(xx)]] = m.pool[x]
 
diff --git a/miasm/expression/expression_helper.py b/miasm/expression/expression_helper.py
index 37a2a273..6a04e66f 100644
--- a/miasm/expression/expression_helper.py
+++ b/miasm/expression/expression_helper.py
@@ -72,7 +72,7 @@ def merge_sliceto_slice(args):
 
     while sorted_s:
         start, v = sorted_s.pop()
-        out = v[0].reload_expr(), v[1], v[2]
+        out = v[0].copy(), v[1], v[2]
         while sorted_s:
             if sorted_s[-1][1][2] != start:
                 break
@@ -101,7 +101,7 @@ def merge_sliceto_slice(args):
         sorted_s.sort()
         while sorted_s:
             start, v = sorted_s.pop()
-            out = v[0].reload_expr(), v[1], v[2]
+            out = v[0].copy(), v[1], v[2]
             while sorted_s:
                 if sorted_s[-1][1][2] != start:
                     break
diff --git a/miasm/tools/to_c_helper.py b/miasm/tools/to_c_helper.py
index 508c57d7..50d79d0b 100644
--- a/miasm/tools/to_c_helper.py
+++ b/miasm/tools/to_c_helper.py
@@ -228,7 +228,7 @@ for x in my_C_id:
     id2Cid[x] = ExprId('vmcpu.'+str(x), x.get_size())
 
 def patch_c_id(e):
-    return e.reload_expr(id2Cid)
+    return e.replace_expr(id2Cid)
 
 
 code_deal_exception_at_instr = r"""
@@ -331,7 +331,7 @@ def Exp2C(exprs, l = None, addr2label = None, gen_exception_code = False):
         if True:#e.dst != eip:
             src, dst = e.src, e.dst
             # reload src using prefetch
-            src = src.reload_expr(src_w_len)
+            src = src.replace_expr(src_w_len)
             str_src = patch_c_id(src).toC()
             str_dst = patch_c_id(dst).toC()
             if isinstance(dst, ExprId):
@@ -935,7 +935,7 @@ if __name__ == '__main__':
         print x
     print '#'*80
 
-    new_e = [x.reload_expr({ExprMem(eax): ExprId('ioio')}) for x in e]
+    new_e = [x.replace_expr({ExprMem(eax): ExprId('ioio')}) for x in e]
     for x in new_e:
         print x
     print '-'*80