about summary refs log tree commit diff stats
path: root/miasm2/ir/ir.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/ir/ir.py')
-rw-r--r--miasm2/ir/ir.py354
1 files changed, 215 insertions, 139 deletions
diff --git a/miasm2/ir/ir.py b/miasm2/ir/ir.py
index fa34cd01..ffcf5480 100644
--- a/miasm2/ir/ir.py
+++ b/miasm2/ir/ir.py
@@ -18,7 +18,7 @@
 # with this program; if not, write to the Free Software Foundation, Inc.,
 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 #
-
+from itertools import chain
 
 import miasm2.expression.expression as m2_expr
 from miasm2.expression.expression_helper import get_missing_interval
@@ -28,6 +28,135 @@ from miasm2.core.asmbloc import asm_symbol_pool, expr_is_label, asm_label, \
 from miasm2.core.graph import DiGraph
 
 
+class AssignBlock(dict):
+
+    def __init__(self, irs=None):
+        """@irs seq"""
+        if irs is None:
+            irs = []
+        super(AssignBlock, self).__init__()
+
+        for expraff in irs:
+            # Concurrent assignments are handled in __setitem__
+            self[expraff.dst] = expraff.src
+
+    def __setitem__(self, dst, src):
+        """
+        Special cases:
+        * if dst is an ExprSlice, expand it to affect the full Expression
+        * if dst already known, sources are merged
+        """
+
+        if dst.size != src.size:
+            raise RuntimeError(
+                "sanitycheck: args must have same size! %s" %
+                ([(str(arg), arg.size) for arg in [dst, src]]))
+
+        if isinstance(dst, m2_expr.ExprSlice):
+            # Complete the source with missing slice parts
+            new_dst = dst.arg
+            rest = [(m2_expr.ExprSlice(dst.arg, r[0], r[1]), r[0], r[1])
+                    for r in dst.slice_rest()]
+            all_a = [(src, dst.start, dst.stop)] + rest
+            all_a.sort(key=lambda x: x[1])
+            new_src = m2_expr.ExprCompose(all_a)
+        else:
+            new_dst, new_src = dst, src
+
+        if new_dst in self and isinstance(new_src, m2_expr.ExprCompose):
+            if not isinstance(self[new_dst], m2_expr.ExprCompose):
+                # prev_RAX = 0x1122334455667788
+                # input_RAX[0:8] = 0x89
+                # final_RAX -> ? (assignment are in parallel)
+                raise RuntimeError("Concurent access on same bit not allowed")
+
+            # Consider slice grouping
+            expr_list = [(new_dst, new_src),
+                         (new_dst, self[new_dst])]
+            # Find collision
+            e_colision = reduce(lambda x, y: x.union(y),
+                                (self.get_modified_slice(dst, src)
+                                 for (dst, src) in expr_list),
+                                set())
+
+            # Sort interval collision
+            known_intervals = sorted([(x[1], x[2]) for x in e_colision])
+
+            for i, (_, stop) in enumerate(known_intervals[:-1]):
+                if stop > known_intervals[i + 1][0]:
+                    raise RuntimeError(
+                        "Concurent access on same bit not allowed")
+
+            # Fill with missing data
+            missing_i = get_missing_interval(known_intervals, 0, new_dst.size)
+            remaining = ((m2_expr.ExprSlice(new_dst, *interval),
+                          interval[0],
+                          interval[1])
+                         for interval in missing_i)
+
+            # Build the merging expression
+            new_src = m2_expr.ExprCompose(e_colision.union(remaining))
+
+        super(AssignBlock, self).__setitem__(new_dst, new_src)
+
+    @staticmethod
+    def get_modified_slice(dst, src):
+        """Return an Expr list of extra expressions needed during the
+        object instanciation"""
+
+        if not isinstance(src, m2_expr.ExprCompose):
+            raise ValueError("Get mod slice not on expraff slice", str(self))
+        modified_s = []
+        for arg in src.args:
+            if (not isinstance(arg[0], m2_expr.ExprSlice) or
+                    arg[0].arg != dst or
+                    arg[1] != arg[0].start or
+                    arg[2] != arg[0].stop):
+                # If x is not the initial expression
+                modified_s.append(arg)
+        return modified_s
+
+    def get_w(self):
+        """Return a set of elements written"""
+        return set(self.keys())
+
+    def get_rw(self, mem_read=False, cst_read=False):
+        """Return a dictionnary associating written expressions to a set of
+        their read requirements
+        @mem_read: (optional) mem_read argument of `get_r`
+        @cst_read: (optional) cst_read argument of `get_r`
+        """
+        out = {}
+        for dst, src in self.iteritems():
+            src_read = src.get_r(mem_read=mem_read, cst_read=cst_read)
+            if isinstance(dst, m2_expr.ExprMem):
+                # Read on destination happens only with ExprMem
+                src_read.update(dst.arg.get_r(mem_read=mem_read,
+                                              cst_read=cst_read))
+            out[dst] = src_read
+        return out
+
+    def get_r(self, mem_read=False, cst_read=False):
+        """Return a set of elements reads
+        @mem_read: (optional) mem_read argument of `get_r`
+        @cst_read: (optional) cst_read argument of `get_r`
+        """
+        return set(
+            chain.from_iterable(self.get_rw(mem_read=mem_read,
+                                            cst_read=cst_read).itervalues()))
+
+    def __str__(self):
+        out = []
+        for dst, src in sorted(self.iteritems()):
+            out.append("%s = %s" % (dst, src))
+        return "\n".join(out)
+
+    def dst2ExprAff(self, dst):
+        """Return an ExprAff corresponding to @dst equation
+        @dst: Expr instance"""
+        return m2_expr.ExprAff(dst, self[dst])
+
+
 class irbloc(object):
 
     def __init__(self, label, irs, lines=None):
