about summary refs log tree commit diff stats
path: root/miasm2/ir/analysis.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/ir/analysis.py')
-rw-r--r--miasm2/ir/analysis.py350
1 files changed, 202 insertions, 148 deletions
diff --git a/miasm2/ir/analysis.py b/miasm2/ir/analysis.py
index ad37a1df..87c53d44 100644
--- a/miasm2/ir/analysis.py
+++ b/miasm2/ir/analysis.py
@@ -6,7 +6,7 @@ import logging
 from miasm2.ir.symbexec import symbexec
 from miasm2.core.graph import DiGraph
 from miasm2.expression.expression \
-    import ExprAff, ExprCond, ExprId, ExprInt, ExprMem, ExprOp
+    import ExprAff, ExprCond, ExprId, ExprInt, ExprMem
 
 log = logging.getLogger("analysis")
 console_handler = logging.StreamHandler()
@@ -16,6 +16,10 @@ log.setLevel(logging.WARNING)
 
 class ira:
 
+    def ira_regs_ids(self):
+        """Returns ids of all registers used in the IR"""
+        return self.arch.regs.all_regs_ids + [self.IRDst]
+
     def sort_dst(self, todo, done):
         out = set()
         while todo:
@@ -115,194 +119,244 @@ class ira:
         out += '}'
         return out
 
-    def remove_dead(self, irb):
-        """Remove dead affectations using previous liveness analysis
+    def remove_dead_instr(self, irb, useful):
+        """Remove dead affectations using previous reaches analysis
         @irb: irbloc instance
-        Return True iff the bloc state has changed
-        PRE: compute_in_out(@irb)
+        @useful: useful statements from previous reach analysis
+        Return True iff the block state has changed
+        PRE: compute_reach(self)
         """
-
-        # print 'state1'
-        # self.dump_bloc_state(irb)
-
         modified = False
-        for ir, _, c_out in zip(irb.irs, irb.c_in, irb.c_out):
+        for k, ir in enumerate(irb.irs):
             j = 0
             while j < len(ir):
-                i_cur = ir[j]
-                if not isinstance(i_cur.dst, ExprId):
-                    pass
-                elif i_cur.dst == self.IRDst:
-                    # never delete irdst
-                    pass
-                elif (isinstance(i_cur.src, ExprOp) and
-                    i_cur.src.op.startswith('call')):
-                    # /!\ never remove ir calls
-                    pass
-                elif i_cur.dst not in c_out:
-                    del(ir[j])
+                cur_instr = ir[j]
+                if (isinstance(cur_instr.dst, ExprId)
+                    and (irb.label, k, cur_instr) not in useful):
+                    del ir[j]
                     modified = True
-                    continue
-                j += 1
+                else:
+                    j += 1
+        return modified
 
-        # print 'state2'
-        # self.dump_bloc_state(irb)
+    def init_useful_instr(self):
+        """Computes a set of triples (block, instruction number, instruction)
+        containing initially useful instructions :
+          - Instructions affecting final value of return registers
+          - Instructions affecting IRDst register
+          - Instructions writing in memory
+          - Function call instructions
+        Return set of intial useful instructions
+        """
 
-        return modified
+        useful = set()
+
+        for node in self.g.nodes():
+            if node not in self.blocs:
+                continue
+
+            block = self.blocs[node]
+            successors = self.g.successors(node)
+            has_son = bool(successors)
+            for p_son in successors:
+                if p_son not in self.blocs:
+                    # Leaf has lost its son: don't remove anything
+                    # reaching this block
+                    for r in self.ira_regs_ids():
+                        useful.update(block.cur_reach[-1][r].union(
+                                block.defout[-1][r]))
+
+            # Function call, memory write or IRDst affectation
+            for k, ir in enumerate(block.irs):
+                for i_cur in ir:
+                    if i_cur.is_function_call():
+                        # /!\ never remove ir calls
+                        useful.add((block.label, k, i_cur))
+                    if isinstance(i_cur.dst, ExprMem):
+                        useful.add((block.label, k, i_cur))
+                    useful.update(block.defout[k][self.IRDst])
+
+            # Affecting return registers
+            if not has_son:
+                for r in self.get_out_regs(block):
+                    useful.update(block.defout[-1][r]
+                                  if block.defout[-1][r] else
+                                  block.cur_reach[-1][r])
+
+        return useful
+
+    def _mark_useful_code(self):
+        """Mark useful statements using previous reach analysis
+
+        Source : Kennedy, K. (1979). A survey of data flow analysis techniques.
+        IBM Thomas J. Watson Research Division,  Algorithm MK
+
+        Return a set of triplets (block, instruction number, instruction) of
+        useful instructions
+        PRE: compute_reach(self)
 
-    def remove_blocs_dead(self):
-        """Call remove_dead on each irbloc
-        Return True iff one of the bloc state has changed
         """
