about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorCamille Mougey <camille.mougey@cea.fr>2016-02-26 14:11:14 +0100
committerFabrice Desclaux <fabrice.desclaux@cea.fr>2016-02-26 15:53:53 +0100
commitf0ed13ea3d1a7bc0255c366ca31d38591c5a1aad (patch)
tree2258ac5df571524b40fae4e86cc1a91e511d90cc
parent75271c4e1f1917eee58ce71aeaf4bd6acb228ebf (diff)
downloadmiasm-f0ed13ea3d1a7bc0255c366ca31d38591c5a1aad.tar.gz
miasm-f0ed13ea3d1a7bc0255c366ca31d38591c5a1aad.zip
Move dead_simp structures into AssignBlock
-rw-r--r--miasm2/ir/analysis.py102
-rw-r--r--miasm2/ir/ir.py29
2 files changed, 59 insertions, 72 deletions
diff --git a/miasm2/ir/analysis.py b/miasm2/ir/analysis.py
index 9cb15811..c52b10ed 100644
--- a/miasm2/ir/analysis.py
+++ b/miasm2/ir/analysis.py
@@ -80,8 +80,8 @@ class ira(ir):
                     # Leaf has lost its son: don't remove anything
                     # reaching this block
                     for r in self.ira_regs_ids():
-                        useful.update(block.cur_reach[-1][r].union(
-                            block.defout[-1][r]))
+                        useful.update(block.irs[-1].cur_reach[r].union(
+                            block.irs[-1].defout[r]))
 
             # Function call, memory write or IRDst affectation
             for idx, assignblk in enumerate(block.irs):
@@ -91,14 +91,14 @@ class ira(ir):
                         useful.add((block.label, idx, dst))
                     if isinstance(dst, ExprMem):
                         useful.add((block.label, idx, dst))
-                    useful.update(block.defout[idx][self.IRDst])
+                    useful.update(block.irs[idx].defout[self.IRDst])
 
             # Affecting return registers
             if not has_son:
                 for r in self.get_out_regs(block):
-                    useful.update(block.defout[-1][r]
-                                  if block.defout[-1][r] else
-                                  block.cur_reach[-1][r])
+                    useful.update(block.irs[-1].defout[r]
+                                  if block.irs[-1].defout[r] else
+                                  block.irs[-1].cur_reach[r])
 
         return useful
 
@@ -119,25 +119,22 @@ class ira(ir):
         while worklist:
             elem = worklist.pop()
             useful.add(elem)
-            irb, irs_ind, dst = elem
+            irb_label, irs_ind, dst = elem
 