@@ -45,26 +174,29 @@ class irbloc(object):
         """Find the IRDst affectation and update dst, dst_linenb accordingly"""
         if self._dst is not None:
             return self._dst
-        dst = None
-        for linenb, ir in enumerate(self.irs):
-            for i in ir:
-                if isinstance(i.dst, m2_expr.ExprId) and i.dst.name == "IRDst":
-                    if dst is not None:
+        final_dst = None
+        for linenb, assignblk in enumerate(self.irs):
+            for dst, src in assignblk.iteritems():
+                if isinstance(dst, m2_expr.ExprId) and dst.name == "IRDst":
+                    if final_dst is not None:
                         raise ValueError('Multiple destinations!')
-                    dst = i.src
-        self._dst = dst
+                    final_dst = src
+        self._dst = final_dst
         self._dst_linenb = linenb
-        return dst
+        return final_dst
 
     def _set_dst(self, value):
         """Find and replace the IRDst affectation's source by @value"""
         if self._dst_linenb is None:
             self._get_dst()
 
-        ir = self.irs[self._dst_linenb]
-        for i, expr in enumerate(ir):
-            if isinstance(expr.dst, m2_expr.ExprId) and expr.dst.name == "IRDst":
-                ir[i] = m2_expr.ExprAff(expr.dst, value)
+        assignblk = self.irs[self._dst_linenb]
+        for dst in assignblk:
+            if isinstance(dst, m2_expr.ExprId) and dst.name == "IRDst":
+                del(assignblk[dst])
+                assignblk[dst] = value
+                # Sanity check is already done in _get_dst
+                break
         self._dst = value
 
     dst = property(_get_dst, _set_dst)
@@ -90,34 +222,32 @@ class irbloc(object):
                          for _ in xrange(len(self.irs))]
         self.prev_kill = [{reg: set() for reg in regs_ids}
                           for _ in xrange(len(self.irs))]
+        # LineNumber -> dict:
+        #               Register: set(definition(irb label, index))
         self.defout = [{reg: set() for reg in regs_ids}
                        for _ in xrange(len(self.irs))]
-
-        for k, ir in enumerate(self.irs):
-            r, w = set(), set()
-            for i in ir:
-                r.update(x for x in i.get_r(True)
-                         if isinstance(x, m2_expr.ExprId))
-                w.update(x for x in i.get_w()
-                         if isinstance(x, m2_expr.ExprId))
-                if isinstance(i.dst, m2_expr.ExprMem):
-                    r.update(x for x in i.dst.arg.get_r(True)
-                             if isinstance(x, m2_expr.ExprId))
-                self.defout[k].update((x, {(self.label, k, i)})
-                                      for x in i.get_w()
-                                      if isinstance(x, m2_expr.ExprId))
-            self.r.append(r)
-            self.w.append(w)
+        keep_exprid = lambda elts: filter(lambda expr: isinstance(expr,
+                                                                  m2_expr.ExprId),
+                                          elts)
+        for idx, assignblk in enumerate(self.irs):
+            read, write = map(keep_exprid,
+                              (assignblk.get_r(mem_read=True),
+                               assignblk.get_w()))
+
+            self.defout[idx].update({dst: set([(self.label, idx, dst)])
+                                     for dst in assignblk
+                                     if isinstance(dst, m2_expr.ExprId)})
+            self.r.append(read)
+            self.w.append(write)
 
     def __str__(self):
