diff options
Diffstat (limited to 'miasm2/ir/ir.py')
| -rw-r--r-- | miasm2/ir/ir.py | 354 |
1 files changed, 215 insertions, 139 deletions
diff --git a/miasm2/ir/ir.py b/miasm2/ir/ir.py index fa34cd01..ffcf5480 100644 --- a/miasm2/ir/ir.py +++ b/miasm2/ir/ir.py @@ -18,7 +18,7 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # - +from itertools import chain import miasm2.expression.expression as m2_expr from miasm2.expression.expression_helper import get_missing_interval @@ -28,6 +28,135 @@ from miasm2.core.asmbloc import asm_symbol_pool, expr_is_label, asm_label, \ from miasm2.core.graph import DiGraph +class AssignBlock(dict): + + def __init__(self, irs=None): + """@irs seq""" + if irs is None: + irs = [] + super(AssignBlock, self).__init__() + + for expraff in irs: + # Concurrent assignments are handled in __setitem__ + self[expraff.dst] = expraff.src + + def __setitem__(self, dst, src): + """ + Special cases: + * if dst is an ExprSlice, expand it to affect the full Expression + * if dst already known, sources are merged + """ + + if dst.size != src.size: + raise RuntimeError( + "sanitycheck: args must have same size! %s" % + ([(str(arg), arg.size) for arg in [dst, src]])) + + if isinstance(dst, m2_expr.ExprSlice): + # Complete the source with missing slice parts + new_dst = dst.arg + rest = [(m2_expr.ExprSlice(dst.arg, r[0], r[1]), r[0], r[1]) + for r in dst.slice_rest()] + all_a = [(src, dst.start, dst.stop)] + rest + all_a.sort(key=lambda x: x[1]) + new_src = m2_expr.ExprCompose(all_a) + else: + new_dst, new_src = dst, src + + if new_dst in self and isinstance(new_src, m2_expr.ExprCompose): + if not isinstance(self[new_dst], m2_expr.ExprCompose): + # prev_RAX = 0x1122334455667788 + # input_RAX[0:8] = 0x89 + # final_RAX -> ? (assignment are in parallel) + raise RuntimeError("Concurent access on same bit not allowed") + + # Consider slice grouping + expr_list = [(new_dst, new_src), + (new_dst, self[new_dst])] + # Find collision + e_colision = reduce(lambda x, y: x.union(y), + (self.get_modified_slice(dst, src) + for (dst, src) in expr_list), + set()) + + # Sort interval collision + known_intervals = sorted([(x[1], x[2]) for x in e_colision]) + + for i, (_, stop) in enumerate(known_intervals[:-1]): + if stop > known_intervals[i + 1][0]: + raise RuntimeError( + "Concurent access on same bit not allowed") + + # Fill with missing data + missing_i = get_missing_interval(known_intervals, 0, new_dst.size) + remaining = ((m2_expr.ExprSlice(new_dst, *interval), + interval[0], + interval[1]) + for interval in missing_i) + + # Build the merging expression + new_src = m2_expr.ExprCompose(e_colision.union(remaining)) + + super(AssignBlock, self).__setitem__(new_dst, new_src) + + @staticmethod + def get_modified_slice(dst, src): + """Return an Expr list of extra expressions needed during the + object instanciation""" + + if not isinstance(src, m2_expr.ExprCompose): + raise ValueError("Get mod slice not on expraff slice", str(self)) + modified_s = [] + for arg in src.args: + if (not isinstance(arg[0], m2_expr.ExprSlice) or + arg[0].arg != dst or + arg[1] != arg[0].start or + arg[2] != arg[0].stop): + # If x is not the initial expression + modified_s.append(arg) + return modified_s + + def get_w(self): + """Return a set of elements written""" + return set(self.keys()) + + def get_rw(self, mem_read=False, cst_read=False): + """Return a dictionnary associating written expressions to a set of + their read requirements + @mem_read: (optional) mem_read argument of `get_r` + @cst_read: (optional) cst_read argument of `get_r` + """ + out = {} + for dst, src in self.iteritems(): + src_read = src.get_r(mem_read=mem_read, cst_read=cst_read) + if isinstance(dst, m2_expr.ExprMem): + # Read on destination happens only with ExprMem + src_read.update(dst.arg.get_r(mem_read=mem_read, + cst_read=cst_read)) + out[dst] = src_read + return out + + def get_r(self, mem_read=False, cst_read=False): + """Return a set of elements reads + @mem_read: (optional) mem_read argument of `get_r` + @cst_read: (optional) cst_read argument of `get_r` + """ + return set( + chain.from_iterable(self.get_rw(mem_read=mem_read, + cst_read=cst_read).itervalues())) + + def __str__(self): + out = [] + for dst, src in sorted(self.iteritems()): + out.append("%s = %s" % (dst, src)) + return "\n".join(out) + + def dst2ExprAff(self, dst): + """Return an ExprAff corresponding to @dst equation + @dst: Expr instance""" + return m2_expr.ExprAff(dst, self[dst]) + + class irbloc(object): def __init__(self, label, irs, lines=None): @@ -45,26 +174,29 @@ class irbloc(object): """Find the IRDst affectation and update dst, dst_linenb accordingly""" if self._dst is not None: return self._dst - dst = None - for linenb, ir in enumerate(self.irs): - for i in ir: - if isinstance(i.dst, m2_expr.ExprId) and i.dst.name == "IRDst": - if dst is not None: + final_dst = None + for linenb, assignblk in enumerate(self.irs): + for dst, src in assignblk.iteritems(): + if isinstance(dst, m2_expr.ExprId) and dst.name == "IRDst": + if final_dst is not None: raise ValueError('Multiple destinations!') - dst = i.src - self._dst = dst + final_dst = src + self._dst = final_dst self._dst_linenb = linenb - return dst + return final_dst def _set_dst(self, value): """Find and replace the IRDst affectation's source by @value""" if self._dst_linenb is None: self._get_dst() - ir = self.irs[self._dst_linenb] - for i, expr in enumerate(ir): - if isinstance(expr.dst, m2_expr.ExprId) and expr.dst.name == "IRDst": - ir[i] = m2_expr.ExprAff(expr.dst, value) + assignblk = self.irs[self._dst_linenb] + for dst in assignblk: + if isinstance(dst, m2_expr.ExprId) and dst.name == "IRDst": + del(assignblk[dst]) + assignblk[dst] = value + # Sanity check is already done in _get_dst + break self._dst = value dst = property(_get_dst, _set_dst) @@ -90,34 +222,32 @@ class irbloc(object): for _ in xrange(len(self.irs))] self.prev_kill = [{reg: set() for reg in regs_ids} for _ in xrange(len(self.irs))] + # LineNumber -> dict: + # Register: set(definition(irb label, index)) self.defout = [{reg: set() for reg in regs_ids} for _ in xrange(len(self.irs))] - - for k, ir in enumerate(self.irs): - r, w = set(), set() - for i in ir: - r.update(x for x in i.get_r(True) - if isinstance(x, m2_expr.ExprId)) - w.update(x for x in i.get_w() - if isinstance(x, m2_expr.ExprId)) - if isinstance(i.dst, m2_expr.ExprMem): - r.update(x for x in i.dst.arg.get_r(True) - if isinstance(x, m2_expr.ExprId)) - self.defout[k].update((x, {(self.label, k, i)}) - for x in i.get_w() - if isinstance(x, m2_expr.ExprId)) - self.r.append(r) - self.w.append(w) + keep_exprid = lambda elts: filter(lambda expr: isinstance(expr, + m2_expr.ExprId), + elts) + for idx, assignblk in enumerate(self.irs): + read, write = map(keep_exprid, + (assignblk.get_r(mem_read=True), + assignblk.get_w())) + + self.defout[idx].update({dst: set([(self.label, idx, dst)]) + for dst in assignblk + if isinstance(dst, m2_expr.ExprId)}) + self.r.append(read) + self.w.append(write) def __str__(self): - o = [] - o.append('%s' % self.label) - for expr in self.irs: - for e in expr: - o.append('\t%s' % e) - o.append("") - - return "\n".join(o) + out = [] + out.append('%s' % self.label) + for assignblk in self.irs: + for dst, src in assignblk.iteritems(): + out.append('\t%s = %s' % (dst, src)) + out.append("") + return "\n".join(out) class DiGraphIR(DiGraph): @@ -139,13 +269,14 @@ class DiGraphIR(DiGraph): if node not in self._blocks: yield [self.DotCellDescription(text="NOT PRESENT", attr={})] raise StopIteration - for i, exprs in enumerate(self._blocks[node].irs): - for expr in exprs: + for i, assignblk in enumerate(self._blocks[node].irs): + for dst, src in assignblk.iteritems(): + line = "%s = %s" % (dst, src) if self._dot_offset: yield [self.DotCellDescription(text="%-4d" % i, attr={}), - self.DotCellDescription(text=str(expr), attr={})] + self.DotCellDescription(text=line, attr={})] else: - yield self.DotCellDescription(text=str(expr), attr={}) + yield self.DotCellDescription(text=line, attr={}) yield self.DotCellDescription(text="", attr={}) def edge_attr(self, src, dst): @@ -190,9 +321,15 @@ class ir(object): # Lazy structure self._graph = None + def get_ir(self, instr): + raise NotImplementedError("Abstract Method") + def instr2ir(self, l): - ir_bloc_cur, ir_blocs_extra = self.get_ir(l) - return ir_bloc_cur, ir_blocs_extra + ir_bloc_cur, extra_assignblk = self.get_ir(l) + assignblk = AssignBlock(ir_bloc_cur) + for irb in extra_assignblk: + irb.irs = map(AssignBlock, irb.irs) + return assignblk, extra_assignblk def get_label(self, ad): """Transforms an ExprId/ExprInt/label/int into a label @@ -221,62 +358,6 @@ class ir(object): b.lines = [l] self.add_bloc(b, gen_pc_updt) - def merge_multi_affect(self, affect_list): - """ - If multiple affection to a same ExprId are present in @affect_list, - merge them (in place). - For instance, XCGH AH, AL semantic is - [ - RAX = {RAX[0:8],0,8, RAX[0:8],8,16, RAX[16:64],16,64} - RAX = {RAX[8:16],0,8, RAX[8:64],8,64} - ] - This function will update @affect_list to replace previous ExprAff by - [ - RAX = {RAX[8:16],0,8, RAX[0:8],8,16, RAX[16:64],16,64} - ] - """ - - # Extract side effect - effect = {} - for expr in affect_list: - effect[expr.dst] = effect.get(expr.dst, []) + [expr] - - # Find candidates - for dst, expr_list in effect.items(): - if len(expr_list) <= 1: - continue - - # Only treat ExprCompose list - if any(map(lambda e: not(isinstance(e.src, m2_expr.ExprCompose)), - expr_list)): - continue - - # Find collision - e_colision = reduce(lambda x, y: x.union(y), - (e.get_modified_slice() for e in expr_list), - set()) - # Sort interval collision - known_intervals = sorted([(x[1], x[2]) for x in e_colision]) - - # Fill with missing data - missing_i = get_missing_interval(known_intervals, 0, dst.size) - - remaining = ((m2_expr.ExprSlice(dst, *interval), - interval[0], - interval[1]) - for interval in missing_i) - - # Build the merging expression - slices = sorted(e_colision.union(remaining), key=lambda x: x[1]) - final_dst = m2_expr.ExprCompose(slices) - - # Remove unused expression - for expr in expr_list: - affect_list.remove(expr) - - # Add the merged one - affect_list.append(m2_expr.ExprAff(dst, final_dst)) - def getby_offset(self, offset): out = set() for irb in self.blocs.values(): @@ -286,8 +367,9 @@ class ir(object): return out def gen_pc_update(self, c, l): - c.irs.append([m2_expr.ExprAff(self.pc, m2_expr.ExprInt_from(self.pc, - l.offset))]) + c.irs.append(AssignBlock([m2_expr.ExprAff(self.pc, + m2_expr.ExprInt_from(self.pc, + l.offset))])) c.lines.append(l) def add_bloc(self, bloc, gen_pc_updt=False): @@ -298,12 +380,12 @@ class ir(object): label = self.get_instr_label(l) c = irbloc(label, [], []) ir_blocs_all.append(c) - ir_bloc_cur, ir_blocs_extra = self.instr2ir(l) + assignblk, ir_blocs_extra = self.instr2ir(l) if gen_pc_updt is not False: self.gen_pc_update(c, l) - c.irs.append(ir_bloc_cur) + c.irs.append(assignblk) c.lines.append(l) if ir_blocs_extra: @@ -337,23 +419,15 @@ class ir(object): continue dst = m2_expr.ExprId(self.get_next_label(bloc.lines[-1]), self.pc.size) - b.irs.append([m2_expr.ExprAff(self.IRDst, dst)]) + b.irs.append(AssignBlock([m2_expr.ExprAff(self.IRDst, dst)])) b.lines.append(b.lines[-1]) - def gen_edges(self, bloc, ir_blocs): - pass - def post_add_bloc(self, bloc, ir_blocs): self.set_empty_dst_to_next(bloc, ir_blocs) - self.gen_edges(bloc, ir_blocs) for irb in ir_blocs: self.irbloc_fix_regs_for_mode(irb, self.attrib) - # Detect multi-affectation - for affect_list in irb.irs: - self.merge_multi_affect(affect_list) - self.blocs[irb.label] = irb # Forget graph if any @@ -375,15 +449,17 @@ class ir(object): return l def simplify_blocs(self): - for b in self.blocs.values(): - for ir in b.irs: - for i, r in enumerate(ir): - ir[i] = m2_expr.ExprAff(expr_simp(r.dst), expr_simp(r.src)) + for irb in self.blocs.values(): + for assignblk in irb.irs: + for dst, src in assignblk.items(): + del assignblk[dst] + assignblk[expr_simp(dst)] = expr_simp(src) def replace_expr_in_ir(self, bloc, rep): - for irs in bloc.irs: - for i, l in enumerate(irs): - irs[i] = l.replace_expr(rep) + for assignblk in bloc.irs: + for dst, src in assignblk.items(): + del assignblk[dst] + assignblk[dst.replace_expr(rep)] = src.replace_expr(rep) def get_rw(self, regs_ids=None): """ @@ -395,7 +471,11 @@ class ir(object): for b in self.blocs.values(): b.get_rw(regs_ids) - def sort_dst(self, todo, done): + def _extract_dst(self, todo, done): + """ + Naive extraction of @todo destinations + WARNING: @todo and @done are modified + """ out = set() while todo: dst = todo.pop() @@ -412,30 +492,26 @@ class ir(object): done.add(dst) return out - def dst_trackback(self, b): - dst = b.dst - todo = set([dst]) + def dst_trackback(self, irb): + """ + Naive backtracking of IRDst + @irb: irbloc instance + """ + todo = set([irb.dst]) done = set() - for irs in reversed(b.irs): - if len(todo) == 0: + for assignblk in reversed(irb.irs): + if not todo: break - out = self.sort_dst(todo, done) + out = self._extract_dst(todo, done) found = set() follow = set() - for i in irs: - if not out: - break - for o in out: - if i.dst == o: - follow.add(i.src) - found.add(o) - for o in found: - out.remove(o) - - for o in out: - if o not in found: - follow.add(o) + for dst in out: + if dst in assignblk: + follow.add(assignblk[dst]) + found.add(dst) + + follow.update(out.difference(found)) todo = follow return done |