about summary refs log tree commit diff stats
path: root/miasm2/ir/ir.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/ir/ir.py')
-rw-r--r--miasm2/ir/ir.py235
1 files changed, 153 insertions, 82 deletions
diff --git a/miasm2/ir/ir.py b/miasm2/ir/ir.py
index f8ac6722..caf2c0ed 100644
--- a/miasm2/ir/ir.py
+++ b/miasm2/ir/ir.py
@@ -29,19 +29,41 @@ from miasm2.core.asmblock import AsmSymbolPool, expr_is_label, AsmLabel, \
 from miasm2.core.graph import DiGraph
 
 
-class AssignBlock(dict):
+class AssignBlock(object):
+    """Represent parallel IR assignment, such as:
+    EAX = EBX
+    EBX = EAX
 
-    def __init__(self, irs=None):
-        """@irs seq"""
+    Also provides common manipulation on this assignments
+    """
+    __slots__ = ["_assigns", "_instr"]
+
+    def __init__(self, irs=None, instr=None):
+        """Create a new AssignBlock
+        @irs: (optional) sequence of ExprAff, or dictionnary dst (Expr) -> src
+              (Expr)
+        @instr: (optional) associate an instruction with this AssignBlock
+
+        """
         if irs is None:
             irs = []
-        super(AssignBlock, self).__init__()
+        self._instr = instr
+        self._assigns = {} # ExprAff.dst -> ExprAff.src
 
-        for expraff in irs:
-            # Concurrent assignments are handled in __setitem__
-            self[expraff.dst] = expraff.src
+        # Concurrent assignments are handled in _set
+        if hasattr(irs, "iteritems"):
+            for dst, src in irs.iteritems():
+                self._set(dst, src)
+        else:
+            for expraff in irs:
+                self._set(expraff.dst, expraff.src)
 
-    def __setitem__(self, dst, src):
+    @property
+    def instr(self):
+        """Return the associated instruction, if any"""
+        return self._instr
+
+    def _set(self, dst, src):
         """
         Special cases:
         * if dst is an ExprSlice, expand it to affect the full Expression
@@ -64,7 +86,7 @@ class AssignBlock(dict):
         else:
             new_dst, new_src = dst, src
 
-        if new_dst in self and isinstance(new_src, m2_expr.ExprCompose):
+        if new_dst in self._assigns and isinstance(new_src, m2_expr.ExprCompose):
             if not isinstance(self[new_dst], m2_expr.ExprCompose):
                 # prev_RAX = 0x1122334455667788
                 # input_RAX[0:8] = 0x89
@@ -103,7 +125,51 @@ class AssignBlock(dict):
             args = [expr for (expr, _, _) in args]
             new_src = m2_expr.ExprCompose(*args)
 
-        super(AssignBlock, self).__setitem__(new_dst, new_src)
+        self._assigns[new_dst] = new_src
+
+    def __setitem__(self, dst, src):
+        raise RuntimeError('AssignBlock is immutable')
+
+    def __getitem__(self, key):
+        return self._assigns[key]
+
+    def __contains__(self, key):
+        return key in self._assigns
+
+    def iteritems(self):
+        for dst, src in self._assigns.iteritems():
+            yield dst, src
+
+    def itervalues(self):
+        for src in self._assigns.itervalues():
+            yield src
+
+    def keys(self):
+        return self._assigns.keys()
+
+    def values(self):
+        return self._assigns.values()
+
+    def __iter__(self):
+        for dst in self._assigns:
+            yield dst
+
+    def __delitem__(self, _):
+        raise RuntimeError('AssignBlock is immutable')
+
+    def update(self, _):
+        raise RuntimeError('AssignBlock is immutable')
+
+    def __eq__(self, other):
+        if self.keys() != other.keys():
+            return False
+        return all(other[dst] == src for dst, src in self.iteritems())
+
+    def __len__(self):
+        return len(self._assigns)
+
+    def get(self, key, default):
+        return self._assigns.get(key, default)
 
     @staticmethod
     def get_modified_slice(dst, src):
@@ -152,7 +218,7 @@ class AssignBlock(dict):
 
     def __str__(self):
         out = []
-        for dst, src in sorted(self.iteritems()):
+        for dst, src in sorted(self._assigns.iteritems()):
             out.append("%s = %s" % (dst, src))
         return "\n".join(out)
 
@@ -161,6 +227,18 @@ class AssignBlock(dict):
         @dst: Expr instance"""
         return m2_expr.ExprAff(dst, self[dst])
 
+    def simplify(self, simplifier):
+        """Return a new AssignBlock with expression simplified
+        @simplifier: ExpressionSimplifier instance"""
+        new_assignblk = {}
+        for dst, src in self.iteritems():
+            if dst == src:
+                continue
+            src = simplifier(src)
+            dst = simplifier(dst)
+            new_assignblk[dst] = src
+        return AssignBlock(irs=new_assignblk, instr=self.instr)
+
 
 class IRBlock(object):
     """Intermediate representation block object.
@@ -168,19 +246,15 @@ class IRBlock(object):
     Stand for an intermediate representation  basic block.
     """
 
