about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--example/expression/manip_expression6.py135
-rw-r--r--miasm/expression/expression.py90
-rw-r--r--miasm/expression/expression_helper.py18
3 files changed, 164 insertions, 79 deletions
diff --git a/example/expression/manip_expression6.py b/example/expression/manip_expression6.py
index 180022e1..7e32da4b 100644
--- a/example/expression/manip_expression6.py
+++ b/example/expression/manip_expression6.py
@@ -51,69 +51,88 @@ print u
 print z
 print u == z
 
-to_test = [(ExprInt32(5)+c+a+b-a+ExprInt32(1)-ExprInt32(5)),
-           a+b+c-a-b-c+a,
-           a+a+b+c-(a+(b+c)),
-           c^b^a^c^b,
-           a^ExprInt32(0),
-           (a+b)-b,
-           -(ExprInt32(0)-((a+b)-b)),
-
-           ExprOp('<<<', a, ExprInt32(32)),
-           ExprOp('>>>', a, ExprInt32(32)),
-           ExprOp('>>>', a, ExprInt32(0)),
-           ExprOp('<<', a, ExprInt32(0)),
-
-           ExprOp('<<<', a, ExprOp('<<<', b, c)),
-           ExprOp('<<<', ExprOp('<<<', a, b), c),
-           ExprOp('<<<', ExprOp('>>>', a, b), c),
-           ExprOp('>>>', ExprOp('<<<', a, b), c),
-           ExprOp('>>>', ExprOp('<<<', a, b), b),
-
-
-           ExprOp('>>>', ExprOp('<<<', a, ExprInt32(10)), ExprInt32(2)),
-
-           ExprOp('>>>', ExprOp('<<<', a, ExprInt32(10)), ExprInt32(2)) ^ ExprOp('>>>', ExprOp('<<<', a, ExprInt32(10)), ExprInt32(2)),
-           ExprOp(">>", (a & ExprInt32(0xF)), ExprInt32(0x15)),
-           ExprOp("==", ExprInt32(12), ExprInt32(10)),
-           ExprOp("==", ExprInt32(12), ExprInt32(12)),
-           ExprOp("==", a|ExprInt32(12), ExprInt32(0)),
-           ExprOp("==", a|ExprInt32(12), ExprInt32(14)),
-           ExprOp("parity", ExprInt32(0xf)),
-           ExprOp("parity", ExprInt32(0xe)),
-           ExprInt32(0x4142)[:32],
-           ExprInt32(0x4142)[:8],
-           ExprInt32(0x4142)[8:16],
-           a[:32],
-           a[:8][:8],
-           a[:16][:8],
-           a[8:16][:8],
-           a[8:32][:8],
-           a[:16][8:16],
-           ExprCompose([(a, 0, 32)]),
-           ExprCompose([(a[:16], 0, 16)]),
-           ExprCompose([(a[:16], 0, 16), (a, 16, 32)]),
-           ExprCompose([(a[:16], 0, 16), (a[16:32], 16, 32)]),
-
-           ExprMem(a)[:32],
-           ExprMem(a)[:16],
-
-           ExprCond(ExprInt32(1), a, b),
-           ExprCond(ExprInt32(0), b, a),
-
-           ExprInt32(0x80000000)[31:32],
-           ExprCompose([(ExprInt16(0x1337)[:8], 0, 8),(ExprInt16(0x1337)[8:16], 8, 16)]),
-
-           ExprCompose([(ExprInt32(0x1337beef)[8:16], 8, 16),
+to_test = [(ExprInt32(1)-ExprInt32(1), ExprInt32(0)),
+           ((ExprInt32(5)+c+a+b-a+ExprInt32(1)-ExprInt32(5)),b+c+ExprInt32(1)),
+           (a+b+c-a-b-c+a,a),
+           (a+a+b+c-(a+(b+c)),a),
+           (c^b^a^c^b,a),
+           (a^ExprInt32(0),a),
+           ((a+b)-b,a),
+           (-(ExprInt32(0)-((a+b)-b)),a),
+
+           (ExprOp('<<<', a, ExprInt32(32)),a),
+           (ExprOp('>>>', a, ExprInt32(32)),a),
+           (ExprOp('>>>', a, ExprInt32(0)),a),
+           (ExprOp('<<', a, ExprInt32(0)),a),
+
+           (ExprOp('<<<', a, ExprOp('<<<', b, c)),
+            ExprOp('<<<', a, ExprOp('<<<', b, c))),
+           (ExprOp('<<<', ExprOp('<<<', a, b), c),
+            ExprOp('<<<', ExprOp('<<<', a, b), c)),
+           (ExprOp('<<<', ExprOp('>>>', a, b), c),
+            ExprOp('<<<', ExprOp('>>>', a, b), c)),
+           (ExprOp('>>>', ExprOp('<<<', a, b), c),
+            ExprOp('>>>', ExprOp('<<<', a, b), c)),
+           (ExprOp('>>>', ExprOp('<<<', a, b), b),
+            ExprOp('>>>', ExprOp('<<<', a, b), b)),
+
+
+           (ExprOp('>>>', ExprOp('<<<', a, ExprInt32(10)), ExprInt32(2)),
+            ExprOp('<<<', a, ExprInt32(8))),
+
+           (ExprOp('>>>', ExprOp('<<<', a, ExprInt32(10)), ExprInt32(2)) ^ ExprOp('>>>', ExprOp('<<<', a, ExprInt32(10)), ExprInt32(2)),
+            ExprInt32(0)),
+           (ExprOp(">>", (a & ExprInt32(0xF)), ExprInt32(0x15)),
+            ExprInt32(0)),
+           (ExprOp("==", ExprInt32(12), ExprInt32(10)), ExprInt32(0)),
+           (ExprOp("==", ExprInt32(12), ExprInt32(12)), ExprInt32(1)),
+           (ExprOp("==", a|ExprInt32(12), ExprInt32(0)),ExprInt32(0)),
+           (ExprOp("==", a|ExprInt32(12), ExprInt32(14)),
+            ExprOp("==", a|ExprInt32(12), ExprInt32(14))),
+           (ExprOp("parity", ExprInt32(0xf)), ExprInt32(1)),
+           (ExprOp("parity", ExprInt32(0xe)), ExprInt32(0)),
+           (ExprInt32(0x4142)[:32],ExprInt32(0x4142)),
+           (ExprInt32(0x4142)[:8],ExprInt8(0x42)),
+           (ExprInt32(0x4142)[8:16],ExprInt8(0x41)),
+           (a[:32], a),
+           (a[:8][:8],a[:8]),
+           (a[:16][:8],a[:8]),
+           (a[8:16][:8],a[8:16]),
+           (a[8:32][:8],a[8:16]),
+           (a[:16][8:16],a[8:16]),
+           (ExprCompose([(a, 0, 32)]),a),
+           (ExprCompose([(a[:16], 0, 16)]), a[:16]),
+           (ExprCompose([(a[:16], 0, 16), (a, 16, 32)]),
+            ExprCompose([(a[:16], 0, 16), (a, 16, 32)]),),
+           (ExprCompose([(a[:16], 0, 16), (a[16:32], 16, 32)]), a),
+
+           (ExprMem(a)[:32], ExprMem(a)),
+           (ExprMem(a)[:16], ExprMem(a, size=16)),
+
+           (ExprCond(ExprInt32(1), a, b), a),
+           (ExprCond(ExprInt32(0), b, a), a),
+
+           (ExprInt32(0x80000000)[31:32], ExprInt32(1)),
+           (ExprCompose([(ExprInt16(0x1337)[:8], 0, 8),(ExprInt16(0x1337)[8:16], 8, 16)]),
+            ExprInt16(0x1337)),
+
+           (ExprCompose([(ExprInt32(0x1337beef)[8:16], 8, 16),
                         (ExprInt32(0x1337beef)[:8], 0, 8),
                         (ExprInt32(0x1337beef)[16:32], 16, 32)]),
+            ExprInt32(0x1337BEEF)),
 
 
            ]
 
 
-for e in to_test:
+for e, e_check in to_test[:]:
+    #
     print "#"*80
-    print e
-    print expr_simp(e)
-
+    e_check = expr_simp(e_check)
+    print "#"*80
+    print str(e), str(e_check)
+    e_new = expr_simp(e)
+    print "orig", str(e), "new", str(e_new), "check", str(e_check)
+    rez = e_new == e_check
+    if not rez:
+        fdsfds
diff --git a/miasm/expression/expression.py b/miasm/expression/expression.py
index f1a1f78c..6018be4a 100644
--- a/miasm/expression/expression.py
+++ b/miasm/expression/expression.py
@@ -67,20 +67,10 @@ def get_missing_interval(all_intervals, i_min = 0, i_max = 32):
 
 def visit_chk(visitor):
     def wrapped(e, cb):
-        #print "wrap", e, visitor, cb
+        #print 'visit', e
         e_new = visitor(e, cb)
         e_new2 = cb(e_new)
-        if e_new2 == e:
-            return e_new2
-        if e_new2 == e_new:
-            return e_new2
-        while True:
-            #print 'NEW', e, e_new2
-            #e = cb(e_new2)
-            e = e_new2.visit(cb)
-            if e_new2 == e:
-                return e_new2
-            e_new2 = e
+        return e_new2
     return wrapped
 
 
@@ -211,6 +201,8 @@ class ExprInt(Expr):
         return str(self)
     @visit_chk
     def visit(self, cb):
+        return self
+    def copy(self):
         return ExprInt(self.arg)
 
 class ExprId(Expr):
@@ -241,6 +233,8 @@ class ExprId(Expr):
         return str(self)
     @visit_chk
     def visit(self, cb):
+        return self
+    def copy(self):
         return ExprId(self.name, self.size)
 
 memreg = ExprId('MEM')
@@ -292,7 +286,13 @@ class ExprAff(Expr):
         return modified_s
     @visit_chk
     def visit(self, cb):
-        return ExprAff(self.dst.visit(cb), self.src.visit(cb))
+        dst, src = self.dst.visit(cb), self.src.visit(cb)
+        if dst == self.dst and src == self.src:
+            return self
+        else:
+            return ExprAff(dst, src)
+    def copy(self):
+        return ExprAff(self.dst.copy(), self.src.copy())
 
 class ExprCond(Expr):
     def __init__(self, cond, src1, src2):
@@ -319,7 +319,18 @@ class ExprCond(Expr):
         return "(%s?%s:%s)"%(self.cond.toC(), self.src1.toC(), self.src2.toC())
     @visit_chk
     def visit(self, cb):
-        return ExprCond(self.cond.visit(cb), self.src1.visit(cb), self.src2.visit(cb))
+        cond = self.cond.visit(cb)
+        src1 = self.src1.visit(cb)
+        src2 = self.src2.visit(cb)
+        if cond == self.cond and \
+                src1 == self.src1 and \
+                src2 == self.src2:
+            return self
+        return ExprCond(cond, src1, src2)
+    def copy(self):
+        return ExprCond(self.cond.copy(),
+                        self.src1.copy(),
+                        self.src2.copy())
 
 class ExprMem(Expr):
     def __init__(self, arg, size = 32, segm = None):
@@ -357,7 +368,21 @@ class ExprMem(Expr):
         segm = self.segm
         if isinstance(segm, Expr):
             segm = self.segm.visit(cb)
-        return ExprMem(self.arg.visit(cb), self.size, segm)
+        else:
+            segm = None
+        arg = self.arg.visit(cb)
+        if segm == self.segm and arg == self.arg:
+            return self
+        return ExprMem(arg, self.size, segm)
+    def copy(self):
+        arg = self.arg.copy()
+        if self.segm:
+            segm = self.segm.copy()
+        else:
+            segm = None
+        return ExprMem(arg, size = self.size, segm = segm)
+
+
 op_assoc = ['+', '*', '^', '&', '|']
 class ExprOp(Expr):
     def __init__(self, op, *args):
@@ -543,6 +568,12 @@ class ExprOp(Expr):
     @visit_chk
     def visit(self, cb):
         args = [a.visit(cb) for a in self.args]
+        modified = any([x[0] != x[1] for x in zip(self.args, args)])
+        if modified:
+            return ExprOp(self.op, *args)
+        return self
+    def copy(self):
+        args = [a.copy() for a in self.args]
         return ExprOp(self.op, *args)
 
 class ExprSlice(Expr):
@@ -573,8 +604,12 @@ class ExprSlice(Expr):
                                     (1<<(self.stop-self.start))-1)
     @visit_chk
     def visit(self, cb):
-        return ExprSlice(self.arg.visit(cb), self.start, self.stop)
-
+        arg = self.arg.visit(cb)
+        if arg == self.arg:
+            return self
+        return ExprSlice(arg, self.start, self.stop)
+    def copy(self):
+        return ExprSlice(self.arg.copy(), self.start, self.stop)
 
 class ExprCompose(Expr):
     def __init__(self, args):
@@ -599,10 +634,10 @@ class ExprCompose(Expr):
     def __eq__(self, a):
         if not isinstance(a, ExprCompose):
             return False
-        if not len(self.args) == len(a.args):
+        if len(self.args) != len(a.args):
             return False
-        for i, x in enumerate(self.args):
-            if not x == a.args[i]:
+        for (e1, start1, stop1), (e2, start2, stop2) in zip(self.args, a.args):
+            if e1 != e2 or start1 != start2 or stop1 != stop2:
                 return False
         return True
     def __hash__(self):
@@ -622,6 +657,12 @@ class ExprCompose(Expr):
     @visit_chk
     def visit(self, cb):
         args = [(a[0].visit(cb), a[1], a[2]) for a in self.args]
+        modified = any([x[0] != x[1] for x in zip(self.args, args)])
+        if modified:
+            return ExprCompose(args)
+        return self
+    def copy(self):
+        args = [(a[0].copy(), a[1], a[2]) for a in self.args]
         return ExprCompose(args)
 
 class set_expr:
@@ -776,6 +817,15 @@ def ExprInt_from(e, i):
     return ExprInt(tab_uintsize[e.get_size()](i))
 
 
+def get_expr_ids_visit(e, ids):
+    if isinstance(e, ExprId):
+        ids.add(e)
+    return e
+
+def get_expr_ids(e):
+    ids = set()
+    e.visit(lambda x:get_expr_ids_visit(x, ids))
+    return ids
 
 def test_set(e, v, tks, result):
     if not v in tks:
diff --git a/miasm/expression/expression_helper.py b/miasm/expression/expression_helper.py
index 17472e7f..6d82d52d 100644
--- a/miasm/expression/expression_helper.py
+++ b/miasm/expression/expression_helper.py
@@ -130,9 +130,25 @@ op_assoc = ['+', '*', '^', '&', '|']
 
 
 def expr_simp(e):
-    return e.visit(_expr_simp)
+    return  e.visit(_expr_simp_w)
+
+def _expr_simp_w(e):
+    if not hasattr(e, 'simp'):
+        e.simp = False
+    if e.simp:
+        #print 'done'
+        return e
+    while True:
+        e_new = _expr_simp(e)
+        if e_new == e:
+            break
+        e = expr_simp(e_new)
+    e.simp = True
+    #print 'return', e
+    return e
 
 def _expr_simp(e):
+    #print 'simp', e
     if isinstance(e, ExprOp):
         # merge associatif op
         # ((a+b) + c) => (a + b + c)