about summary refs log tree commit diff stats
path: root/miasm/expression/expression.py
diff options
context:
space:
mode:
authorserpilliere <devnull@localhost>2012-12-19 21:36:24 +0100
committerserpilliere <devnull@localhost>2012-12-19 21:36:24 +0100
commitddddebe228d24e3bae3fa5fcfd63e6c6ef9497e0 (patch)
tree6964441c432d93851c0b0f191c407b591be722ae /miasm/expression/expression.py
parent9813a49e961b674ff0108d0f60c968de17de8a1b (diff)
downloadfocaccia-miasm-ddddebe228d24e3bae3fa5fcfd63e6c6ef9497e0.tar.gz
focaccia-miasm-ddddebe228d24e3bae3fa5fcfd63e6c6ef9497e0.zip
rewrite expression visitor; expr_simp
Diffstat (limited to 'miasm/expression/expression.py')
-rw-r--r--miasm/expression/expression.py81
1 files changed, 61 insertions, 20 deletions
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index df8b4c2c..1029b319 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -67,20 +67,10 @@ def get_missing_interval(all_intervals, i_min = 0, i_max = 32):
 
 def visit_chk(visitor):
     def wrapped(e, cb):
-        #print "wrap", e, visitor, cb
+        #print 'visit', e
         e_new = visitor(e, cb)
         e_new2 = cb(e_new)
-        if e_new2 == e:
-            return e_new2
-        if e_new2 == e_new:
-            return e_new2
-        while True:
-            #print 'NEW', e, e_new2
-            #e = cb(e_new2)
-            e = e_new2.visit(cb)
-            if e_new2 == e:
-                return e_new2
-            e_new2 = e
+        return e_new2
     return wrapped
 
 
@@ -211,6 +201,8 @@ class ExprInt(Expr):
         return str(self)
     @visit_chk
     def visit(self, cb):
+        return self
+    def copy(self):
         return ExprInt(self.arg)
 
 class ExprId(Expr):
@@ -241,6 +233,8 @@ class ExprId(Expr):
         return str(self)
     @visit_chk
     def visit(self, cb):
+        return self
+    def copy(self):
         return ExprId(self.name, self.size)
 
 memreg = ExprId('MEM')
@@ -292,7 +286,13 @@ class ExprAff(Expr):
         return modified_s
     @visit_chk
     def visit(self, cb):
-        return ExprAff(self.dst.visit(cb), self.src.visit(cb))
+        dst, src = self.dst.visit(cb), self.src.visit(cb)
+        if dst == self.dst and src == self.src:
+            return self
+        else:
+            return ExprAff(dst, src)
+    def copy(self):
+        return ExprAff(self.dst.copy(), self.src.copy())
 
 class ExprCond(Expr):
     def __init__(self, cond, src1, src2):
@@ -319,7 +319,18 @@ class ExprCond(Expr):
         return "(%s?%s:%s)"%(self.cond.toC(), self.src1.toC(), self.src2.toC())
     @visit_chk
     def visit(self, cb):
-        return ExprCond(self.cond.visit(cb), self.src1.visit(cb), self.src2.visit(cb))
+        cond = self.cond.visit(cb)
+        src1 = self.src1.visit(cb)
+        src2 = self.src2.visit(cb)
+        if cond == self.cond and \
+                src1 == self.src1 and \
+                src2 == self.src2:
+            return self
+        return ExprCond(cond, src1, src2)
+    def copy(self):
+        return ExprCond(self.cond.copy(),
+                        self.src1.copy(),
+                        self.src2.copy())
 
 class ExprMem(Expr):
     def __init__(self, arg, size = 32, segm = None):
@@ -357,7 +368,21 @@ class ExprMem(Expr):
         segm = self.segm
         if isinstance(segm, Expr):
             segm = self.segm.visit(cb)
-        return ExprMem(self.arg.visit(cb), self.size, segm)
+        else:
+            segm = None
+        arg = self.arg.visit(cb)
+        if segm == self.segm and arg == self.arg:
+            return self
+        return ExprMem(arg, self.size, segm)
+    def copy(self):
+        arg = self.arg.copy()
+        if self.segm:
+            segm = self.segm.copy()
+        else:
+            segm = None
+        return ExprMem(arg, size = self.size, segm = segm)
+
+
 op_assoc = ['+', '*', '^', '&', '|']
 class ExprOp(Expr):
     def __init__(self, op, *args):
@@ -543,6 +568,12 @@ class ExprOp(Expr):
     @visit_chk
     def visit(self, cb):
         args = [a.visit(cb) for a in self.args]
+        modified = any([x[0] != x[1] for x in zip(self.args, args)])
+        if modified:
+            return ExprOp(self.op, *args)
+        return self
+    def copy(self):
+        args = [a.copy() for a in self.args]
         return ExprOp(self.op, *args)
 
 class ExprSlice(Expr):
@@ -573,8 +604,12 @@ class ExprSlice(Expr):
                                     (1<<(self.stop-self.start))-1)
     @visit_chk
     def visit(self, cb):
-        return ExprSlice(self.arg.visit(cb), self.start, self.stop)
-
+        arg = self.arg.visit(cb)
+        if arg == self.arg:
+            return self
+        return ExprSlice(arg, self.start, self.stop)
+    def copy(self):
+        return ExprSlice(self.arg.copy(), self.start, self.stop)
 
 class ExprCompose(Expr):
     def __init__(self, args):
@@ -599,10 +634,10 @@ class ExprCompose(Expr):
     def __eq__(self, a):
         if not isinstance(a, ExprCompose):
             return False
-        if not len(self.args) == len(a.args):
+        if len(self.args) != len(a.args):
             return False
-        for i, x in enumerate(self.args):
-            if not x == a.args[i]:
+        for (e1, start1, stop1), (e2, start2, stop2) in zip(self.args, a.args):
+            if e1 != e2 or start1 != start2 or stop1 != stop2:
                 return False
         return True
     def __hash__(self):
@@ -622,6 +657,12 @@ class ExprCompose(Expr):
     @visit_chk
     def visit(self, cb):
         args = [(a[0].visit(cb), a[1], a[2]) for a in self.args]
+        modified = any([x[0] != x[1] for x in zip(self.args, args)])
+        if modified:
+            return ExprCompose(args)
+        return self
+    def copy(self):
+        args = [(a[0].copy(), a[1], a[2]) for a in self.args]
         return ExprCompose(args)
 
 class set_expr: