about summary refs log tree commit diff stats
path: root/miasm2/ir/symbexec.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/ir/symbexec.py')
-rw-r--r--miasm2/ir/symbexec.py603
1 files changed, 301 insertions, 302 deletions
diff --git a/miasm2/ir/symbexec.py b/miasm2/ir/symbexec.py
index 1dc8dde1..d3c56f70 100644
--- a/miasm2/ir/symbexec.py
+++ b/miasm2/ir/symbexec.py
@@ -3,6 +3,8 @@ from miasm2.expression.modint import int32
 from miasm2.expression.simplifications import expr_simp
 from miasm2.core import asmbloc
 from miasm2.ir.ir import AssignBlock
+from miasm2.core.interval import interval
+
 import logging
 
 
@@ -13,72 +15,82 @@ log.addHandler(console_handler)
 log.setLevel(logging.INFO)
 
 
-class symbols():
+class symbols(object):
 
     def __init__(self, init=None):
         if init is None:
             init = {}
         self.symbols_id = {}
         self.symbols_mem = {}
-        for k, v in init.items():
-            self[k] = v
+        for expr, value in init.items():
+            self[expr] = value
 
-    def __contains__(self, a):
-        if not isinstance(a, m2_expr.ExprMem):
-            return self.symbols_id.__contains__(a)
-        if not self.symbols_mem.__contains__(a.arg):
+    def __contains__(self, expr):
+        if not isinstance(expr, m2_expr.ExprMem):
+            return self.symbols_id.__contains__(expr)
+        if not self.symbols_mem.__contains__(expr.arg):
             return False
-        return self.symbols_mem[a.arg][0].size == a.size
-
-    def __getitem__(self, a):
-        if not isinstance(a, m2_expr.ExprMem):
-            return self.symbols_id.__getitem__(a)
-        if not a.arg in self.symbols_mem:
-            raise KeyError(a)
-        m = self.symbols_mem.__getitem__(a.arg)
-        if m[0].size != a.size:
-            raise KeyError(a)
-        return m[1]
-
-    def __setitem__(self, a, v):
-        if not isinstance(a, m2_expr.ExprMem):
-            self.symbols_id.__setitem__(a, v)
+        return self.symbols_mem[expr.arg][0].size == expr.size
+
+    def __getitem__(self, expr):
+        if not isinstance(expr, m2_expr.ExprMem):
+            return self.symbols_id.__getitem__(expr)
+        if not expr.arg in self.symbols_mem:
+            raise KeyError(expr)
+        mem, value = self.symbols_mem.__getitem__(expr.arg)
+        if mem.size != expr.size:
+            raise KeyError(expr)
+        return value
+
+    def get(self, expr, default=None):
+        if not isinstance(expr, m2_expr.ExprMem):
+            return self.symbols_id.get(expr, default)
+        if not expr.arg in self.symbols_mem:
+            return default
+        mem, value = self.symbols_mem.__getitem__(expr.arg)
+        if mem.size != expr.size:
+            return default
+        return value
+
+    def __setitem__(self, expr, value):
+        if not isinstance(expr, m2_expr.ExprMem):
+            self.symbols_id.__setitem__(expr, value)
             return
-        self.symbols_mem.__setitem__(a.arg, (a, v))
+        assert expr.size == value.size
+        self.symbols_mem.__setitem__(expr.arg, (expr, value))
 
     def __iter__(self):
-        for a in self.symbols_id:
-            yield a
-        for a in self.symbols_mem:
-            yield self.symbols_mem[a][0]
-
-    def __delitem__(self, a):
-        if not isinstance(a, m2_expr.ExprMem):
-            self.symbols_id.__delitem__(a)
+        for expr in self.symbols_id:
+            yield expr
+        for expr in self.symbols_mem:
+            yield self.symbols_mem[expr][0]
+
+    def __delitem__(self, expr):
+        if not isinstance(expr, m2_expr.ExprMem):
+            self.symbols_id.__delitem__(expr)
         else:
-            self.symbols_mem.__delitem__(a.arg)
+            self.symbols_mem.__delitem__(expr.arg)
 
     def items(self):
-        k = self.symbols_id.items() + [x for x in self.symbols_mem.values()]
-        return k
+        return self.symbols_id.items() + [x for x in self.symbols_mem.values()]
 
     def keys(self):
-        k = self.symbols_id.keys() + [x[0] for x in self.symbols_mem.values()]
-        return k
+        return (self.symbols_id.keys() +
+                [x[0] for x in self.symbols_mem.values()])
 
     def copy(self):
-        p = symbols()
-        p.symbols_id = dict(self.symbols_id)
-        p.symbols_mem = dict(self.symbols_mem)
-        return p
+        new_symbols = symbols()
+        new_symbols.symbols_id = dict(self.symbols_id)
+        new_symbols.symbols_mem = dict(self.symbols_mem)
+        return new_symbols
 
     def inject_info(self, info):
-        s = symbols()
-        for k, v in self.items():
-            k = expr_simp(k.replace_expr(info))
-            v = expr_simp(v.replace_expr(info))
-            s[k] = v
-        return s
+        new_symbols = symbols()
+        for expr, value in self.items():
+            expr = expr_simp(expr.replace_expr(info))
+            value = expr_simp(value.replace_expr(info))
+            new_symbols[expr] = value
+        return new_symbols
 
 
 class symbexec(object):
@@ -88,154 +100,152 @@ class symbexec(object):
                  func_write=None,
                  sb_expr_simp=expr_simp):
         self.symbols = symbols()
-        for k, v in known_symbols.items():
-            self.symbols[k] = v
+        for expr, value in known_symbols.items():
+            self.symbols[expr] = value
         self.func_read = func_read
         self.func_write = func_write
         self.ir_arch = ir_arch
         self.expr_simp = sb_expr_simp
 
-    def find_mem_by_addr(self, e):
-        if e in self.symbols.symbols_mem:
-            return self.symbols.symbols_mem[e][0]
+    def find_mem_by_addr(self, expr):
+        """
+        Return memory keys with pointer equal to @expr
+        @expr: address of the searched memory variable
+        """
+        if expr in self.symbols.symbols_mem:
+            return self.symbols.symbols_mem[expr][0]
         return None
 