-    def __init__(self, label, irs, lines=None):
+    def __init__(self, label, irs):
         """
         @label: AsmLabel of the IR basic block
         @irs: list of AssignBlock
-        @lines: list of native instructions
         """
 
         assert isinstance(label, AsmLabel)
-        if lines is None:
-            lines = []
         self.label = label
         self.irs = irs
-        self.lines = lines
         self.except_automod = True
         self._dst = None
         self._dst_linenb = None
@@ -207,15 +281,15 @@ class IRBlock(object):
         if self._dst_linenb is None:
             self._get_dst()
 
-        assignblk = self.irs[self._dst_linenb]
-        for dst in assignblk:
+        new_assignblk = dict(self.irs[self._dst_linenb])
+        for dst in new_assignblk:
             if isinstance(dst, m2_expr.ExprId) and dst.name == "IRDst":
-                del assignblk[dst]
-                assignblk[dst] = value
+                new_assignblk[dst] = value
                 # Sanity check is already done in _get_dst
                 break
         self._dst = value
-
+        instr = self.irs[self._dst_linenb].instr
+        self.irs[self._dst_linenb] = AssignBlock(new_assignblk, instr)
     dst = property(_get_dst, _set_dst)
 
     @property
@@ -262,7 +336,7 @@ class irbloc(IRBlock):
 
     def __init__(self, label, irs, lines=None):
         warnings.warn('DEPRECATION WARNING: use "IRBlock" instead of "irblock"')
-        super(irbloc, self).__init__(label, irs, lines)
+        super(irbloc, self).__init__(label, irs)
 
 
 class DiGraphIR(DiGraph):
@@ -349,54 +423,55 @@ class IntermediateRepresentation(object):
     def get_ir(self, instr):
         raise NotImplementedError("Abstract Method")
 
-    def instr2ir(self, l):
-        ir_bloc_cur, extra_assignblk = self.get_ir(l)
-        assignblk = AssignBlock(ir_bloc_cur)
+    def instr2ir(self, instr):
+        ir_bloc_cur, extra_assignblk = self.get_ir(instr)
         for irb in extra_assignblk:
-            irb.irs = map(AssignBlock, irb.irs)
+            irs = []
+            for assignblk in irb.irs:
+                irs.append(AssignBlock(assignblk, instr))
+            irb.irs = irs
+        assignblk = AssignBlock(ir_bloc_cur, instr)
         return assignblk, extra_assignblk
 
-    def get_label(self, ad):
+    def get_label(self, addr):
         """Transforms an ExprId/ExprInt/label/int into a label
-        @ad: an ExprId/ExprInt/label/int"""
-
-        if (isinstance(ad, m2_expr.ExprId) and
-                isinstance(ad.name, AsmLabel)):
-            ad = ad.name
-        if isinstance(ad, m2_expr.ExprInt):
-            ad = int(ad)
-        if isinstance(ad, (int, long)):
-            ad = self.symbol_pool.getby_offset_create(ad)
-        elif isinstance(ad, AsmLabel):
-            ad = self.symbol_pool.getby_name_create(ad.name)
-        return ad
-
-    def get_bloc(self, ad):
+        @addr: an ExprId/ExprInt/label/int"""
+
+        if (isinstance(addr, m2_expr.ExprId) and
+                isinstance(addr.name, AsmLabel)):
+            addr = addr.name
+        if isinstance(addr, m2_expr.ExprInt):
+            addr = int(addr)
+        if isinstance(addr, (int, long)):
+            addr = self.symbol_pool.getby_offset_create(addr)
+        elif isinstance(addr, AsmLabel):
+            addr = self.symbol_pool.getby_name_create(addr.name)
+        return addr
+
+    def get_bloc(self, addr):
         """Returns the irbloc associated to an ExprId/ExprInt/label/int
-        @ad: an ExprId/ExprInt/label/int"""
+        @addr: an ExprId/ExprInt/label/int"""
 
-        label = self.get_label(ad)
+        label = self.get_label(addr)
         return self.blocks.get(label, None)
 
-    def add_instr(self, l, ad=0, gen_pc_updt=False):
-        b = AsmBlock(self.gen_label())
-        b.lines = [l]
-        self.add_bloc(b, gen_pc_updt)
+    def add_instr(self, line, addr=0, gen_pc_updt=False):
+        block = AsmBlock(self.gen_label())
+        block.lines = [line]
+        self.add_bloc(block, gen_pc_updt)
 
     def getby_offset(self, offset):
         out = set()
         for irb in self.blocks.values():