+
+        useful = self.init_useful_instr()
+        worklist = useful.copy()
+        while worklist:
+            elem = worklist.pop()
+            useful.add(elem)
+            irb, irs_ind, ins = elem
+
+            block = self.blocs[irb]
+            instr_defout = block.defout[irs_ind]
+            cur_kill = block.cur_kill[irs_ind]
+            cur_reach = block.cur_reach[irs_ind]
+
+            # Handle dependencies of used variables in ins
+            for reg in ins.get_r(True).intersection(self.ira_regs_ids()):
+                worklist.update(
+                    cur_reach[reg].difference(useful).difference(
+                        cur_kill[reg]
+                        if not instr_defout[reg] else
+                        set()))
+                for _, _, i in instr_defout[reg]:
+                    # Loop case (i in defout of current block)
+                    if i == ins:
+                        worklist.update(cur_reach[reg].difference(useful))
+        return useful
+
+    def remove_dead_code(self):
+        """Remove dead instructions in each block of the graph using the reach
+        analysis .
+        Returns True if a block has been modified
+        PRE : compute_reach(self)
+        """
+        useful = self._mark_useful_code()
         modified = False
-        for b in self.blocs.values():
-            modified |= self.remove_dead(b)
+        for block in self.blocs.values():
+            modified |= self.remove_dead_instr(block, useful)
         return modified
 
-    # for test XXX TODO
     def set_dead_regs(self, b):
         pass
 
     def add_unused_regs(self):
         pass
 
+    @staticmethod
+    def print_set(v_set):
+        """Print each triplet contained in a set
+        @v_set: set containing triplets elements
+        """
+        for p in v_set:
+            print '    (%s, %s, %s)' % p
+
     def dump_bloc_state(self, irb):
         print '*'*80
-        for i, (ir, c_in, c_out) in enumerate(zip(irb.irs, irb.c_in, irb.c_out)):
-            print 'ir'
-            for x in ir:
-                print '\t', x
-            print 'R', [str(x) for x in irb.r[i]]#c_in]
-            print 'W', [str(x) for x in irb.w[i]]#c_out]
-            print 'IN', [str(x) for x in c_in]
-            print 'OUT', [str(x) for x in c_out]
-
-
-    def compute_in_out(self, irb):
-        """Liveness computation for a single bloc
+        for k, irs in enumerate(irb.irs):
+            for i in xrange(len(irs)):
+                print 5*"-"
+                print 'instr', k, irs[i]
+                print 5*"-"
+                for v in self.ira_regs_ids():
+                    if irb.cur_reach[k][v]:
+                        print 'REACH[%d][%s]' % (k, v)
+                        self.print_set(irb.cur_reach[k][v])
+                    if irb.cur_kill[k][v]:
+                        print 'KILL[%d][%s]' % (k, v)
+                        self.print_set(irb.cur_kill[k][v])
+                    if irb.defout[k][v]:
+                        print 'DEFOUT[%d][%s]' % (k, v)
+                        self.print_set(irb.defout[k][v])
+
+    def compute_reach_block(self, irb):
+        """Variable influence computation for a single block
         @irb: irbloc instance
-        Return True iff bloc state has changed
+        PRE: init_reach()
         """
-        modified = False
-
-        # Compute OUT for last irb entry
-        c_out = set()
-        has_son = False
-        for n_son in self.g.successors(irb.label):
-            has_son = True
-            if n_son not in self.blocs:
-                # If the son is not defined, we will propagate our current out
-                # nodes to the in nodes's son
-                son_c_in = irb.c_out_missing
-            else:
-                son_c_in = self.blocs[n_son].c_in[0]
-            c_out.update(son_c_in)
-        if not has_son:
-            # Special case: leaf nodes architecture dependant
-            c_out = self.get_out_regs(irb)
-
-        if irb.c_out[-1] != c_out:
-            irb.c_out[-1] = c_out
-            modified = True
-
-        # Compute out/in intra bloc
-        for i in reversed(xrange(len(irb.irs))):
-            new_in = set(irb.r[i].union(irb.c_out[i].difference(irb.w[i])))
-            if irb.c_in[i] != new_in:
-                irb.c_in[i] = new_in
-                modified = True
-
-            if i >= len(irb.irs) - 1:
-                # Last out has been previously updated
-                continue
-            new_out = set(irb.c_in[i + 1])
-            if irb.c_out[i] != new_out:
-                irb.c_out[i] = new_out
-                modified = True
 
-        return modified
-
-    def test_in_out_fix(self):
-        """Return True iff a fixed point has been reached during liveness
+        reach_block = {key: value.copy()
+                      for key, value in irb.cur_reach[0].iteritems()}
+
+        # Compute reach from predecessors
+        for n_pred in self.g.predecessors(irb.label):
+            p_block = self.blocs[n_pred]
+
+            # Handle each register definition
+            for c_reg in self.ira_regs_ids():
+                # REACH(n) = U[p in pred] DEFOUT(p) U REACH(p)\KILL(p)
+                pred_through = p_block.defout[-1][c_reg].union(
+                    p_block.cur_reach[-1][c_reg].difference(
+                        p_block.cur_kill[-1][c_reg]))
+                reach_block[c_reg].update(pred_through)
+
+        # If a predecessor has changed
+        if reach_block != irb.cur_reach[0]:
+            irb.cur_reach[0] = reach_block
+            for c_reg in self.ira_regs_ids():
+                if irb.defout[0][c_reg]:
+                    # KILL(n) = DEFOUT(n) ? REACH(n)\DEFOUT(n) : EMPTY
+                    irb.cur_kill[0][c_reg].update(
+                        reach_block[c_reg].difference(irb.defout[0][c_reg]))
+
+        # Compute reach and kill for block's instructions
+        for i in xrange(1, len(irb.irs)):
+            for c_reg in self.ira_regs_ids():
+                # REACH(n) = U[p in pred] DEFOUT(p) U REACH(p)\KILL(p)
+                pred_through = irb.defout[i - 1][c_reg].union(
+                    irb.cur_reach[i - 1][c_reg].difference(
+                        irb.cur_kill[i - 1][c_reg]))
+                irb.cur_reach[i][c_reg].update(pred_through)
+                if irb.defout[i][c_reg]:
+                    # KILL(n) = DEFOUT(n) ? REACH(n)\DEFOUT(n) : EMPTY
+                    irb.cur_kill[i][c_reg].update(
+                        irb.cur_reach[i][c_reg].difference(
+                            irb.defout[i][c_reg]))
+
+    def _test_kill_reach_fix(self):
+        """Return True iff a fixed point has been reached during reach
         analysis"""
 
         fixed = True
         for node in self.g.nodes():