-    def eval_ExprId(self, e, eval_cache=None):
-        if eval_cache is None:
-            eval_cache = {}
-        if isinstance(e.name, asmbloc.asm_label) and e.name.offset is not None:
-            return m2_expr.ExprInt_from(e, e.name.offset)
-        if not e in self.symbols:
-            # raise ValueError('unknown symbol %s'% e)
-            return e
-        return self.symbols[e]
-
-    def eval_ExprInt(self, e, eval_cache=None):
-        return e
-
-    def eval_ExprMem(self, e, eval_cache=None):
-        if eval_cache is None:
-            eval_cache = {}
-        a_val = self.expr_simp(self.eval_expr(e.arg, eval_cache))
-        if a_val != e.arg:
-            a = self.expr_simp(m2_expr.ExprMem(a_val, size=e.size))
-        else:
-            a = e
-        if a in self.symbols:
-            return self.symbols[a]
-        tmp = None
-        # test if mem lookup is known
-        if a_val in self.symbols.symbols_mem:
-            tmp = self.symbols.symbols_mem[a_val][0]
-        if tmp is None:
-
-            v = self.find_mem_by_addr(a_val)
-            if not v:
-                out = []
-                ov = self.get_mem_overlapping(a, eval_cache)
-                off_base = 0
-                ov.sort()
-                # ov.reverse()
-                for off, x in ov:
-                    # off_base = off * 8
-                    # x_size = self.symbols[x].size
-                    if off >= 0:
-                        m = min(a.size - off * 8, x.size)
-                        ee = m2_expr.ExprSlice(self.symbols[x], 0, m)
-                        ee = self.expr_simp(ee)
-                        out.append((ee, off_base, off_base + m))
-                        off_base += m
-                    else:
-                        m = min(a.size - off * 8, x.size)
-                        ee = m2_expr.ExprSlice(self.symbols[x], -off * 8, m)
-                        ff = self.expr_simp(ee)
-                        new_off_base = off_base + m + off * 8
-                        out.append((ff, off_base, new_off_base))
-                        off_base = new_off_base
-                if out:
-                    missing_slice = self.rest_slice(out, 0, a.size)
-                    for sa, sb in missing_slice:
-                        ptr = self.expr_simp(
-                            a_val + m2_expr.ExprInt_from(a_val, sa / 8)
-                        )
-                        mm = m2_expr.ExprMem(ptr, size=sb - sa)
-                        mm.is_term = True
-                        mm.is_simp = True
-                        out.append((mm, sa, sb))
-                    out.sort(key=lambda x: x[1])
-                    # for e, sa, sb in out:
-                    #    print str(e), sa, sb
-                    ee = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, a.size)
-                    ee = self.expr_simp(ee)
-                    return ee
-            if self.func_read and isinstance(a.arg, m2_expr.ExprInt):
-                return self.func_read(a)
+    def get_mem_state(self, expr):
+        """
+        Evaluate the @expr memory in the current state using @cache
+        @expr: the memory key
+        """
+        ptr, size = expr.arg, expr.size
+        ret = self.find_mem_by_addr(ptr)
+        if not ret:
+            out = []
+            overlaps = self.get_mem_overlapping(expr)
+            off_base = 0
+            for off, mem in overlaps:
+                if off >= 0:
+                    new_size = min(size - off * 8, mem.size)
+                    tmp = self.expr_simp(self.symbols[mem][0:new_size])
+                    out.append((tmp, off_base, off_base + new_size))
+                    off_base += new_size
+                else:
+                    new_size = min(size - off * 8, mem.size)
+                    tmp = self.expr_simp(self.symbols[mem][-off * 8:new_size])
+                    new_off_base = off_base + new_size + off * 8
+                    out.append((tmp, off_base, new_off_base))
+                    off_base = new_off_base
+            if out:
+                missing_slice = self.rest_slice(out, 0, size)
+                for slice_start, slice_stop in missing_slice:
+                    ptr = self.expr_simp(ptr + m2_expr.ExprInt(slice_start / 8, ptr.size))
+                    mem = m2_expr.ExprMem(ptr, slice_stop - slice_start)
+                    out.append((mem, slice_start, slice_stop))
+                out.sort(key=lambda x: x[1])
+                tmp = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, size)
+                tmp = self.expr_simp(tmp)
+                return tmp
+
+
+            if self.func_read and isinstance(ptr, m2_expr.ExprInt):
+                return self.func_read(expr)
             else:
-                # XXX hack test
-                a.is_term = True
-                return a
+                return expr
         # bigger lookup
-        if a.size > tmp.size:
-            rest = a.size
-            ptr = a_val
+        if size > ret.size:
+            rest = size
+            ptr = ptr
             out = []
             ptr_index = 0
             while rest:
-                v = self.find_mem_by_addr(ptr)
-                if v is None:
-                    # raise ValueError("cannot find %s in mem"%str(ptr))
-                    val = m2_expr.ExprMem(ptr, 8)
-                    v = val
+                mem = self.find_mem_by_addr(ptr)
+                if mem is None:
+                    value = m2_expr.ExprMem(ptr, 8)
+                    mem = value
                     diff_size = 8
-                elif rest >= v.size:
-                    val = self.symbols[v]
-                    diff_size = v.size
+                elif rest >= mem.size:
+                    value = self.symbols[mem]
+                    diff_size = mem.size
                 else:
                     diff_size = rest
-                    val = self.symbols[v][0:diff_size]
-                val = (val, ptr_index, ptr_index + diff_size)
-                out.append(val)
+                    value = self.symbols[mem][0:diff_size]
+                out.append((value, ptr_index, ptr_index + diff_size))
                 ptr_index += diff_size
                 rest -= diff_size
-                ptr = self.expr_simp(
-                    self.eval_expr(
-                        m2_expr.ExprOp('+', ptr,
-                                       m2_expr.ExprInt_from(ptr, v.size / 8)),
-                        eval_cache)
-                )
-            e = self.expr_simp(m2_expr.ExprCompose(out))
-            return e
+                ptr = self.expr_simp(ptr + m2_expr.ExprInt(mem.size / 8, ptr.size))
+            ret = self.expr_simp(m2_expr.ExprCompose(out))
+            return ret
         # part lookup