-            irb = self.blocs[irb]
-            ins = irb.irs[irs_ind].dst2ExprAff(dst)
-            instr_defout = irb.defout[irs_ind]
-            cur_kill = irb.cur_kill[irs_ind]
-            cur_reach = irb.cur_reach[irs_ind]
+            assignblk = self.blocs[irb_label].irs[irs_ind]
+            ins = assignblk.dst2ExprAff(dst)
 
             # Handle dependencies of used variables in ins
             for reg in ins.get_r(True).intersection(self.ira_regs_ids()):
                 worklist.update(
-                    cur_reach[reg].difference(useful).difference(
-                        cur_kill[reg]
-                        if not instr_defout[reg] else
+                    assignblk.cur_reach[reg].difference(useful).difference(
+                        assignblk.cur_kill[reg]
+                        if not assignblk.defout[reg] else
                         set()))
-                for _, _, defout_dst in instr_defout[reg]:
+                for _, _, defout_dst in assignblk.defout[reg]:
                     # Loop case (dst in defout of current irb)
                     if defout_dst == dst:
-                        worklist.update(cur_reach[reg].difference(useful))
+                        worklist.update(assignblk.cur_reach[reg].difference(useful))
         return useful
 
     def remove_dead_code(self):
@@ -167,22 +164,22 @@ class ira(ir):
             print '    (%s, %s, %s)' % p
 
     def dump_bloc_state(self, irb):
-        print '*'*80
-        for k, irs in enumerate(irb.irs):
-            for i in xrange(len(irs)):
-                print 5*"-"
-                print 'instr', k, irs[i]
-                print 5*"-"
+        print '*' * 80
+        for irs in irb.irs:
+            for assignblk in irs:
+                print 5 * "-"
+                print 'instr', assignblk
+                print 5 * "-"
                 for v in self.ira_regs_ids():
-                    if irb.cur_reach[k][v]:
+                    if assignblk.cur_reach[v]:
                         print 'REACH[%d][%s]' % (k, v)
-                        self.print_set(irb.cur_reach[k][v])
-                    if irb.cur_kill[k][v]:
+                        self.print_set(assignblk.cur_reach[v])
+                    if assignblk.cur_kill[v]:
                         print 'KILL[%d][%s]' % (k, v)
-                        self.print_set(irb.cur_kill[k][v])
-                    if irb.defout[k][v]:
+                        self.print_set(assignblk.cur_kill[v])
+                    if assignblk.defout[v]:
                         print 'DEFOUT[%d][%s]' % (k, v)
-                        self.print_set(irb.defout[k][v])
+                        self.print_set(assignblk.defout[v])
 
     def compute_reach_block(self, irb):
         """Variable influence computation for a single block
@@ -191,7 +188,7 @@ class ira(ir):
         """
 
         reach_block = {key: value.copy()
-                      for key, value in irb.cur_reach[0].iteritems()}
+                      for key, value in irb.irs[0].cur_reach.iteritems()}
 
         # Compute reach from predecessors
         for n_pred in self.graph.predecessors(irb.label):
@@ -200,33 +197,33 @@ class ira(ir):
             # Handle each register definition
             for c_reg in self.ira_regs_ids():
                 # REACH(n) = U[p in pred] DEFOUT(p) U REACH(p)\KILL(p)
-                pred_through = p_block.defout[-1][c_reg].union(
-                    p_block.cur_reach[-1][c_reg].difference(
-                        p_block.cur_kill[-1][c_reg]))
+                pred_through = p_block.irs[-1].defout[c_reg].union(
+                    p_block.irs[-1].cur_reach[c_reg].difference(
+                        p_block.irs[-1].cur_kill[c_reg]))
                 reach_block[c_reg].update(pred_through)
 
         # If a predecessor has changed
-        if reach_block != irb.cur_reach[0]:
-            irb.cur_reach[0] = reach_block
+        if reach_block != irb.irs[0].cur_reach:
+            irb.irs[0].cur_reach = reach_block
             for c_reg in self.ira_regs_ids():
-                if irb.defout[0][c_reg]:
+                if irb.irs[0].defout[c_reg]:
                     # KILL(n) = DEFOUT(n) ? REACH(n)\DEFOUT(n) : EMPTY
-                    irb.cur_kill[0][c_reg].update(
-                        reach_block[c_reg].difference(irb.defout[0][c_reg]))
+                    irb.irs[0].cur_kill[c_reg].update(
+                        reach_block[c_reg].difference(irb.irs[0].defout[c_reg]))
 
         # Compute reach and kill for block's instructions
         for i in xrange(1, len(irb.irs)):
             for c_reg in self.ira_regs_ids():
                 # REACH(n) = U[p in pred] DEFOUT(p) U REACH(p)\KILL(p)
-                pred_through = irb.defout[i - 1][c_reg].union(
-                    irb.cur_reach[i - 1][c_reg].difference(
-                        irb.cur_kill[i - 1][c_reg]))
-                irb.cur_reach[i][c_reg].update(pred_through)
-                if irb.defout[i][c_reg]:
+                pred_through = irb.irs[i - 1].defout[c_reg].union(
+                    irb.irs[i - 1].cur_reach[c_reg].difference(
+                        irb.irs[i - 1].cur_kill[c_reg]))
+                irb.irs[i].cur_reach[c_reg].update(pred_through)
+                if irb.irs[i].defout[c_reg]:
                     # KILL(n) = DEFOUT(n) ? REACH(n)\DEFOUT(n) : EMPTY