-            if node not in self.blocs:
-                # leaf has lost her son
-                continue
-            irb = self.blocs[node]
-            if irb.c_in != irb.l_in or irb.c_out != irb.l_out:
-                fixed = False
-            irb.l_in = [set(x) for x in irb.c_in]
-            irb.l_out = [set(x) for x in irb.c_out]
+            if node in self.blocs:
+                irb = self.blocs[node]
+                if (irb.cur_reach != irb.prev_reach or
+                    irb.cur_kill != irb.prev_kill):
+                    fixed = False
+                    irb.prev_reach = irb.cur_reach[:]
+                    irb.prev_kill = irb.cur_kill[:]
         return fixed
 
-    def fill_missing_son_c_in(self):
-        """Find nodes with missing sons in graph, and add virtual link to all
-        written variables of all parents.
-        PRE: gen_graph() and get_rw()"""
+    def compute_reach(self):
+        """
+        Compute reach, defout and kill sets until a fixed point is reached.
+
+        Source : Kennedy, K. (1979). A survey of data flow analysis techniques.
+        IBM Thomas J. Watson Research Division, page 43
 
-        for node in self.g.nodes():
-            if node not in self.blocs:
-                continue
-            self.blocs[node].c_out_missing = set()
-            has_all_son = True
-            for node_son in self.g.successors(node):
-                if node_son not in self.blocs:
-                    has_all_son = False
-                    break
-            if has_all_son:
-                continue
-            parents = self.g.reachable_parents(node)
-            for parent in parents:
-                irb = self.blocs[parent]
-                for var_w in irb.w:
-                    self.blocs[node].c_out_missing.update(var_w)
-
-    def compute_dead(self):
-        """Iterate liveness analysis until a fixed point is reached.
         PRE: gen_graph()
         """
-
-        it = 0
         fixed_point = False
         log.debug('iteration...')
         while not fixed_point:
-            log.debug(it)
-            it += 1
-            for n in self.g.nodes():
-                if n not in self.blocs:
-                    # leaf has lost her son
-                    continue
-                irb = self.blocs[n]
-                self.compute_in_out(irb)
-
-            fixed_point = self.test_in_out_fix()
+            for node in self.g.nodes():
+                if node in self.blocs:
+                    self.compute_reach_block(self.blocs[node])
+            fixed_point = self._test_kill_reach_fix()
 
     def dead_simp(self):
-        """This function is used to analyse relation of a * complete function *
-        This mean the blocs under study represent a solid full function graph.
-
-        Ref: CS 5470 Compiler Techniques and Principles (Liveness
-        analysis/Dataflow equations)
-
-        PRE: call to gen_graph
         """
+        This function is used to analyse relation of a * complete function *
+        This means the blocks under study represent a solid full function graph.
 
-        modified = True
-        while modified:
-            log.debug('dead_simp step')
-
-            # Update r/w variables for all irblocs
-            self.get_rw()
-            # Fill c_in for missing sons
-            self.fill_missing_son_c_in()
-
-            # Liveness step
-            self.compute_dead()
-            modified = self.remove_blocs_dead()
+        Source : Kennedy, K. (1979). A survey of data flow analysis techniques.
+        IBM Thomas J. Watson Research Division, page 43
 
+        PRE: gen_graph()
+        """
+        # Update r/w variables for all irblocs
+        self.get_rw(self.ira_regs_ids())
+        # Liveness step
+        self.compute_reach()
+        self.remove_dead_code()
         # Simplify expressions
         self.simplify_blocs()