-        tmp = self.expr_simp(m2_expr.ExprSlice(self.symbols[tmp], 0, a.size))
-        return tmp
-
-    def eval_expr_visit(self, e, eval_cache=None):
-        if eval_cache is None:
-            eval_cache = {}
-        # print 'visit', e, e.is_term
-        if e.is_term:
-            return e
-        if e in eval_cache:
-            return eval_cache[e]
-        c = e.__class__
-        deal_class = {m2_expr.ExprId: self.eval_ExprId,
-                      m2_expr.ExprInt: self.eval_ExprInt,
-                      m2_expr.ExprMem: self.eval_ExprMem,
-                      }
-        # print 'eval', e
-        if c in deal_class:
-            e = deal_class[c](e, eval_cache)
-        # print "ret", e
-        if not (isinstance(e, m2_expr.ExprId) or isinstance(e,
-                                                            m2_expr.ExprInt)):
-            e.is_term = True
-        return e
-
-    def eval_expr(self, e, eval_cache=None):
-        if eval_cache is None:
-            eval_cache = {}
-        r = e.visit(lambda x: self.eval_expr_visit(x, eval_cache))
-        return r
+        ret = self.expr_simp(self.symbols[ret][:size])
+        return ret
+
+
+    def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0):
+        """
+        Deep First evaluate nodes:
+            1. evaluate node's sons
+            2. simplify
+        """
+
+        #print '\t'*level, "Eval:", expr
+        if expr in cache:
+            ret = cache[expr]
+            #print "In cache!", ret
+        elif isinstance(expr, m2_expr.ExprInt):
+            return expr
+        elif isinstance(expr, m2_expr.ExprId):
+            if isinstance(expr.name, asmbloc.asm_label) and expr.name.offset is not None:
+                ret = m2_expr.ExprInt_from(expr, expr.name.offset)
+            else:
+                ret = state.get(expr, expr)
+        elif isinstance(expr, m2_expr.ExprMem):
+            ptr = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1)
+            ret = m2_expr.ExprMem(ptr, expr.size)
+            ret = self.get_mem_state(ret)
+            assert expr.size == ret.size
+        elif isinstance(expr, m2_expr.ExprCond):
+            cond = self.apply_expr_on_state_visit_cache(expr.cond, state, cache, level+1)
+            src1 = self.apply_expr_on_state_visit_cache(expr.src1, state, cache, level+1)
+            src2 = self.apply_expr_on_state_visit_cache(expr.src2, state, cache, level+1)
+            ret = m2_expr.ExprCond(cond, src1, src2)
+        elif isinstance(expr, m2_expr.ExprSlice):
+            arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1)
+            ret = m2_expr.ExprSlice(arg, expr.start, expr.stop)
+        elif isinstance(expr, m2_expr.ExprOp):
+            args = []
+            for oarg in expr.args:
+                arg = self.apply_expr_on_state_visit_cache(oarg, state, cache, level+1)
+                assert oarg.size == arg.size
+                args.append(arg)
+            ret = m2_expr.ExprOp(expr.op, *args)
+        elif isinstance(expr, m2_expr.ExprCompose):
+            args = []
+            for (arg, start, stop) in expr.args:
+                arg = self.apply_expr_on_state_visit_cache(arg, state, cache, level+1)
+                args.append((arg, start, stop))
+            ret = m2_expr.ExprCompose(args)
+        else:
+            raise TypeError("Unknown expr type")
+        #print '\t'*level, "Result", ret
+        ret = self.expr_simp(ret)
+        #print '\t'*level, "Result simpl", ret
+
+        assert expr.size == ret.size
+        cache[expr] = ret
+        return ret
+
+    def apply_expr_on_state(self, expr, cache):
+        if cache is None:
+            cache = {}
+        ret = self.apply_expr_on_state_visit_cache(expr, self.symbols, cache)
+        return ret
+
+    def eval_expr(self, expr, eval_cache=None):
+        return self.apply_expr_on_state(expr, eval_cache)
 
     def modified_regs(self, init_state=None):
         if init_state is None:
