about summary refs log tree commit diff stats
path: root/miasm2/ir/analysis.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/ir/analysis.py')
-rw-r--r--miasm2/ir/analysis.py153
1 files changed, 84 insertions, 69 deletions
diff --git a/miasm2/ir/analysis.py b/miasm2/ir/analysis.py
index 40a3bf64..2aa853f0 100644
--- a/miasm2/ir/analysis.py
+++ b/miasm2/ir/analysis.py
@@ -4,9 +4,9 @@
 import logging
 
 from miasm2.ir.symbexec import symbexec
-from miasm2.ir.ir import ir
+from miasm2.ir.ir import ir, AssignBlock
 from miasm2.expression.expression \
-    import ExprAff, ExprCond, ExprId, ExprInt, ExprMem
+    import ExprAff, ExprCond, ExprId, ExprInt, ExprMem, ExprOp
 
 log = logging.getLogger("analysis")
 console_handler = logging.StreamHandler()
@@ -29,6 +29,17 @@ class ira(ir):
         """Returns ids of all registers used in the IR"""
         return self.arch.regs.all_regs_ids + [self.IRDst]
 
+    def call_effects(self, ad):
+        """
+        Default simulation of a function call to @ad
+        @ad: (Expr) address of the called function
+        """
+        return [AssignBlock(
+            [ExprAff(self.ret_reg, ExprOp('call_func_ret', ad, self.sp)),
+             ExprAff(self.sp, ExprOp(
+                 'call_func_stack', ad, self.sp)),
+             ])]
+
     def remove_dead_instr(self, irb, useful):
         """Remove dead affectations using previous reaches analysis
         @irb: irbloc instance
@@ -37,16 +48,12 @@ class ira(ir):
         PRE: compute_reach(self)
         """
         modified = False
-        for k, ir in enumerate(irb.irs):
-            j = 0
-            while j < len(ir):
-                cur_instr = ir[j]
-                if (isinstance(cur_instr.dst, ExprId)
-                    and (irb.label, k, cur_instr) not in useful):
-                    del ir[j]
+        for idx, assignblk in enumerate(irb.irs):
+            for dst in assignblk.keys():
+                if (isinstance(dst, ExprId) and
+                        (irb.label, idx, dst) not in useful):
+                    del assignblk[dst]
                     modified = True
-                else:
-                    j += 1
         return modified
 
     def init_useful_instr(self):
@@ -73,25 +80,25 @@ class ira(ir):
                     # Leaf has lost its son: don't remove anything
                     # reaching this block
                     for r in self.ira_regs_ids():
-                        useful.update(block.cur_reach[-1][r].union(
-                                block.defout[-1][r]))
+                        useful.update(block.irs[-1]._cur_reach[r].union(
+                            block.irs[-1].defout[r]))
 
             # Function call, memory write or IRDst affectation
-            for k, ir in enumerate(block.irs):
-                for i_cur in ir:
-                    if i_cur.src.is_function_call():
+            for idx, assignblk in enumerate(block.irs):
+                for dst, src in assignblk.iteritems():
+                    if src.is_function_call():
                         # /!\ never remove ir calls
-                        useful.add((block.label, k, i_cur))
-                    if isinstance(i_cur.dst, ExprMem):
-                        useful.add((block.label, k, i_cur))
-                    useful.update(block.defout[k][self.IRDst])
+                        useful.add((block.label, idx, dst))
+                    if isinstance(dst, ExprMem):
+                        useful.add((block.label, idx, dst))
+                    useful.update(block.irs[idx].defout[self.IRDst])
 
             # Affecting return registers
             if not has_son:
                 for r in self.get_out_regs(block):
-                    useful.update(block.defout[-1][r]
-                                  if block.defout[-1][r] else
-                                  block.cur_reach[-1][r])
+                    useful.update(block.irs[-1].defout[r]
+                                  if block.irs[-1].defout[r] else
+                                  block.irs[-1]._cur_reach[r])
 
         return useful
 
@@ -112,24 +119,23 @@ class ira(ir):
         while worklist:
             elem = worklist.pop()
             useful.add(elem)
-            irb, irs_ind, ins = elem
+            irb_label, irs_ind, dst = elem
 
