diff options
Diffstat (limited to 'miasm2/ir/symbexec.py')
| -rw-r--r-- | miasm2/ir/symbexec.py | 603 |
1 files changed, 301 insertions, 302 deletions
diff --git a/miasm2/ir/symbexec.py b/miasm2/ir/symbexec.py index 1dc8dde1..d3c56f70 100644 --- a/miasm2/ir/symbexec.py +++ b/miasm2/ir/symbexec.py @@ -3,6 +3,8 @@ from miasm2.expression.modint import int32 from miasm2.expression.simplifications import expr_simp from miasm2.core import asmbloc from miasm2.ir.ir import AssignBlock +from miasm2.core.interval import interval + import logging @@ -13,72 +15,82 @@ log.addHandler(console_handler) log.setLevel(logging.INFO) -class symbols(): +class symbols(object): def __init__(self, init=None): if init is None: init = {} self.symbols_id = {} self.symbols_mem = {} - for k, v in init.items(): - self[k] = v + for expr, value in init.items(): + self[expr] = value - def __contains__(self, a): - if not isinstance(a, m2_expr.ExprMem): - return self.symbols_id.__contains__(a) - if not self.symbols_mem.__contains__(a.arg): + def __contains__(self, expr): + if not isinstance(expr, m2_expr.ExprMem): + return self.symbols_id.__contains__(expr) + if not self.symbols_mem.__contains__(expr.arg): return False - return self.symbols_mem[a.arg][0].size == a.size - - def __getitem__(self, a): - if not isinstance(a, m2_expr.ExprMem): - return self.symbols_id.__getitem__(a) - if not a.arg in self.symbols_mem: - raise KeyError(a) - m = self.symbols_mem.__getitem__(a.arg) - if m[0].size != a.size: - raise KeyError(a) - return m[1] - - def __setitem__(self, a, v): - if not isinstance(a, m2_expr.ExprMem): - self.symbols_id.__setitem__(a, v) + return self.symbols_mem[expr.arg][0].size == expr.size + + def __getitem__(self, expr): + if not isinstance(expr, m2_expr.ExprMem): + return self.symbols_id.__getitem__(expr) + if not expr.arg in self.symbols_mem: + raise KeyError(expr) + mem, value = self.symbols_mem.__getitem__(expr.arg) + if mem.size != expr.size: + raise KeyError(expr) + return value + + def get(self, expr, default=None): + if not isinstance(expr, m2_expr.ExprMem): + return self.symbols_id.get(expr, default) + if not expr.arg in self.symbols_mem: + return default + mem, value = self.symbols_mem.__getitem__(expr.arg) + if mem.size != expr.size: + return default + return value + + def __setitem__(self, expr, value): + if not isinstance(expr, m2_expr.ExprMem): + self.symbols_id.__setitem__(expr, value) return - self.symbols_mem.__setitem__(a.arg, (a, v)) + assert expr.size == value.size + self.symbols_mem.__setitem__(expr.arg, (expr, value)) def __iter__(self): - for a in self.symbols_id: - yield a - for a in self.symbols_mem: - yield self.symbols_mem[a][0] - - def __delitem__(self, a): - if not isinstance(a, m2_expr.ExprMem): - self.symbols_id.__delitem__(a) + for expr in self.symbols_id: + yield expr + for expr in self.symbols_mem: + yield self.symbols_mem[expr][0] + + def __delitem__(self, expr): + if not isinstance(expr, m2_expr.ExprMem): + self.symbols_id.__delitem__(expr) else: - self.symbols_mem.__delitem__(a.arg) + self.symbols_mem.__delitem__(expr.arg) def items(self): - k = self.symbols_id.items() + [x for x in self.symbols_mem.values()] - return k + return self.symbols_id.items() + [x for x in self.symbols_mem.values()] def keys(self): - k = self.symbols_id.keys() + [x[0] for x in self.symbols_mem.values()] - return k + return (self.symbols_id.keys() + + [x[0] for x in self.symbols_mem.values()]) def copy(self): - p = symbols() - p.symbols_id = dict(self.symbols_id) - p.symbols_mem = dict(self.symbols_mem) - return p + new_symbols = symbols() + new_symbols.symbols_id = dict(self.symbols_id) + new_symbols.symbols_mem = dict(self.symbols_mem) + return new_symbols def inject_info(self, info): - s = symbols() - for k, v in self.items(): - k = expr_simp(k.replace_expr(info)) - v = expr_simp(v.replace_expr(info)) - s[k] = v - return s + new_symbols = symbols() + for expr, value in self.items(): + expr = expr_simp(expr.replace_expr(info)) + value = expr_simp(value.replace_expr(info)) + new_symbols[expr] = value + return new_symbols class symbexec(object): @@ -88,154 +100,152 @@ class symbexec(object): func_write=None, sb_expr_simp=expr_simp): self.symbols = symbols() - for k, v in known_symbols.items(): - self.symbols[k] = v + for expr, value in known_symbols.items(): + self.symbols[expr] = value self.func_read = func_read self.func_write = func_write self.ir_arch = ir_arch self.expr_simp = sb_expr_simp - def find_mem_by_addr(self, e): - if e in self.symbols.symbols_mem: - return self.symbols.symbols_mem[e][0] + def find_mem_by_addr(self, expr): + """ + Return memory keys with pointer equal to @expr + @expr: address of the searched memory variable + """ + if expr in self.symbols.symbols_mem: + return self.symbols.symbols_mem[expr][0] return None - def eval_ExprId(self, e, eval_cache=None): - if eval_cache is None: - eval_cache = {} - if isinstance(e.name, asmbloc.asm_label) and e.name.offset is not None: - return m2_expr.ExprInt_from(e, e.name.offset) - if not e in self.symbols: - # raise ValueError('unknown symbol %s'% e) - return e - return self.symbols[e] - - def eval_ExprInt(self, e, eval_cache=None): - return e - - def eval_ExprMem(self, e, eval_cache=None): - if eval_cache is None: - eval_cache = {} - a_val = self.expr_simp(self.eval_expr(e.arg, eval_cache)) - if a_val != e.arg: - a = self.expr_simp(m2_expr.ExprMem(a_val, size=e.size)) - else: - a = e - if a in self.symbols: - return self.symbols[a] - tmp = None - # test if mem lookup is known - if a_val in self.symbols.symbols_mem: - tmp = self.symbols.symbols_mem[a_val][0] - if tmp is None: - - v = self.find_mem_by_addr(a_val) - if not v: - out = [] - ov = self.get_mem_overlapping(a, eval_cache) - off_base = 0 - ov.sort() - # ov.reverse() - for off, x in ov: - # off_base = off * 8 - # x_size = self.symbols[x].size - if off >= 0: - m = min(a.size - off * 8, x.size) - ee = m2_expr.ExprSlice(self.symbols[x], 0, m) - ee = self.expr_simp(ee) - out.append((ee, off_base, off_base + m)) - off_base += m - else: - m = min(a.size - off * 8, x.size) - ee = m2_expr.ExprSlice(self.symbols[x], -off * 8, m) - ff = self.expr_simp(ee) - new_off_base = off_base + m + off * 8 - out.append((ff, off_base, new_off_base)) - off_base = new_off_base - if out: - missing_slice = self.rest_slice(out, 0, a.size) - for sa, sb in missing_slice: - ptr = self.expr_simp( - a_val + m2_expr.ExprInt_from(a_val, sa / 8) - ) - mm = m2_expr.ExprMem(ptr, size=sb - sa) - mm.is_term = True - mm.is_simp = True - out.append((mm, sa, sb)) - out.sort(key=lambda x: x[1]) - # for e, sa, sb in out: - # print str(e), sa, sb - ee = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, a.size) - ee = self.expr_simp(ee) - return ee - if self.func_read and isinstance(a.arg, m2_expr.ExprInt): - return self.func_read(a) + def get_mem_state(self, expr): + """ + Evaluate the @expr memory in the current state using @cache + @expr: the memory key + """ + ptr, size = expr.arg, expr.size + ret = self.find_mem_by_addr(ptr) + if not ret: + out = [] + overlaps = self.get_mem_overlapping(expr) + off_base = 0 + for off, mem in overlaps: + if off >= 0: + new_size = min(size - off * 8, mem.size) + tmp = self.expr_simp(self.symbols[mem][0:new_size]) + out.append((tmp, off_base, off_base + new_size)) + off_base += new_size + else: + new_size = min(size - off * 8, mem.size) + tmp = self.expr_simp(self.symbols[mem][-off * 8:new_size]) + new_off_base = off_base + new_size + off * 8 + out.append((tmp, off_base, new_off_base)) + off_base = new_off_base + if out: + missing_slice = self.rest_slice(out, 0, size) + for slice_start, slice_stop in missing_slice: + ptr = self.expr_simp(ptr + m2_expr.ExprInt(slice_start / 8, ptr.size)) + mem = m2_expr.ExprMem(ptr, slice_stop - slice_start) + out.append((mem, slice_start, slice_stop)) + out.sort(key=lambda x: x[1]) + tmp = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, size) + tmp = self.expr_simp(tmp) + return tmp + + + if self.func_read and isinstance(ptr, m2_expr.ExprInt): + return self.func_read(expr) else: - # XXX hack test - a.is_term = True - return a + return expr # bigger lookup - if a.size > tmp.size: - rest = a.size - ptr = a_val + if size > ret.size: + rest = size + ptr = ptr out = [] ptr_index = 0 while rest: - v = self.find_mem_by_addr(ptr) - if v is None: - # raise ValueError("cannot find %s in mem"%str(ptr)) - val = m2_expr.ExprMem(ptr, 8) - v = val + mem = self.find_mem_by_addr(ptr) + if mem is None: + value = m2_expr.ExprMem(ptr, 8) + mem = value diff_size = 8 - elif rest >= v.size: - val = self.symbols[v] - diff_size = v.size + elif rest >= mem.size: + value = self.symbols[mem] + diff_size = mem.size else: diff_size = rest - val = self.symbols[v][0:diff_size] - val = (val, ptr_index, ptr_index + diff_size) - out.append(val) + value = self.symbols[mem][0:diff_size] + out.append((value, ptr_index, ptr_index + diff_size)) ptr_index += diff_size rest -= diff_size - ptr = self.expr_simp( - self.eval_expr( - m2_expr.ExprOp('+', ptr, - m2_expr.ExprInt_from(ptr, v.size / 8)), - eval_cache) - ) - e = self.expr_simp(m2_expr.ExprCompose(out)) - return e + ptr = self.expr_simp(ptr + m2_expr.ExprInt(mem.size / 8, ptr.size)) + ret = self.expr_simp(m2_expr.ExprCompose(out)) + return ret # part lookup - tmp = self.expr_simp(m2_expr.ExprSlice(self.symbols[tmp], 0, a.size)) - return tmp - - def eval_expr_visit(self, e, eval_cache=None): - if eval_cache is None: - eval_cache = {} - # print 'visit', e, e.is_term - if e.is_term: - return e - if e in eval_cache: - return eval_cache[e] - c = e.__class__ - deal_class = {m2_expr.ExprId: self.eval_ExprId, - m2_expr.ExprInt: self.eval_ExprInt, - m2_expr.ExprMem: self.eval_ExprMem, - } - # print 'eval', e - if c in deal_class: - e = deal_class[c](e, eval_cache) - # print "ret", e - if not (isinstance(e, m2_expr.ExprId) or isinstance(e, - m2_expr.ExprInt)): - e.is_term = True - return e - - def eval_expr(self, e, eval_cache=None): - if eval_cache is None: - eval_cache = {} - r = e.visit(lambda x: self.eval_expr_visit(x, eval_cache)) - return r + ret = self.expr_simp(self.symbols[ret][:size]) + return ret + + + def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0): + """ + Deep First evaluate nodes: + 1. evaluate node's sons + 2. simplify + """ + + #print '\t'*level, "Eval:", expr + if expr in cache: + ret = cache[expr] + #print "In cache!", ret + elif isinstance(expr, m2_expr.ExprInt): + return expr + elif isinstance(expr, m2_expr.ExprId): + if isinstance(expr.name, asmbloc.asm_label) and expr.name.offset is not None: + ret = m2_expr.ExprInt_from(expr, expr.name.offset) + else: + ret = state.get(expr, expr) + elif isinstance(expr, m2_expr.ExprMem): + ptr = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1) + ret = m2_expr.ExprMem(ptr, expr.size) + ret = self.get_mem_state(ret) + assert expr.size == ret.size + elif isinstance(expr, m2_expr.ExprCond): + cond = self.apply_expr_on_state_visit_cache(expr.cond, state, cache, level+1) + src1 = self.apply_expr_on_state_visit_cache(expr.src1, state, cache, level+1) + src2 = self.apply_expr_on_state_visit_cache(expr.src2, state, cache, level+1) + ret = m2_expr.ExprCond(cond, src1, src2) + elif isinstance(expr, m2_expr.ExprSlice): + arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1) + ret = m2_expr.ExprSlice(arg, expr.start, expr.stop) + elif isinstance(expr, m2_expr.ExprOp): + args = [] + for oarg in expr.args: + arg = self.apply_expr_on_state_visit_cache(oarg, state, cache, level+1) + assert oarg.size == arg.size + args.append(arg) + ret = m2_expr.ExprOp(expr.op, *args) + elif isinstance(expr, m2_expr.ExprCompose): + args = [] + for (arg, start, stop) in expr.args: + arg = self.apply_expr_on_state_visit_cache(arg, state, cache, level+1) + args.append((arg, start, stop)) + ret = m2_expr.ExprCompose(args) + else: + raise TypeError("Unknown expr type") + #print '\t'*level, "Result", ret + ret = self.expr_simp(ret) + #print '\t'*level, "Result simpl", ret + + assert expr.size == ret.size + cache[expr] = ret + return ret + + def apply_expr_on_state(self, expr, cache): + if cache is None: + cache = {} + ret = self.apply_expr_on_state_visit_cache(expr, self.symbols, cache) + return ret + + def eval_expr(self, expr, eval_cache=None): + return self.apply_expr_on_state(expr, eval_cache) def modified_regs(self, init_state=None): if init_state is None: @@ -250,121 +260,111 @@ class symbexec(object): yield i def modified_mems(self, init_state=None): + if init_state is None: + init_state = self.ir_arch.arch.regs.regs_init mems = self.symbols.symbols_mem.values() mems.sort() - for m, _ in mems: - yield m + for mem, _ in mems: + if mem in init_state and \ + mem in self.symbols.symbols_mem and \ + self.symbols.symbols_mem[mem] == init_state[mem]: + continue + yield mem def modified(self, init_state=None): - for r in self.modified_regs(init_state): - yield r - for m in self.modified_mems(init_state): - yield m + for reg in self.modified_regs(init_state): + yield reg + for mem in self.modified_mems(init_state): + yield mem def dump_id(self): + """ + Dump modififed registers symbols only + """ ids = self.symbols.symbols_id.keys() ids.sort() - for i in ids: - if i in self.ir_arch.arch.regs.regs_init and \ - i in self.symbols.symbols_id and \ - self.symbols.symbols_id[i] == self.ir_arch.arch.regs.regs_init[i]: + for expr in ids: + if (expr in self.ir_arch.arch.regs.regs_init and + expr in self.symbols.symbols_id and + self.symbols.symbols_id[expr] == self.ir_arch.arch.regs.regs_init[expr]): continue - print i, self.symbols.symbols_id[i] + print expr, "=", self.symbols.symbols_id[expr] def dump_mem(self): + """ + Dump modififed memory symbols + """ mems = self.symbols.symbols_mem.values() mems.sort() - for m, v in mems: - print m, v + for mem, value in mems: + print mem, value def rest_slice(self, slices, start, stop): - o = [] + """ + Return the complementary slices of @slices in the range @start, @stop + @slices: base slices + @start, @stop: interval range + """ + out = [] last = start - for _, a, b in slices: - if a == last: - last = b + for _, slice_start, slice_stop in slices: + if slice_start == last: + last = slice_stop continue - o.append((last, a)) - last = b + out.append((last, slice_start)) + last = slice_stop if last != stop: - o.append((b, stop)) - return o - - def substract_mems(self, a, b): - ex = b.arg - a.arg - ex = self.expr_simp(self.eval_expr(ex, {})) - if not isinstance(ex, m2_expr.ExprInt): - return None - ptr_diff = int(int32(ex.arg)) - out = [] - if ptr_diff < 0: - # [a ] - #[b ]XXX - sub_size = b.size + ptr_diff * 8 - if sub_size >= a.size: - pass - else: - ex = m2_expr.ExprOp('+', a.arg, - m2_expr.ExprInt_from(a.arg, sub_size / 8)) - ex = self.expr_simp(self.eval_expr(ex, {})) + out.append((slice_stop, stop)) + return out - rest_ptr = ex - rest_size = a.size - sub_size + def substract_mems(self, arg1, arg2): + """ + Return the remaining memory areas of @arg1 - @arg2 + @arg1, @arg2: ExprMem + """ - val = self.symbols[a][sub_size:a.size] - out = [(m2_expr.ExprMem(rest_ptr, rest_size), val)] - else: - #[a ] - # XXXX[b ]YY + ptr_diff = self.expr_simp(arg2.arg - arg1.arg) + ptr_diff = int(int32(ptr_diff.arg)) - #[a ] - # XXXX[b ] + zone1 = interval([(0, arg1.size/8-1)]) + zone2 = interval([(ptr_diff, ptr_diff + arg2.size/8-1)]) + zones = zone1 - zone2 + + out = [] + for start, stop in zones: + ptr = arg1.arg + m2_expr.ExprInt(start, arg1.arg.size) + ptr = self.expr_simp(ptr) + value = self.expr_simp(self.symbols[arg1][start*8:(stop+1)*8]) + mem = m2_expr.ExprMem(ptr, (stop - start + 1)*8) + assert mem.size == value.size + out.append((mem, value)) - out = [] - # part X - if ptr_diff > 0: - val = self.symbols[a][0:ptr_diff * 8] - out.append((m2_expr.ExprMem(a.arg, ptr_diff * 8), val)) - # part Y - if ptr_diff * 8 + b.size < a.size: - - ex = m2_expr.ExprOp('+', b.arg, - m2_expr.ExprInt_from(b.arg, b.size / 8)) - ex = self.expr_simp(self.eval_expr(ex, {})) - - rest_ptr = ex - rest_size = a.size - (ptr_diff * 8 + b.size) - val = self.symbols[a][ptr_diff * 8 + b.size:a.size] - out.append((m2_expr.ExprMem(ex, val.size), val)) return out - # give mem stored overlapping requested mem ptr - def get_mem_overlapping(self, e, eval_cache=None): - if eval_cache is None: - eval_cache = {} - if not isinstance(e, m2_expr.ExprMem): - raise ValueError('mem overlap bad arg') - ov = [] - # suppose max mem size is 64 bytes, compute all reachable addresses - to_test = [] - base_ptr = self.expr_simp(e.arg) - for i in xrange(-7, e.size / 8): - ex = self.expr_simp( - self.eval_expr(base_ptr + m2_expr.ExprInt_from(e.arg, i), - eval_cache)) - to_test.append((i, ex)) - - for i, x in to_test: - if not x in self.symbols.symbols_mem: + + def get_mem_overlapping(self, expr): + """ + Gives mem stored overlapping memory in @expr + Hypothesis: Max mem size is 64 bytes, compute all reachable addresses + @expr: target memory + """ + + overlaps = [] + base_ptr = self.expr_simp(expr.arg) + for i in xrange(-7, expr.size / 8): + new_ptr = base_ptr + m2_expr.ExprInt(i, expr.arg.size) + new_ptr = self.expr_simp(new_ptr) + + mem, origin = self.symbols.symbols_mem.get(new_ptr, (None, None)) + if mem is None: continue - ex = self.expr_simp(self.eval_expr(e.arg - x, eval_cache)) - if not isinstance(ex, m2_expr.ExprInt): - raise ValueError('ex is not ExprInt') - ptr_diff = int32(ex.arg) - if ptr_diff >= self.symbols.symbols_mem[x][1].size / 8: - # print "too long!" + + ptr_diff = -i + if ptr_diff >= origin.size / 8: + # access is too small to overlap the memory target continue - ov.append((i, self.symbols.symbols_mem[x][0])) - return ov + overlaps.append((i, mem)) + + return overlaps def eval_ir_expr(self, assignblk): """ @@ -372,16 +372,14 @@ class symbexec(object): @assignblk: AssignBlock instance """ pool_out = {} - - eval_cache = dict(self.symbols.items()) + eval_cache = {} for dst, src in assignblk.iteritems(): src = self.eval_expr(src, eval_cache) if isinstance(dst, m2_expr.ExprMem): - a = self.eval_expr(dst.arg, eval_cache) - a = self.expr_simp(a) + ptr = self.eval_expr(dst.arg, eval_cache) # test if mem lookup is known - tmp = m2_expr.ExprMem(a, dst.size) + tmp = m2_expr.ExprMem(ptr, dst.size) pool_out[tmp] = src elif isinstance(dst, m2_expr.ExprId): @@ -398,18 +396,18 @@ class symbexec(object): """ mem_dst = [] src_dst = self.eval_ir_expr(assignblk) - eval_cache = dict(self.symbols.items()) for dst, src in src_dst: if isinstance(dst, m2_expr.ExprMem): - mem_overlap = self.get_mem_overlapping(dst, eval_cache) + mem_overlap = self.get_mem_overlapping(dst) for _, base in mem_overlap: diff_mem = self.substract_mems(base, dst) del self.symbols[base] for new_mem, new_val in diff_mem: - new_val.is_term = True self.symbols[new_mem] = new_val src_o = self.expr_simp(src) self.symbols[dst] = src_o + if dst == src_o: + del self.symbols[dst] if isinstance(dst, m2_expr.ExprMem): if self.func_write and isinstance(dst.arg, m2_expr.ExprInt): self.func_write(self, dst, src_o) @@ -424,51 +422,52 @@ class symbexec(object): @step: display intermediate steps """ for assignblk in irb.irs: - self.eval_ir(assignblk) if step: + print 'Assignblk:' + print assignblk print '_' * 80 + self.eval_ir(assignblk) + if step: self.dump_id() - eval_cache = dict(self.symbols.items()) - return self.eval_expr(self.ir_arch.IRDst, eval_cache) + self.dump_mem() + print '_' * 80 + return self.eval_expr(self.ir_arch.IRDst) - def emul_ir_bloc(self, myir, ad, step=False): - b = myir.get_bloc(ad) - if b is not None: - ad = self.emulbloc(b, step=step) - return ad + def emul_ir_bloc(self, myir, addr, step=False): + irblock = myir.get_bloc(addr) + if irblock is not None: + addr = self.emulbloc(irblock, step=step) + return addr - def emul_ir_blocs(self, myir, ad, lbl_stop=None, step=False): + def emul_ir_blocs(self, myir, addr, lbl_stop=None, step=False): while True: - b = myir.get_bloc(ad) - if b is None: + irblock = myir.get_bloc(addr) + if irblock is None: break - if b.label == lbl_stop: + if irblock.label == lbl_stop: break - ad = self.emulbloc(b, step=step) - return ad - - def del_mem_above_stack(self, sp): - sp_val = self.symbols[sp] - for mem_ad, (mem, _) in self.symbols.symbols_mem.items(): - # print mem_ad, sp_val - diff = self.eval_expr(mem_ad - sp_val, {}) - diff = expr_simp(diff) + addr = self.emulbloc(irblock, step=step) + return addr + + def del_mem_above_stack(self, stack_ptr): + stack_ptr = self.eval_expr(stack_ptr) + for mem_addr, (mem, _) in self.symbols.symbols_mem.items(): + diff = self.expr_simp(mem_addr - stack_ptr) if not isinstance(diff, m2_expr.ExprInt): continue - m = expr_simp(diff.msb()) - if m.arg == 1: + sign_bit = self.expr_simp(diff.msb()) + if sign_bit.arg == 1: del self.symbols[mem] def apply_expr(self, expr): """Evaluate @expr and apply side effect if needed (ie. if expr is an assignment). Return the evaluated value""" - # Eval expression - to_eval = expr.src if isinstance(expr, m2_expr.ExprAff) else expr - ret = self.expr_simp(self.eval_expr(to_eval)) - # Update value if needed if isinstance(expr, m2_expr.ExprAff): - self.eval_ir(AssignBlock([m2_expr.ExprAff(expr.dst, ret)])) + ret = self.eval_expr(expr.src) + self.eval_ir(AssignBlock([expr])) + else: + ret = self.eval_expr(expr) return ret |