@@ -250,121 +260,111 @@ class symbexec(object):
             yield i
 
     def modified_mems(self, init_state=None):
+        if init_state is None:
+            init_state = self.ir_arch.arch.regs.regs_init
         mems = self.symbols.symbols_mem.values()
         mems.sort()
-        for m, _ in mems:
-            yield m
+        for mem, _ in mems:
+            if mem in init_state and \
+                    mem in self.symbols.symbols_mem and \
+                    self.symbols.symbols_mem[mem] == init_state[mem]:
+                continue
+            yield mem
 
     def modified(self, init_state=None):
-        for r in self.modified_regs(init_state):
-            yield r
-        for m in self.modified_mems(init_state):
-            yield m
+        for reg in self.modified_regs(init_state):
+            yield reg
+        for mem in self.modified_mems(init_state):
+            yield mem
 
     def dump_id(self):
+        """
+        Dump modififed registers symbols only
+        """
         ids = self.symbols.symbols_id.keys()
         ids.sort()
-        for i in ids:
-            if i in self.ir_arch.arch.regs.regs_init and \
-                    i in self.symbols.symbols_id and \
-                    self.symbols.symbols_id[i] == self.ir_arch.arch.regs.regs_init[i]:
+        for expr in ids:
+            if (expr in self.ir_arch.arch.regs.regs_init and
+                expr in self.symbols.symbols_id and
+                self.symbols.symbols_id[expr] == self.ir_arch.arch.regs.regs_init[expr]):
                 continue
-            print i, self.symbols.symbols_id[i]
+            print expr, "=", self.symbols.symbols_id[expr]
 
     def dump_mem(self):
+        """
+        Dump modififed memory symbols
+        """
         mems = self.symbols.symbols_mem.values()
         mems.sort()
-        for m, v in mems:
-            print m, v
+        for mem, value in mems:
+            print mem, value
 
     def rest_slice(self, slices, start, stop):
-        o = []
+        """
+        Return the complementary slices of @slices in the range @start, @stop
+        @slices: base slices
+        @start, @stop: interval range
+        """
+        out = []
         last = start
-        for _, a, b in slices:
-            if a == last:
-                last = b
+        for _, slice_start, slice_stop in slices:
+            if slice_start == last:
+                last = slice_stop
                 continue
-            o.append((last, a))
-            last = b
+            out.append((last, slice_start))
+            last = slice_stop
         if last != stop:
-            o.append((b, stop))
-        return o
-
-    def substract_mems(self, a, b):
-        ex = b.arg - a.arg
-        ex = self.expr_simp(self.eval_expr(ex, {}))
-        if not isinstance(ex, m2_expr.ExprInt):
-            return None
-        ptr_diff = int(int32(ex.arg))
-        out = []
-        if ptr_diff < 0:
-            #    [a     ]
-            #[b      ]XXX
-            sub_size = b.size + ptr_diff * 8
-            if sub_size >= a.size:
-                pass
-            else:
-                ex = m2_expr.ExprOp('+', a.arg,
-                                    m2_expr.ExprInt_from(a.arg, sub_size / 8))
-                ex = self.expr_simp(self.eval_expr(ex, {}))
+            out.append((slice_stop, stop))
+        return out
 
-                rest_ptr = ex
-                rest_size = a.size - sub_size
+    def substract_mems(self, arg1, arg2):
+        """
+        Return the remaining memory areas of @arg1 - @arg2
+        @arg1, @arg2: ExprMem
+        """
 
-                val = self.symbols[a][sub_size:a.size]
-                out = [(m2_expr.ExprMem(rest_ptr, rest_size), val)]
-        else:
-            #[a         ]
-            # XXXX[b   ]YY
+        ptr_diff = self.expr_simp(arg2.arg - arg1.arg)
+        ptr_diff = int(int32(ptr_diff.arg))
 