-            block = self.blocs[irb]
-            instr_defout = block.defout[irs_ind]
-            cur_kill = block.cur_kill[irs_ind]
-            cur_reach = block.cur_reach[irs_ind]
+            assignblk = self.blocs[irb_label].irs[irs_ind]
+            ins = assignblk.dst2ExprAff(dst)
 
             # Handle dependencies of used variables in ins
             for reg in ins.get_r(True).intersection(self.ira_regs_ids()):
                 worklist.update(
-                    cur_reach[reg].difference(useful).difference(
-                        cur_kill[reg]
-                        if not instr_defout[reg] else
+                    assignblk._cur_reach[reg].difference(useful).difference(
+                        assignblk._cur_kill[reg]
+                        if not assignblk.defout[reg] else
                         set()))
-                for _, _, i in instr_defout[reg]:
-                    # Loop case (i in defout of current block)
-                    if i == ins:
-                        worklist.update(cur_reach[reg].difference(useful))
+                for _, _, defout_dst in assignblk.defout[reg]:
+                    # Loop case (dst in defout of current irb)
+                    if defout_dst == dst:
+                        worklist.update(
+                            assignblk._cur_reach[reg].difference(useful))
         return useful
 
     def remove_dead_code(self):
@@ -142,6 +148,12 @@ class ira(ir):
         modified = False
         for block in self.blocs.values():
             modified |= self.remove_dead_instr(block, useful)
+            # Remove useless structures
+            for assignblk in block.irs:
+                del assignblk._cur_kill
+                del assignblk._prev_kill
+                del assignblk._cur_reach
+                del assignblk._prev_reach
         return modified
 
     def set_dead_regs(self, b):
@@ -159,22 +171,22 @@ class ira(ir):
             print '    (%s, %s, %s)' % p
 
     def dump_bloc_state(self, irb):
-        print '*'*80
-        for k, irs in enumerate(irb.irs):
-            for i in xrange(len(irs)):
-                print 5*"-"
-                print 'instr', k, irs[i]
-                print 5*"-"
+        print '*' * 80
+        for irs in irb.irs:
+            for assignblk in irs:
+                print 5 * "-"
+                print 'instr', assignblk
+                print 5 * "-"
                 for v in self.ira_regs_ids():
-                    if irb.cur_reach[k][v]:
+                    if assignblk._cur_reach[v]:
                         print 'REACH[%d][%s]' % (k, v)
-                        self.print_set(irb.cur_reach[k][v])
-                    if irb.cur_kill[k][v]:
+                        self.print_set(assignblk._cur_reach[v])
+                    if assignblk._cur_kill[v]:
                         print 'KILL[%d][%s]' % (k, v)
-                        self.print_set(irb.cur_kill[k][v])
-                    if irb.defout[k][v]:
+                        self.print_set(assignblk._cur_kill[v])
+                    if assignblk.defout[v]:
                         print 'DEFOUT[%d][%s]' % (k, v)
-                        self.print_set(irb.defout[k][v])
+                        self.print_set(assignblk.defout[v])
 
     def compute_reach_block(self, irb):
         """Variable influence computation for a single block
@@ -183,7 +195,7 @@ class ira(ir):
         """
 
         reach_block = {key: value.copy()
-                      for key, value in irb.cur_reach[0].iteritems()}
+                       for key, value in irb.irs[0]._cur_reach.iteritems()}
 
         # Compute reach from predecessors
         for n_pred in self.graph.predecessors(irb.label):
@@ -192,33 +204,33 @@ class ira(ir):
             # Handle each register definition
             for c_reg in self.ira_regs_ids():
                 # REACH(n) = U[p in pred] DEFOUT(p) U REACH(p)\KILL(p)
-                pred_through = p_block.defout[-1][c_reg].union(
-                    p_block.cur_reach[-1][c_reg].difference(
-                        p_block.cur_kill[-1][c_reg]))
+                pred_through = p_block.irs[-1].defout[c_reg].union(
+                    p_block.irs[-1]._cur_reach[c_reg].difference(
+                        p_block.irs[-1]._cur_kill[c_reg]))
                 reach_block[c_reg].update(pred_through)
 
         # If a predecessor has changed
-        if reach_block != irb.cur_reach[0]:
-            irb.cur_reach[0] = reach_block
+        if reach_block != irb.irs[0]._cur_reach:
+            irb.irs[0]._cur_reach = reach_block
             for c_reg in self.ira_regs_ids():
-                if irb.defout[0][c_reg]:
+                if irb.irs[0].defout[c_reg]:
                     # KILL(n) = DEFOUT(n) ? REACH(n)\DEFOUT(n) : EMPTY
-                    irb.cur_kill[0][c_reg].update(
-                        reach_block[c_reg].difference(irb.defout[0][c_reg]))
+                    irb.irs[0]._cur_kill[c_reg].update(
+                        reach_block[c_reg].difference(irb.irs[0].defout[c_reg]))
 
         # Compute reach and kill for block's instructions
         for i in xrange(1, len(irb.irs)):
             for c_reg in self.ira_regs_ids():
                 # REACH(n) = U[p in pred] DEFOUT(p) U REACH(p)\KILL(p)
-                pred_through = irb.defout[i - 1][c_reg].union(
-                    irb.cur_reach[i - 1][c_reg].difference(
-                        irb.cur_kill[i - 1][c_reg]))
-                irb.cur_reach[i][c_reg].update(pred_through)
-                if irb.defout[i][c_reg]:
+                pred_through = irb.irs[i - 1].defout[c_reg].union(
+                    irb.irs[i - 1]._cur_reach[c_reg].difference(
+                        irb.irs[i - 1]._cur_kill[c_reg]))
+                irb.irs[i]._cur_reach[c_reg].update(pred_through)
+                if irb.irs[i].defout[c_reg]:
                     # KILL(n) = DEFOUT(n) ? REACH(n)\DEFOUT(n) : EMPTY
-                    irb.cur_kill[i][c_reg].update(
-                        irb.cur_reach[i][c_reg].difference(
-                            irb.defout[i][c_reg]))
+                    irb.irs[i]._cur_kill[c_reg].update(
+                        irb.irs[i]._cur_reach[c_reg].difference(
+                            irb.irs[i].defout[c_reg]))
 
     def _test_kill_reach_fix(self):
         """Return True iff a fixed point has been reached during reach
@@ -228,11 +240,14 @@ class ira(ir):
         for node in self.graph.nodes():
             if node in self.blocs:
                 irb = self.blocs[node]
-                if (irb.cur_reach != irb.prev_reach or
-                    irb.cur_kill != irb.prev_kill):
-                    fixed = False
-                    irb.prev_reach = irb.cur_reach[:]
-                    irb.prev_kill = irb.cur_kill[:]
+                for assignblk in irb.irs:
+                    if (assignblk._cur_reach != assignblk._prev_reach or
+                            assignblk._cur_kill != assignblk._prev_kill):
+                        fixed = False
+                        # This is not a deepcopy, but cur_reach is assigned to a
+                        # new dictionnary on change in `compute_reach_block`
+                        assignblk._prev_reach = assignblk._cur_reach.copy()
+                        assignblk._prev_kill = assignblk._cur_kill.copy()
         return fixed
 
     def compute_reach(self):