about summary refs log tree commit diff stats
path: root/miasm2/ir/symbexec.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/ir/symbexec.py')
-rw-r--r--miasm2/ir/symbexec.py435
1 files changed, 435 insertions, 0 deletions
diff --git a/miasm2/ir/symbexec.py b/miasm2/ir/symbexec.py
new file mode 100644
index 00000000..08608142
--- /dev/null
+++ b/miasm2/ir/symbexec.py
@@ -0,0 +1,435 @@
+from miasm2.expression.expression import *
+from miasm2.expression.simplifications import expr_simp
+from miasm2.core import asmbloc
+import logging
+
+
+log = logging.getLogger("symbexec")
+console_handler = logging.StreamHandler()
+console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
+log.addHandler(console_handler)
+log.setLevel(logging.INFO)
+
+
+class symbols():
+
+    def __init__(self, init=None):
+        if init is None:
+            init = {}
+        self.symbols_id = {}
+        self.symbols_mem = {}
+        for k, v in init.items():
+            self[k] = v
+
+    def __contains__(self, a):
+        if not isinstance(a, ExprMem):
+            return self.symbols_id.__contains__(a)
+        if not self.symbols_mem.__contains__(a.arg):
+            return False
+        return self.symbols_mem[a.arg][0].size == a.size
+
+    def __getitem__(self, a):
+        if not isinstance(a, ExprMem):
+            return self.symbols_id.__getitem__(a)
+        if not a.arg in self.symbols_mem:
+            raise KeyError, a
+        m = self.symbols_mem.__getitem__(a.arg)
+        if m[0].size != a.size:
+            raise KeyError, a
+        return m[1]
+
+    def __setitem__(self, a, v):
+        if not isinstance(a, ExprMem):
+            self.symbols_id.__setitem__(a, v)
+            return
+        self.symbols_mem.__setitem__(a.arg, (a, v))
+
+    def __iter__(self):
+        for a in self.symbols_id:
+            yield a
+        for a in self.symbols_mem:
+            yield self.symbols_mem[a][0]
+
+    def __delitem__(self, a):
+        if not isinstance(a, ExprMem):
+            self.symbols_id.__delitem__(a)
+        else:
+            self.symbols_mem.__delitem__(a.arg)
+
+    def items(self):
+        k = self.symbols_id.items() + [x for x in self.symbols_mem.values()]
+        return k
+
+    def keys(self):
+        k = self.symbols_id.keys() + [x[0] for x in self.symbols_mem.values()]
+        return k
+
+    def copy(self):
+        p = symbols()
+        p.symbols_id = dict(self.symbols_id)
+        p.symbols_mem = dict(self.symbols_mem)
+        return p
+
+    def inject_info(self, info):
+        s = symbols()
+        for k, v in self.items():
+            k = expr_simp(k.replace_expr(info))
+            v = expr_simp(v.replace_expr(info))
+            s[k] = v
+        return s
+
+
+class symbexec:
+
+    def __init__(self, arch, known_symbols,
+                 func_read=None,
+                 func_write=None,
+                 sb_expr_simp=expr_simp):
+        self.symbols = symbols()
+        for k, v in known_symbols.items():
+            self.symbols[k] = v
+        self.func_read = func_read
+        self.func_write = func_write
+        self.arch = arch
+        self.expr_simp = sb_expr_simp
+
+    def find_mem_by_addr(self, e):
+        if e in self.symbols.symbols_mem:
+            return self.symbols.symbols_mem[e][0]
+        return None
+
+    def eval_ExprId(self, e, eval_cache=None):
+        if isinstance(e.name, asmbloc.asm_label) and e.name.offset is not None:
+            return ExprInt_from(e, e.name.offset)
+        if not e in self.symbols:
+            # raise ValueError('unknown symbol %s'% e)
+            return e
+        return self.symbols[e]
+
+    def eval_ExprInt(self, e, eval_cache=None):
+        return e
+
+    def eval_ExprMem(self, e, eval_cache=None):
+        a_val = self.expr_simp(self.eval_expr(e.arg, eval_cache))
+        if a_val != e.arg:
+            a = self.expr_simp(ExprMem(a_val, size=e.size))
+        else:
+            a = e
+        if a in self.symbols:
+            return self.symbols[a]
+        tmp = None
+        # test if mem lookup is known
+        if a_val in self.symbols.symbols_mem:
+            tmp = self.symbols.symbols_mem[a_val][0]
+        if tmp is None:
+
+            v = self.find_mem_by_addr(a_val)
+            if not v:
+                out = []
+                ov = self.get_mem_overlapping(a, eval_cache)
+                off_base = 0
+                ov.sort()
+                # ov.reverse()
+                for off, x in ov:
+                    # off_base = off * 8
+                    # x_size = self.symbols[x].size
+                    if off >= 0:
+                        m = min(a.size - off * 8, x.size)
+                        ee = ExprSlice(self.symbols[x], 0, m)
+                        ee = self.expr_simp(ee)
+                        out.append((ee, off_base, off_base + m))
+                        off_base += m
+                    else:
+                        m = min(a.size - off * 8, x.size)
+                        ee = ExprSlice(self.symbols[x], -off * 8, m)
+                        ff = self.expr_simp(ee)
+                        new_off_base = off_base + m + off * 8
+                        out.append((ff, off_base, new_off_base))
+                        off_base = new_off_base
+                if out:
+                    missing_slice = self.rest_slice(out, 0, a.size)
+                    for sa, sb in missing_slice:
+                        ptr = self.expr_simp(a_val + ExprInt32(sa / 8))
+                        mm = ExprMem(ptr, size=sb - sa)
+                        mm.is_term = True
+                        mm.is_simp = True
+                        out.append((mm, sa, sb))
+                    out.sort(key=lambda x: x[1])
+                    # for e, sa, sb in out:
+                    #    print str(e), sa, sb
+                    ee = ExprSlice(ExprCompose(out), 0, a.size)
+                    ee = self.expr_simp(ee)
+                    return ee
+            if self.func_read and isinstance(a.arg, ExprInt):
+                return self.func_read(a)
+            else:
+                # XXX hack test
+                a.is_term = True
+                return a
+        # bigger lookup
+        if a.size > tmp.size:
+            rest = a.size
+            ptr = a_val
+            out = []
+            ptr_index = 0
+            while rest:
+                v = self.find_mem_by_addr(ptr)
+                if v is None:
+                    # raise ValueError("cannot find %s in mem"%str(ptr))
+                    val = ExprMem(ptr, 8)
+                    v = val
+                    diff_size = 8
+                elif rest >= v.size:
+                    val = self.symbols[v]
+                    diff_size = v.size
+                else:
+                    diff_size = rest
+                    val = self.symbols[v][0:diff_size]
+                val = (val, ptr_index, ptr_index + diff_size)
+                out.append(val)
+                ptr_index += diff_size
+                rest -= diff_size
+                ptr = self.expr_simp(self.eval_expr(ExprOp('+', ptr,
+                    ExprInt_from(ptr, v.size / 8)), eval_cache))
+            e = self.expr_simp(ExprCompose(out))
+            return e
+        # part lookup
+        tmp = self.expr_simp(ExprSlice(self.symbols[tmp], 0, a.size))
+        return tmp
+
+    def eval_expr_visit(self, e, eval_cache=None):
+        # print 'visit', e, e.is_term
+        if e.is_term:
+            return e
+        c = e.__class__
+        deal_class = {ExprId: self.eval_ExprId,
+                      ExprInt: self.eval_ExprInt,
+                      ExprMem: self.eval_ExprMem,
+                      }
+        # print 'eval', e
+        if c in deal_class:
+            e = deal_class[c](e, eval_cache)
+        # print "ret", e
+        if not (isinstance(e, ExprId) or isinstance(e, ExprInt)):
+            e.is_term = True
+        return e
+
+    def eval_expr(self, e, eval_cache=None):
+        r = e.visit(lambda x: self.eval_expr_visit(x, eval_cache))
+        return r
+
+    def modified_regs(self, init_state=None):
+        if init_state is None:
+            init_state = self.arch.regs.regs_init
+        ids = self.symbols.symbols_id.keys()
+        ids.sort()
+        for i in ids:
+            if i in init_state and \
+                    i in self.symbols.symbols_id and \
+                    self.symbols.symbols_id[i] == init_state[i]:
+                continue
+            yield i
+
+    def modified_mems(self, init_state=None):
+        mems = self.symbols.symbols_mem.values()
+        mems.sort()
+        for m, _ in mems:
+            yield m
+
+    def modified(self, init_state=None):
+        for r in self.modified_regs(init_state):
+            yield r
+        for m in self.modified_mems(init_state):
+            yield m
+
+    def dump_id(self):
+        ids = self.symbols.symbols_id.keys()
+        ids.sort()
+        for i in ids:
+            if i in self.arch.regs.regs_init and \
+                    i in self.symbols.symbols_id and \
+                    self.symbols.symbols_id[i] == self.arch.regs.regs_init[i]:
+                continue
+            print i, self.symbols.symbols_id[i]
+
+    def dump_mem(self):
+        mems = self.symbols.symbols_mem.values()
+        mems.sort()
+        for m, v in mems:
+            print m, v
+
+    def rest_slice(self, slices, start, stop):
+        o = []
+        last = start
+        for _, a, b in slices:
+            if a == last:
+                last = b
+                continue
+            o.append((last, a))
+            last = b
+        if last != stop:
+            o.append((b, stop))
+        return o
+
+    def substract_mems(self, a, b):
+        ex = ExprOp('-', b.arg, a.arg)
+        ex = self.expr_simp(self.eval_expr(ex, {}))
+        if not isinstance(ex, ExprInt):
+            return None
+        ptr_diff = int(int32(ex.arg))
+        out = []
+        if ptr_diff < 0:
+            #    [a     ]
+            #[b      ]XXX
+            sub_size = b.size + ptr_diff * 8
+            if sub_size >= a.size:
+                pass
+            else:
+                ex = ExprOp('+', a.arg, ExprInt_from(a.arg, sub_size / 8))
+                ex = self.expr_simp(self.eval_expr(ex, {}))
+
+                rest_ptr = ex
+                rest_size = a.size - sub_size
+
+                val = self.symbols[a][sub_size:a.size]
+                out = [(ExprMem(rest_ptr, rest_size), val)]
+        else:
+            #[a         ]
+            # XXXX[b   ]YY
+
+            #[a     ]
+            # XXXX[b     ]
+
+            out = []
+            # part X
+            if ptr_diff > 0:
+                val = self.symbols[a][0:ptr_diff * 8]
+                out.append((ExprMem(a.arg, ptr_diff * 8), val))
+            # part Y
+            if ptr_diff * 8 + b.size < a.size:
+
+                ex = ExprOp('+', b.arg, ExprInt_from(b.arg, b.size / 8))
+                ex = self.expr_simp(self.eval_expr(ex, {}))
+
+                rest_ptr = ex
+                rest_size = a.size - (ptr_diff * 8 + b.size)
+                val = self.symbols[a][ptr_diff * 8 + b.size:a.size]
+                out.append((ExprMem(ex, val.size), val))
+        return out
+
+    # give mem stored overlapping requested mem ptr
+    def get_mem_overlapping(self, e, eval_cache=None):
+        if not isinstance(e, ExprMem):
+            raise ValueError('mem overlap bad arg')
+        ov = []
+        # suppose max mem size is 64 bytes, compute all reachable addresses
+        to_test = []
+        base_ptr = self.expr_simp(e.arg)
+        for i in xrange(-7, e.size / 8):
+            ex = self.expr_simp(
+                self.eval_expr(base_ptr + ExprInt_from(e.arg, i), eval_cache))
+            to_test.append((i, ex))
+
+        for i, x in to_test:
+            if not x in self.symbols.symbols_mem:
+                continue
+            ex = self.expr_simp(self.eval_expr(e.arg - x, eval_cache))
+            if not isinstance(ex, ExprInt):
+                raise ValueError('ex is not ExprInt')
+            ptr_diff = int32(ex.arg)
+            if ptr_diff >= self.symbols.symbols_mem[x][1].size / 8:
+                # print "too long!"
+                continue
+            ov.append((i, self.symbols.symbols_mem[x][0]))
+        return ov
+
+    def eval_ir_expr(self, exprs):
+        pool_out = {}
+
+        eval_cache = {}
+
+        for e in exprs:
+            if not isinstance(e, ExprAff):
+                raise TypeError('not affect', str(e))
+
+            src = self.eval_expr(e.src, eval_cache)
+            if isinstance(e.dst, ExprMem):
+                a = self.eval_expr(e.dst.arg, eval_cache)
+                a = self.expr_simp(a)
+                # search already present mem
+                tmp = None
+                # test if mem lookup is known
+                tmp = ExprMem(a, e.dst.size)
+                dst = tmp
+                if self.func_write and isinstance(dst.arg, ExprInt):
+                    self.func_write(self, dst, src, pool_out)
+                else:
+                    pool_out[dst] = src
+
+            elif isinstance(e.dst, ExprId):
+                pool_out[e.dst] = src
+            else:
+                raise ValueError("affected zarb", str(e.dst))
+
+        return pool_out.items()
+
+    def eval_ir(self, ir):
+        mem_dst = []
+        # src_dst = [(x.src, x.dst) for x in ir]
+        src_dst = self.eval_ir_expr(ir)
+
+        for dst, src in src_dst:
+            if isinstance(dst, ExprMem):
+                mem_overlap = self.get_mem_overlapping(dst)
+                for _, base in mem_overlap:
+                    diff_mem = self.substract_mems(base, dst)
+                    del(self.symbols[base])
+                    for new_mem, new_val in diff_mem:
+                        new_val.is_term = True
+                        self.symbols[new_mem] = new_val
+            src_o = self.expr_simp(src)
+            # print 'SRCo', src_o
+            # src_o.is_term = True
+            self.symbols[dst] = src_o
+            if isinstance(dst, ExprMem):
+                mem_dst.append(dst)
+        return mem_dst
+
+    def emulbloc(self, bloc_ir, step=False):
+        for ir in bloc_ir.irs:
+            self.eval_ir(ir)
+            if step:
+                print '_' * 80
+                self.dump_id()
+        if bloc_ir.dst is None:
+            return None
+        return self.eval_expr(bloc_ir.dst)
+
+    def emul_ir_bloc(self, myir, ad):
+        b = myir.get_bloc(ad)
+        if b is not None:
+            ad = self.emulbloc(b)
+        return ad
+
+    def emul_ir_blocs(self, myir, ad, lbl_stop=None):
+        while True:
+            b = myir.get_bloc(ad)
+            if b is None:
+                break
+            if b.label == lbl_stop:
+                break
+            ad = self.emulbloc(b)
+        return ad
+
+    def del_mem_above_stack(self, sp):
+        sp_val = self.symbols[sp]
+        for mem_ad, (mem, _) in self.symbols.symbols_mem.items():
+            # print mem_ad, sp_val
+            diff = self.eval_expr(mem_ad - sp_val, {})
+            diff = expr_simp(diff)
+            if not isinstance(diff, ExprInt):
+                continue
+            m = expr_simp(diff.msb())
+            if m.arg == 1:
+                del(self.symbols[mem])
+