-            #[a     ]
-            # XXXX[b     ]
+        zone1 = interval([(0, arg1.size/8-1)])
+        zone2 = interval([(ptr_diff, ptr_diff + arg2.size/8-1)])
+        zones = zone1 - zone2
+
+        out = []
+        for start, stop in zones:
+            ptr = arg1.arg + m2_expr.ExprInt(start, arg1.arg.size)
+            ptr = self.expr_simp(ptr)
+            value = self.expr_simp(self.symbols[arg1][start*8:(stop+1)*8])
+            mem = m2_expr.ExprMem(ptr, (stop - start + 1)*8)
+            assert mem.size == value.size
+            out.append((mem, value))
 
-            out = []
-            # part X
-            if ptr_diff > 0:
-                val = self.symbols[a][0:ptr_diff * 8]
-                out.append((m2_expr.ExprMem(a.arg, ptr_diff * 8), val))
-            # part Y
-            if ptr_diff * 8 + b.size < a.size:
-
-                ex = m2_expr.ExprOp('+', b.arg,
-                                    m2_expr.ExprInt_from(b.arg, b.size / 8))
-                ex = self.expr_simp(self.eval_expr(ex, {}))
-
-                rest_ptr = ex
-                rest_size = a.size - (ptr_diff * 8 + b.size)
-                val = self.symbols[a][ptr_diff * 8 + b.size:a.size]
-                out.append((m2_expr.ExprMem(ex, val.size), val))
         return out
-    # give mem stored overlapping requested mem ptr
-    def get_mem_overlapping(self, e, eval_cache=None):
-        if eval_cache is None:
-            eval_cache = {}
-        if not isinstance(e, m2_expr.ExprMem):
-            raise ValueError('mem overlap bad arg')
-        ov = []
-        # suppose max mem size is 64 bytes, compute all reachable addresses
-        to_test = []
-        base_ptr = self.expr_simp(e.arg)
-        for i in xrange(-7, e.size / 8):
-            ex = self.expr_simp(
-                self.eval_expr(base_ptr + m2_expr.ExprInt_from(e.arg, i),
-                               eval_cache))
-            to_test.append((i, ex))
-
-        for i, x in to_test:
-            if not x in self.symbols.symbols_mem:
+
+    def get_mem_overlapping(self, expr):
+        """
+        Gives mem stored overlapping memory in @expr
+        Hypothesis: Max mem size is 64 bytes, compute all reachable addresses
+        @expr: target memory
+        """
+
+        overlaps = []
+        base_ptr = self.expr_simp(expr.arg)
+        for i in xrange(-7, expr.size / 8):
+            new_ptr = base_ptr + m2_expr.ExprInt(i, expr.arg.size)
+            new_ptr = self.expr_simp(new_ptr)
+
+            mem, origin = self.symbols.symbols_mem.get(new_ptr, (None, None))
+            if mem is None:
                 continue
-            ex = self.expr_simp(self.eval_expr(e.arg - x, eval_cache))
-            if not isinstance(ex, m2_expr.ExprInt):
-                raise ValueError('ex is not ExprInt')
-            ptr_diff = int32(ex.arg)
-            if ptr_diff >= self.symbols.symbols_mem[x][1].size / 8:
-                # print "too long!"
+
+            ptr_diff = -i
+            if ptr_diff >= origin.size / 8:
+                # access is too small to overlap the memory target
                 continue
-            ov.append((i, self.symbols.symbols_mem[x][0]))
-        return ov
+            overlaps.append((i, mem))
+
+        return overlaps
 
     def eval_ir_expr(self, assignblk):
         """
@@ -372,16 +372,14 @@ class symbexec(object):
         @assignblk: AssignBlock instance
         """
         pool_out = {}
-
-        eval_cache = dict(self.symbols.items())
+        eval_cache = {}
 
         for dst, src in assignblk.iteritems():
             src = self.eval_expr(src, eval_cache)
             if isinstance(dst, m2_expr.ExprMem):
-                a = self.eval_expr(dst.arg, eval_cache)
-                a = self.expr_simp(a)
+                ptr = self.eval_expr(dst.arg, eval_cache)
                 # test if mem lookup is known
-                tmp = m2_expr.ExprMem(a, dst.size)
+                tmp = m2_expr.ExprMem(ptr, dst.size)
                 pool_out[tmp] = src
 
             elif isinstance(dst, m2_expr.ExprId):