-                    irb.cur_kill[i][c_reg].update(
-                        irb.cur_reach[i][c_reg].difference(
-                            irb.defout[i][c_reg]))
+                    irb.irs[i].cur_kill[c_reg].update(
+                        irb.irs[i].cur_reach[c_reg].difference(
+                            irb.irs[i].defout[c_reg]))
 
     def _test_kill_reach_fix(self):
         """Return True iff a fixed point has been reached during reach
@@ -236,11 +233,14 @@ class ira(ir):
         for node in self.graph.nodes():
             if node in self.blocs:
                 irb = self.blocs[node]
-                if (irb.cur_reach != irb.prev_reach or
-                    irb.cur_kill != irb.prev_kill):
-                    fixed = False
-                    irb.prev_reach = irb.cur_reach[:]
-                    irb.prev_kill = irb.cur_kill[:]
+                for assignblk in irb.irs:
+                    if (assignblk.cur_reach != assignblk.prev_reach or
+                        assignblk.cur_kill != assignblk.prev_kill):
+                        fixed = False
+                        # This is not a deepcopy, but cur_reach is assigned to a
+                        # new dictionnary on change in `compute_reach_block`
+                        assignblk.prev_reach = assignblk.cur_reach.copy()
+                        assignblk.prev_kill = assignblk.cur_kill.copy()
         return fixed
 
     def compute_reach(self):
diff --git a/miasm2/ir/ir.py b/miasm2/ir/ir.py
index ffcf5480..6265faeb 100644
--- a/miasm2/ir/ir.py
+++ b/miasm2/ir/ir.py
@@ -212,33 +212,20 @@ class irbloc(object):
         Initialize attributes needed for in/out and reach computation.
         @regs_ids : ids of registers used in IR
         """
-        self.r = []
-        self.w = []
-        self.cur_reach = [{reg: set() for reg in regs_ids}
-                          for _ in xrange(len(self.irs))]
-        self.prev_reach = [{reg: set() for reg in regs_ids}
-                           for _ in xrange(len(self.irs))]
-        self.cur_kill = [{reg: set() for reg in regs_ids}
-                         for _ in xrange(len(self.irs))]
-        self.prev_kill = [{reg: set() for reg in regs_ids}
-                          for _ in xrange(len(self.irs))]
-        # LineNumber -> dict:
-        #               Register: set(definition(irb label, index))
-        self.defout = [{reg: set() for reg in regs_ids}
-                       for _ in xrange(len(self.irs))]
         keep_exprid = lambda elts: filter(lambda expr: isinstance(expr,
                                                                   m2_expr.ExprId),
                                           elts)
         for idx, assignblk in enumerate(self.irs):
-            read, write = map(keep_exprid,
-                              (assignblk.get_r(mem_read=True),
-                               assignblk.get_w()))
-
-            self.defout[idx].update({dst: set([(self.label, idx, dst)])
+            assignblk.cur_reach = {reg: set() for reg in regs_ids}
+            assignblk.prev_reach = {reg: set() for reg in regs_ids}
+            assignblk.cur_kill = {reg: set() for reg in regs_ids}
+            assignblk.prev_kill = {reg: set() for reg in regs_ids}
+            # LineNumber -> dict:
+            #               Register: set(definition(irb label, index))
+            assignblk.defout = {reg: set() for reg in regs_ids}
+            assignblk.defout.update({dst: set([(self.label, idx, dst)])
                                      for dst in assignblk
                                      if isinstance(dst, m2_expr.ExprId)})
-            self.r.append(read)
-            self.w.append(write)
 
     def __str__(self):
         out = []