-        o = []
-        o.append('%s' % self.label)
-        for expr in self.irs:
-            for e in expr:
-                o.append('\t%s' % e)
-            o.append("")
-
-        return "\n".join(o)
+        out = []
+        out.append('%s' % self.label)
+        for assignblk in self.irs:
+            for dst, src in assignblk.iteritems():
+                out.append('\t%s = %s' % (dst, src))
+            out.append("")
+        return "\n".join(out)
 
 
 class DiGraphIR(DiGraph):
@@ -139,13 +269,14 @@ class DiGraphIR(DiGraph):
         if node not in self._blocks:
             yield [self.DotCellDescription(text="NOT PRESENT", attr={})]
             raise StopIteration
-        for i, exprs in enumerate(self._blocks[node].irs):
-            for expr in exprs:
+        for i, assignblk in enumerate(self._blocks[node].irs):
+            for dst, src in assignblk.iteritems():
+                line = "%s = %s" % (dst, src)
                 if self._dot_offset:
                     yield [self.DotCellDescription(text="%-4d" % i, attr={}),
-                           self.DotCellDescription(text=str(expr), attr={})]
+                           self.DotCellDescription(text=line, attr={})]
                 else:
-                    yield self.DotCellDescription(text=str(expr), attr={})
+                    yield self.DotCellDescription(text=line, attr={})
             yield self.DotCellDescription(text="", attr={})
 
     def edge_attr(self, src, dst):
@@ -190,9 +321,15 @@ class ir(object):
         # Lazy structure
         self._graph = None
 
+    def get_ir(self, instr):
+        raise NotImplementedError("Abstract Method")
+
     def instr2ir(self, l):