@@ -398,18 +396,18 @@ class symbexec(object):
         """
         mem_dst = []
         src_dst = self.eval_ir_expr(assignblk)
-        eval_cache = dict(self.symbols.items())
         for dst, src in src_dst:
             if isinstance(dst, m2_expr.ExprMem):
-                mem_overlap = self.get_mem_overlapping(dst, eval_cache)
+                mem_overlap = self.get_mem_overlapping(dst)
                 for _, base in mem_overlap:
                     diff_mem = self.substract_mems(base, dst)
                     del self.symbols[base]
                     for new_mem, new_val in diff_mem:
-                        new_val.is_term = True
                         self.symbols[new_mem] = new_val
             src_o = self.expr_simp(src)
             self.symbols[dst] = src_o
+            if dst == src_o:
+                del self.symbols[dst]
             if isinstance(dst, m2_expr.ExprMem):
                 if self.func_write and isinstance(dst.arg, m2_expr.ExprInt):
                     self.func_write(self, dst, src_o)
@@ -424,51 +422,52 @@ class symbexec(object):
         @step: display intermediate steps
         """
         for assignblk in irb.irs:
-            self.eval_ir(assignblk)
             if step:
+                print 'Assignblk:'
+                print assignblk
                 print '_' * 80
+            self.eval_ir(assignblk)
+            if step:
                 self.dump_id()
-        eval_cache = dict(self.symbols.items())
-        return self.eval_expr(self.ir_arch.IRDst, eval_cache)
+                self.dump_mem()
+                print '_' * 80
+        return self.eval_expr(self.ir_arch.IRDst)
 
-    def emul_ir_bloc(self, myir, ad, step=False):
-        b = myir.get_bloc(ad)
-        if b is not None:
-            ad = self.emulbloc(b, step=step)
-        return ad
+    def emul_ir_bloc(self, myir, addr, step=False):
+        irblock = myir.get_bloc(addr)
+        if irblock is not None:
+            addr = self.emulbloc(irblock, step=step)
+        return addr
 
-    def emul_ir_blocs(self, myir, ad, lbl_stop=None, step=False):
+    def emul_ir_blocs(self, myir, addr, lbl_stop=None, step=False):
         while True:
-            b = myir.get_bloc(ad)
-            if b is None:
+            irblock = myir.get_bloc(addr)
+            if irblock is None:
                 break
-            if b.label == lbl_stop:
+            if irblock.label == lbl_stop:
                 break
-            ad = self.emulbloc(b, step=step)
-        return ad
-
-    def del_mem_above_stack(self, sp):
-        sp_val = self.symbols[sp]
-        for mem_ad, (mem, _) in self.symbols.symbols_mem.items():
-            # print mem_ad, sp_val
-            diff = self.eval_expr(mem_ad - sp_val, {})
-            diff = expr_simp(diff)
+            addr = self.emulbloc(irblock, step=step)
+        return addr
+
+    def del_mem_above_stack(self, stack_ptr):
+        stack_ptr = self.eval_expr(stack_ptr)
+        for mem_addr, (mem, _) in self.symbols.symbols_mem.items():
+            diff = self.expr_simp(mem_addr - stack_ptr)
             if not isinstance(diff, m2_expr.ExprInt):
                 continue
-            m = expr_simp(diff.msb())
-            if m.arg == 1:
+            sign_bit = self.expr_simp(diff.msb())
+            if sign_bit.arg == 1:
                 del self.symbols[mem]
 
     def apply_expr(self, expr):
         """Evaluate @expr and apply side effect if needed (ie. if expr is an
         assignment). Return the evaluated value"""
 
-        # Eval expression
-        to_eval = expr.src if isinstance(expr, m2_expr.ExprAff) else expr
-        ret = self.expr_simp(self.eval_expr(to_eval))
-
         # Update value if needed
         if isinstance(expr, m2_expr.ExprAff):
-            self.eval_ir(AssignBlock([m2_expr.ExprAff(expr.dst, ret)]))
+            ret = self.eval_expr(expr.src)
+            self.eval_ir(AssignBlock([expr]))
+        else:
+            ret = self.eval_expr(expr)
 
         return ret