-            for l in irb.lines:
-                if l.offset <= offset < l.offset + l.l:
+            for assignblk in irb.irs:
+                instr = assignblk.instr
+                if instr.offset <= offset < instr.offset + instr.l:
                     out.add(irb)
         return out
 
-    def gen_pc_update(self, c, l):
-        c.irs.append(AssignBlock([m2_expr.ExprAff(self.pc,
-                                                  m2_expr.ExprInt(l.offset,
-                                                                  self.pc.size)
-                                                 )]))
-        c.lines.append(l)
+    def gen_pc_update(self, irblock, instr):
+        irblock.irs.append(AssignBlock({self.pc: m2_expr.ExprInt(instr.offset, self.pc.size)},
+                                       instr))
 
     def pre_add_instr(self, block, instr, irb_cur, ir_blocks_all, gen_pc_updt):
         """Function called before adding an instruction from the the native @block to
@@ -439,11 +514,8 @@ class IntermediateRepresentation(object):
             self.gen_pc_update(irb_cur, instr)
 
         irb_cur.irs.append(assignblk)
-        irb_cur.lines.append(instr)
 
         if ir_blocks_extra:
-            for irblock in ir_blocks_extra:
-                irblock.lines = [instr] * len(irblock.irs)
             ir_blocks_all += ir_blocks_extra
             irb_cur = None
         return irb_cur
@@ -460,28 +532,28 @@ class IntermediateRepresentation(object):
         for instr in block.lines:
             if irb_cur is None:
                 label = self.get_instr_label(instr)
-                irb_cur = IRBlock(label, [], [])
+                irb_cur = IRBlock(label, [])
                 ir_blocks_all.append(irb_cur)
             irb_cur = self.add_instr_to_irblock(block, instr, irb_cur,
                                                 ir_blocks_all, gen_pc_updt)
         self.post_add_bloc(block, ir_blocks_all)
         return ir_blocks_all
 
-    def expr_fix_regs_for_mode(self, e, *args, **kwargs):
-        return e
+    def expr_fix_regs_for_mode(self, expr, *args, **kwargs):
+        return expr
 
-    def expraff_fix_regs_for_mode(self, e, *args, **kwargs):
-        return e
+    def expraff_fix_regs_for_mode(self, expr, *args, **kwargs):
+        return expr
 
     def irbloc_fix_regs_for_mode(self, irbloc, *args, **kwargs):
         return
 
-    def is_pc_written(self, b):
+    def is_pc_written(self, block):
         all_pc = self.arch.pc.values()
-        for irs in b.irs:
-            for ir in irs:
-                if ir.dst in all_pc:
-                    return ir
+        for irs in block.irs:
+            for assignblk in irs:
+                if assignblk.dst in all_pc:
+                    return assignblk
         return None
 
     def set_empty_dst_to_next(self, block, ir_blocks):
@@ -495,8 +567,8 @@ class IntermediateRepresentation(object):
             else:
                 dst = m2_expr.ExprId(next_lbl,
                                      self.pc.size)
-            irblock.irs.append(AssignBlock([m2_expr.ExprAff(self.IRDst, dst)]))
-            irblock.lines.append(irblock.lines[-1])
+            irblock.irs.append(AssignBlock({self.IRDst: dst},
+                                           irblock.irs[-1].instr))
 
     def post_add_bloc(self, block, ir_blocks):
         self.set_empty_dst_to_next(block, ir_blocks)
@@ -516,12 +588,12 @@ class IntermediateRepresentation(object):
 
     def gen_label(self):
         # TODO: fix hardcoded offset
-        l = self.symbol_pool.gen_label()
-        return l
+        label = self.symbol_pool.gen_label()
+        return label
 
     def get_next_label(self, instr):
-        l = self.symbol_pool.getby_offset_create(instr.offset + instr.l)
-        return l
+        label = self.symbol_pool.getby_offset_create(instr.offset + instr.l)
+        return label
 
     def simplify_blocs(self):
         for irblock in self.blocks.values():
@@ -596,15 +668,14 @@ class IntermediateRepresentation(object):
         Gen irbloc digraph
         """
         self._graph = DiGraphIR(self.blocks)
-        for lbl, b in self.blocks.iteritems():
+        for lbl, block in self.blocks.iteritems():
             self._graph.add_node(lbl)
-            dst = self.dst_trackback(b)
-            for d in dst:
-                if isinstance(d, m2_expr.ExprInt):
-                    d = m2_expr.ExprId(
-                        self.symbol_pool.getby_offset_create(int(d)))
-                if expr_is_label(d):
-                    self._graph.add_edge(lbl, d.name)
+            for dst in self.dst_trackback(block):
+                if dst.is_int():
+                    dst_lbl = self.symbol_pool.getby_offset_create(int(dst))
+                    dst = m2_expr.ExprId(dst_lbl)
+                if expr_is_label(dst):
+                    self._graph.add_edge(lbl, dst.name)
 
     @property
     def graph(self):