-        ir_bloc_cur, ir_blocs_extra = self.get_ir(l)
-        return ir_bloc_cur, ir_blocs_extra
+        ir_bloc_cur, extra_assignblk = self.get_ir(l)
+        assignblk = AssignBlock(ir_bloc_cur)
+        for irb in extra_assignblk:
+            irb.irs = map(AssignBlock, irb.irs)
+        return assignblk, extra_assignblk
 
     def get_label(self, ad):
         """Transforms an ExprId/ExprInt/label/int into a label
@@ -221,62 +358,6 @@ class ir(object):
         b.lines = [l]
         self.add_bloc(b, gen_pc_updt)
 
-    def merge_multi_affect(self, affect_list):
-        """
-        If multiple affection to a same ExprId are present in @affect_list,
-        merge them (in place).
-        For instance, XCGH AH, AL semantic is
-        [
-            RAX = {RAX[0:8],0,8, RAX[0:8],8,16, RAX[16:64],16,64}
-            RAX = {RAX[8:16],0,8, RAX[8:64],8,64}
-        ]
-        This function will update @affect_list to replace previous ExprAff by
-        [
-            RAX = {RAX[8:16],0,8, RAX[0:8],8,16, RAX[16:64],16,64}
-        ]
-        """
-
-        # Extract side effect
-        effect = {}
-        for expr in affect_list:
-            effect[expr.dst] = effect.get(expr.dst, []) + [expr]
-
-        # Find candidates
-        for dst, expr_list in effect.items():
-            if len(expr_list) <= 1:
-                continue
-
-            # Only treat ExprCompose list
-            if any(map(lambda e: not(isinstance(e.src, m2_expr.ExprCompose)),
-                       expr_list)):
-                continue
-
-            # Find collision
-            e_colision = reduce(lambda x, y: x.union(y),
-                                (e.get_modified_slice() for e in expr_list),
-                                set())
-            # Sort interval collision
-            known_intervals = sorted([(x[1], x[2]) for x in e_colision])
-
-            # Fill with missing data
-            missing_i = get_missing_interval(known_intervals, 0, dst.size)
-
-            remaining = ((m2_expr.ExprSlice(dst, *interval),
-                          interval[0],
-                          interval[1])
-                         for interval in missing_i)
-
-            # Build the merging expression
-            slices = sorted(e_colision.union(remaining), key=lambda x: x[1])
-            final_dst = m2_expr.ExprCompose(slices)
-
-            # Remove unused expression
-            for expr in expr_list:
-                affect_list.remove(expr)
-
-            # Add the merged one
-            affect_list.append(m2_expr.ExprAff(dst, final_dst))
-
     def getby_offset(self, offset):
         out = set()
         for irb in self.blocs.values():
@@ -286,8 +367,9 @@ class ir(object):
         return out
 
     def gen_pc_update(self, c, l):
-        c.irs.append([m2_expr.ExprAff(self.pc, m2_expr.ExprInt_from(self.pc,
-                                                                    l.offset))])
+        c.irs.append(AssignBlock([m2_expr.ExprAff(self.pc,
+                                                  m2_expr.ExprInt_from(self.pc,
+                                                                       l.offset))]))
         c.lines.append(l)
 
     def add_bloc(self, bloc, gen_pc_updt=False):
@@ -298,12 +380,12 @@ class ir(object):
                 label = self.get_instr_label(l)
                 c = irbloc(label, [], [])
                 ir_blocs_all.append(c)
-            ir_bloc_cur, ir_blocs_extra = self.instr2ir(l)
+            assignblk, ir_blocs_extra = self.instr2ir(l)
 
             if gen_pc_updt is not False:
                 self.gen_pc_update(c, l)
 
-            c.irs.append(ir_bloc_cur)
+            c.irs.append(assignblk)
             c.lines.append(l)
 
             if ir_blocs_extra:
@@ -337,23 +419,15 @@ class ir(object):
                 continue
             dst = m2_expr.ExprId(self.get_next_label(bloc.lines[-1]),
                                  self.pc.size)
-            b.irs.append([m2_expr.ExprAff(self.IRDst, dst)])
+            b.irs.append(AssignBlock([m2_expr.ExprAff(self.IRDst, dst)]))
             b.lines.append(b.lines[-1])
 
-    def gen_edges(self, bloc, ir_blocs):
-        pass
-
     def post_add_bloc(self, bloc, ir_blocs):
         self.set_empty_dst_to_next(bloc, ir_blocs)
-        self.gen_edges(bloc, ir_blocs)
 
         for irb in ir_blocs:
             self.irbloc_fix_regs_for_mode(irb, self.attrib)
 
-            # Detect multi-affectation
-            for affect_list in irb.irs:
-                self.merge_multi_affect(affect_list)
-
             self.blocs[irb.label] = irb
 
         # Forget graph if any
@@ -375,15 +449,17 @@ class ir(object):
         return l
 
     def simplify_blocs(self):
-        for b in self.blocs.values():
-            for ir in b.irs:
-                for i, r in enumerate(ir):
-                    ir[i] = m2_expr.ExprAff(expr_simp(r.dst), expr_simp(r.src))
+        for irb in self.blocs.values():
+            for assignblk in irb.irs:
+                for dst, src in assignblk.items():
+                    del assignblk[dst]
+                    assignblk[expr_simp(dst)] = expr_simp(src)
 
     def replace_expr_in_ir(self, bloc, rep):
-        for irs in bloc.irs:
-            for i, l in enumerate(irs):
-                irs[i] = l.replace_expr(rep)
+        for assignblk in bloc.irs:
+            for dst, src in assignblk.items():
+                del assignblk[dst]
+                assignblk[dst.replace_expr(rep)] = src.replace_expr(rep)
 
     def get_rw(self, regs_ids=None):
         """
@@ -395,7 +471,11 @@ class ir(object):
         for b in self.blocs.values():
             b.get_rw(regs_ids)
 
-    def sort_dst(self, todo, done):
+    def _extract_dst(self, todo, done):
+        """
+        Naive extraction of @todo destinations
+        WARNING: @todo and @done are modified
+        """
         out = set()
         while todo:
             dst = todo.pop()
@@ -412,30 +492,26 @@ class ir(object):
                 done.add(dst)
         return out
 
-    def dst_trackback(self, b):
-        dst = b.dst
-        todo = set([dst])
+    def dst_trackback(self, irb):
+        """
+        Naive backtracking of IRDst
+        @irb: irbloc instance
+        """
+        todo = set([irb.dst])
         done = set()
 
-        for irs in reversed(b.irs):
-            if len(todo) == 0:
+        for assignblk in reversed(irb.irs):
+            if not todo:
                 break
-            out = self.sort_dst(todo, done)
+            out = self._extract_dst(todo, done)
             found = set()
             follow = set()
-            for i in irs:
-                if not out:
-                    break
-                for o in out:
-                    if i.dst == o:
-                        follow.add(i.src)
-                        found.add(o)
-                for o in found:
-                    out.remove(o)
-
-            for o in out:
-                if o not in found:
-                    follow.add(o)
+            for dst in out:
+                if dst in assignblk:
+                    follow.add(assignblk[dst])
+                    found.add(dst)
+
+            follow.update(out.difference(found))
             todo = follow
 
         return done