diff options
Diffstat (limited to 'src/miasm/analysis/data_flow.py')
| -rw-r--r-- | src/miasm/analysis/data_flow.py | 2356 |
1 files changed, 2356 insertions, 0 deletions
diff --git a/src/miasm/analysis/data_flow.py b/src/miasm/analysis/data_flow.py new file mode 100644 index 00000000..23d0b3dd --- /dev/null +++ b/src/miasm/analysis/data_flow.py @@ -0,0 +1,2356 @@ +"""Data flow analysis based on miasm intermediate representation""" +from builtins import range +from collections import namedtuple, Counter +from pprint import pprint as pp +from future.utils import viewitems, viewvalues +from miasm.core.utils import encode_hex +from miasm.core.graph import DiGraph +from miasm.ir.ir import AssignBlock, IRBlock +from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ + ExprAssign, ExprOp, ExprWalk, ExprSlice, \ + is_function_call, ExprVisitorCallbackBottomToTop +from miasm.expression.simplifications import expr_simp, expr_simp_explicit +from miasm.core.interval import interval +from miasm.expression.expression_helper import possible_values +from miasm.analysis.ssa import get_phi_sources_parent_block, \ + irblock_has_phi +from miasm.ir.symbexec import get_expr_base_offset +from collections import deque + +class ReachingDefinitions(dict): + """ + Computes for each assignblock the set of reaching definitions. + Example: + IR block: + lbl0: + 0 A = 1 + B = 3 + 1 B = 2 + 2 A = A + B + 4 + + Reach definition of lbl0: + (lbl0, 0) => {} + (lbl0, 1) => {A: {(lbl0, 0)}, B: {(lbl0, 0)}} + (lbl0, 2) => {A: {(lbl0, 0)}, B: {(lbl0, 1)}} + (lbl0, 3) => {A: {(lbl0, 2)}, B: {(lbl0, 1)}} + + Source set 'REACHES' in: Kennedy, K. (1979). + A survey of data flow analysis techniques. + IBM Thomas J. Watson Research Division, Algorithm MK + + This class is usable as a dictionary whose structure is + { (block, index): { lvalue: set((block, index)) } } + """ + + ircfg = None + + def __init__(self, ircfg): + super(ReachingDefinitions, self).__init__() + self.ircfg = ircfg + self.compute() + + def get_definitions(self, block_lbl, assignblk_index): + """Returns the dict { lvalue: set((def_block_lbl, def_index)) } + associated with self.ircfg.@block.assignblks[@assignblk_index] + or {} if it is not yet computed + """ + return self.get((block_lbl, assignblk_index), {}) + + def compute(self): + """This is the main fixpoint""" + modified = True + while modified: + modified = False + for block in viewvalues(self.ircfg.blocks): + modified |= self.process_block(block) + + def process_block(self, block): + """ + Fetch reach definitions from predecessors and propagate it to + the assignblk in block @block. + """ + predecessor_state = {} + for pred_lbl in self.ircfg.predecessors(block.loc_key): + if pred_lbl not in self.ircfg.blocks: + continue + pred = self.ircfg.blocks[pred_lbl] + for lval, definitions in viewitems(self.get_definitions(pred_lbl, len(pred))): + predecessor_state.setdefault(lval, set()).update(definitions) + + modified = self.get((block.loc_key, 0)) != predecessor_state + if not modified: + return False + self[(block.loc_key, 0)] = predecessor_state + + for index in range(len(block)): + modified |= self.process_assignblock(block, index) + return modified + + def process_assignblock(self, block, assignblk_index): + """ + Updates the reach definitions with values defined at + assignblock @assignblk_index in block @block. + NB: the effect of assignblock @assignblk_index in stored at index + (@block, @assignblk_index + 1). + """ + + assignblk = block[assignblk_index] + defs = self.get_definitions(block.loc_key, assignblk_index).copy() + for lval in assignblk: + defs.update({lval: set([(block.loc_key, assignblk_index)])}) + + modified = self.get((block.loc_key, assignblk_index + 1)) != defs + if modified: + self[(block.loc_key, assignblk_index + 1)] = defs + + return modified + +ATTR_DEP = {"color" : "black", + "_type" : "data"} + +AssignblkNode = namedtuple('AssignblkNode', ['label', 'index', 'var']) + + +class DiGraphDefUse(DiGraph): + """Representation of a Use-Definition graph as defined by + Kennedy, K. (1979). A survey of data flow analysis techniques. + IBM Thomas J. Watson Research Division. + Example: + IR block: + lbl0: + 0 A = 1 + B = 3 + 1 B = 2 + 2 A = A + B + 4 + + Def use analysis: + (lbl0, 0, A) => {(lbl0, 2, A)} + (lbl0, 0, B) => {} + (lbl0, 1, B) => {(lbl0, 2, A)} + (lbl0, 2, A) => {} + + """ + + + def __init__(self, reaching_defs, + deref_mem=False, apply_simp=False, *args, **kwargs): + """Instantiate a DiGraph + @blocks: IR blocks + """ + self._edge_attr = {} + + # For dot display + self._filter_node = None + self._dot_offset = None + self._blocks = reaching_defs.ircfg.blocks + + super(DiGraphDefUse, self).__init__(*args, **kwargs) + self._compute_def_use(reaching_defs, + deref_mem=deref_mem, + apply_simp=apply_simp) + + def edge_attr(self, src, dst): + """ + Return a dictionary of attributes for the edge between @src and @dst + @src: the source node of the edge + @dst: the destination node of the edge + """ + return self._edge_attr[(src, dst)] + + def _compute_def_use(self, reaching_defs, + deref_mem=False, apply_simp=False): + for block in viewvalues(self._blocks): + self._compute_def_use_block(block, + reaching_defs, + deref_mem=deref_mem, + apply_simp=apply_simp) + + def _compute_def_use_block(self, block, reaching_defs, deref_mem=False, apply_simp=False): + for index, assignblk in enumerate(block): + assignblk_reaching_defs = reaching_defs.get_definitions(block.loc_key, index) + for lval, expr in viewitems(assignblk): + self.add_node(AssignblkNode(block.loc_key, index, lval)) + + expr = expr_simp_explicit(expr) if apply_simp else expr + read_vars = expr.get_r(mem_read=deref_mem) + if deref_mem and lval.is_mem(): + read_vars.update(lval.ptr.get_r(mem_read=deref_mem)) + for read_var in read_vars: + for reach in assignblk_reaching_defs.get(read_var, set()): + self.add_data_edge(AssignblkNode(reach[0], reach[1], read_var), + AssignblkNode(block.loc_key, index, lval)) + + def del_edge(self, src, dst): + super(DiGraphDefUse, self).del_edge(src, dst) + del self._edge_attr[(src, dst)] + + def add_uniq_labeled_edge(self, src, dst, edge_label): + """Adds the edge (@src, @dst) with label @edge_label. + if edge (@src, @dst) already exists, the previous label is overridden + """ + self.add_uniq_edge(src, dst) + self._edge_attr[(src, dst)] = edge_label + + def add_data_edge(self, src, dst): + """Adds an edge representing a data dependency + and sets the label accordingly""" + self.add_uniq_labeled_edge(src, dst, ATTR_DEP) + + def node2lines(self, node): + lbl, index, reg = node + yield self.DotCellDescription(text="%s (%s)" % (lbl, index), + attr={'align': 'center', + 'colspan': 2, + 'bgcolor': 'grey'}) + src = self._blocks[lbl][index][reg] + line = "%s = %s" % (reg, src) + yield self.DotCellDescription(text=line, attr={}) + yield self.DotCellDescription(text="", attr={}) + + +class DeadRemoval(object): + """ + Do dead removal + """ + + def __init__(self, lifter, expr_to_original_expr=None): + self.lifter = lifter + if expr_to_original_expr is None: + expr_to_original_expr = {} + self.expr_to_original_expr = expr_to_original_expr + + + def add_expr_to_original_expr(self, expr_to_original_expr): + self.expr_to_original_expr.update(expr_to_original_expr) + + def is_unkillable_destination(self, lval, rval): + if ( + lval.is_mem() or + self.lifter.IRDst == lval or + lval.is_id("exception_flags") or + is_function_call(rval) + ): + return True + return False + + def get_block_useful_destinations(self, block): + """ + Force keeping of specific cases + block: IRBlock instance + """ + useful = set() + for index, assignblk in enumerate(block): + for lval, rval in viewitems(assignblk): + if self.is_unkillable_destination(lval, rval): + useful.add(AssignblkNode(block.loc_key, index, lval)) + return useful + + def is_tracked_var(self, lval, variable): + new_lval = self.expr_to_original_expr.get(lval, lval) + return new_lval == variable + + def find_definitions_from_worklist(self, worklist, ircfg): + """ + Find variables definition in @worklist by browsing the @ircfg + """ + locs_done = set() + + defs = set() + + while worklist: + found = False + elt = worklist.pop() + if elt in locs_done: + continue + locs_done.add(elt) + variable, loc_key = elt + block = ircfg.get_block(loc_key) + + if block is None: + # Consider no sources in incomplete graph + continue + + for index, assignblk in reversed(list(enumerate(block))): + for dst, src in viewitems(assignblk): + if self.is_tracked_var(dst, variable): + defs.add(AssignblkNode(loc_key, index, dst)) + found = True + break + if found: + break + + if not found: + for predecessor in ircfg.predecessors(loc_key): + worklist.add((variable, predecessor)) + + return defs + + def find_out_regs_definitions_from_block(self, block, ircfg): + """ + Find definitions of out regs starting from @block + """ + worklist = set() + for reg in self.lifter.get_out_regs(block): + worklist.add((reg, block.loc_key)) + ret = self.find_definitions_from_worklist(worklist, ircfg) + return ret + + + def add_def_for_incomplete_leaf(self, block, ircfg, reaching_defs): + """ + Add valid definitions at end of @block plus out regs + """ + valid_definitions = reaching_defs.get_definitions( + block.loc_key, + len(block) + ) + worklist = set() + for lval, definitions in viewitems(valid_definitions): + for definition in definitions: + new_lval = self.expr_to_original_expr.get(lval, lval) + worklist.add((new_lval, block.loc_key)) + ret = self.find_definitions_from_worklist(worklist, ircfg) + useful = ret + useful.update(self.find_out_regs_definitions_from_block(block, ircfg)) + return useful + + def get_useful_assignments(self, ircfg, defuse, reaching_defs): + """ + Mark useful statements using previous reach analysis and defuse + + Return a set of triplets (block, assignblk number, lvalue) of + useful definitions + PRE: compute_reach(self) + + """ + + useful = set() + + for block_lbl, block in viewitems(ircfg.blocks): + block = ircfg.get_block(block_lbl) + if block is None: + # skip unknown blocks: won't generate dependencies + continue + + block_useful = self.get_block_useful_destinations(block) + useful.update(block_useful) + + + successors = ircfg.successors(block_lbl) + for successor in successors: + if successor not in ircfg.blocks: + keep_all_definitions = True + break + else: + keep_all_definitions = False + + if keep_all_definitions: + useful.update(self.add_def_for_incomplete_leaf(block, ircfg, reaching_defs)) + continue + + if len(successors) == 0: + useful.update(self.find_out_regs_definitions_from_block(block, ircfg)) + else: + continue + + + + # Useful nodes dependencies + for node in useful: + for parent in defuse.reachable_parents(node): + yield parent + + def do_dead_removal(self, ircfg): + """ + Remove useless assignments. + + This function is used to analyse relation of a * complete function * + This means the blocks under study represent a solid full function graph. + + Source : Kennedy, K. (1979). A survey of data flow analysis techniques. + IBM Thomas J. Watson Research Division, page 43 + + @ircfg: Lifter instance + """ + + modified = False + reaching_defs = ReachingDefinitions(ircfg) + defuse = DiGraphDefUse(reaching_defs, deref_mem=True) + useful = self.get_useful_assignments(ircfg, defuse, reaching_defs) + useful = set(useful) + for block in list(viewvalues(ircfg.blocks)): + irs = [] + for idx, assignblk in enumerate(block): + new_assignblk = dict(assignblk) + for lval in assignblk: + if AssignblkNode(block.loc_key, idx, lval) not in useful: + del new_assignblk[lval] + modified = True + irs.append(AssignBlock(new_assignblk, assignblk.instr)) + ircfg.blocks[block.loc_key] = IRBlock(block.loc_db, block.loc_key, irs) + return modified + + def __call__(self, ircfg): + ret = self.do_dead_removal(ircfg) + return ret + + +def _test_merge_next_block(ircfg, loc_key): + """ + Test if the irblock at @loc_key can be merge with its son + @ircfg: IRCFG instance + @loc_key: LocKey instance of the candidate parent irblock + """ + + if loc_key not in ircfg.blocks: + return None + sons = ircfg.successors(loc_key) + if len(sons) != 1: + return None + son = list(sons)[0] + if ircfg.predecessors(son) != [loc_key]: + return None + if son not in ircfg.blocks: + return None + + return son + + +def _do_merge_blocks(ircfg, loc_key, son_loc_key): + """ + Merge two irblocks at @loc_key and @son_loc_key + + @ircfg: DiGrpahIR + @loc_key: LocKey instance of the parent IRBlock + @loc_key: LocKey instance of the son IRBlock + """ + + assignblks = [] + for assignblk in ircfg.blocks[loc_key]: + if ircfg.IRDst not in assignblk: + assignblks.append(assignblk) + continue + affs = {} + for dst, src in viewitems(assignblk): + if dst != ircfg.IRDst: + affs[dst] = src + if affs: + assignblks.append(AssignBlock(affs, assignblk.instr)) + + assignblks += ircfg.blocks[son_loc_key].assignblks + new_block = IRBlock(ircfg.loc_db, loc_key, assignblks) + + ircfg.discard_edge(loc_key, son_loc_key) + + for son_successor in ircfg.successors(son_loc_key): + ircfg.add_uniq_edge(loc_key, son_successor) + ircfg.discard_edge(son_loc_key, son_successor) + del ircfg.blocks[son_loc_key] + ircfg.del_node(son_loc_key) + ircfg.blocks[loc_key] = new_block + + +def _test_jmp_only(ircfg, loc_key, heads): + """ + If irblock at @loc_key sets only IRDst to an ExprLoc, return the + corresponding loc_key target. + Avoid creating predecssors for heads LocKeys + None in other cases. + + @ircfg: IRCFG instance + @loc_key: LocKey instance of the candidate irblock + @heads: LocKey heads of the graph + + """ + + if loc_key not in ircfg.blocks: + return None + irblock = ircfg.blocks[loc_key] + if len(irblock.assignblks) != 1: + return None + items = list(viewitems(dict(irblock.assignblks[0]))) + if len(items) != 1: + return None + if len(ircfg.successors(loc_key)) != 1: + return None + # Don't create predecessors on heads + dst, src = items[0] + assert dst.is_id("IRDst") + if not src.is_loc(): + return None + dst = src.loc_key + if loc_key in heads: + predecessors = set(ircfg.predecessors(dst)) + predecessors.difference_update(set([loc_key])) + if predecessors: + return None + return dst + + +def _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct): + """ + Link loc_key's parents to parents directly to son_loc_key + """ + for parent in set(ircfg.predecessors(loc_key)): + parent_block = ircfg.blocks.get(parent, None) + if parent_block is None: + continue + + new_block = parent_block.modify_exprs( + lambda expr:expr.replace_expr(replace_dct), + lambda expr:expr.replace_expr(replace_dct) + ) + + # Link parent to new dst + ircfg.add_uniq_edge(parent, son_loc_key) + + # Unlink block + ircfg.blocks[new_block.loc_key] = new_block + ircfg.del_node(loc_key) + + +def _remove_to_son(ircfg, loc_key, son_loc_key): + """ + Merge irblocks; The final block has the @son_loc_key loc_key + Update references + + Condition: + - irblock at @loc_key is a pure jump block + - @loc_key is not an entry point (can be removed) + + @irblock: IRCFG instance + @loc_key: LocKey instance of the parent irblock + @son_loc_key: LocKey instance of the son irblock + """ + + # Ircfg loop => don't mess + if loc_key == son_loc_key: + return False + + # Unlink block destinations + ircfg.del_edge(loc_key, son_loc_key) + + replace_dct = { + ExprLoc(loc_key, ircfg.IRDst.size):ExprLoc(son_loc_key, ircfg.IRDst.size) + } + + _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct) + + ircfg.del_node(loc_key) + del ircfg.blocks[loc_key] + + return True + + +def _remove_to_parent(ircfg, loc_key, son_loc_key): + """ + Merge irblocks; The final block has the @loc_key loc_key + Update references + + Condition: + - irblock at @loc_key is a pure jump block + - @son_loc_key is not an entry point (can be removed) + + @irblock: IRCFG instance + @loc_key: LocKey instance of the parent irblock + @son_loc_key: LocKey instance of the son irblock + """ + + # Ircfg loop => don't mess + if loc_key == son_loc_key: + return False + + # Unlink block destinations + ircfg.del_edge(loc_key, son_loc_key) + + old_irblock = ircfg.blocks[son_loc_key] + new_irblock = IRBlock(ircfg.loc_db, loc_key, old_irblock.assignblks) + + ircfg.blocks[son_loc_key] = new_irblock + + ircfg.add_irblock(new_irblock) + + replace_dct = { + ExprLoc(son_loc_key, ircfg.IRDst.size):ExprLoc(loc_key, ircfg.IRDst.size) + } + + _relink_block_node(ircfg, son_loc_key, loc_key, replace_dct) + + + ircfg.del_node(son_loc_key) + del ircfg.blocks[son_loc_key] + + return True + + +def merge_blocks(ircfg, heads): + """ + This function modifies @ircfg to apply the following transformations: + - group an irblock with its son if the irblock has one and only one son and + this son has one and only one parent (spaghetti code). + - if an irblock is only made of an assignment to IRDst with a given label, + this irblock is dropped and its parent destination targets are + updated. The irblock must have a parent (avoid deleting the function head) + - if an irblock is a head of the graph and is only made of an assignment to + IRDst with a given label, this irblock is dropped and its son becomes the + head. References are fixed + + This function avoid creating predecessors on heads + + Return True if at least an irblock has been modified + + @ircfg: IRCFG instance + @heads: loc_key to keep + """ + + modified = False + todo = set(ircfg.nodes()) + while todo: + loc_key = todo.pop() + + # Test merge block + son = _test_merge_next_block(ircfg, loc_key) + if son is not None and son not in heads: + _do_merge_blocks(ircfg, loc_key, son) + todo.add(loc_key) + modified = True + continue + + # Test jmp only block + son = _test_jmp_only(ircfg, loc_key, heads) + if son is not None and loc_key not in heads: + ret = _remove_to_son(ircfg, loc_key, son) + modified |= ret + if ret: + todo.add(loc_key) + continue + + # Test head jmp only block + if (son is not None and + son not in heads and + son in ircfg.blocks): + # jmp only test done previously + ret = _remove_to_parent(ircfg, loc_key, son) + modified |= ret + if ret: + todo.add(loc_key) + continue + + + return modified + + +def remove_empty_assignblks(ircfg): + """ + Remove empty assignblks in irblocks of @ircfg + Return True if at least an irblock has been modified + + @ircfg: IRCFG instance + """ + modified = False + for loc_key, block in list(viewitems(ircfg.blocks)): + irs = [] + block_modified = False + for assignblk in block: + if len(assignblk): + irs.append(assignblk) + else: + block_modified = True + if block_modified: + new_irblock = IRBlock(ircfg.loc_db, loc_key, irs) + ircfg.blocks[loc_key] = new_irblock + modified = True + return modified + + +class SSADefUse(DiGraph): + """ + Generate DefUse information from SSA transformation + Links are not valid for ExprMem. + """ + + def add_var_def(self, node, src): + index2dst = self._links.setdefault(node.label, {}) + dst2src = index2dst.setdefault(node.index, {}) + dst2src[node.var] = src + + def add_def_node(self, def_nodes, node, src): + if node.var.is_id(): + def_nodes[node.var] = node + + def add_use_node(self, use_nodes, node, src): + sources = set() + if node.var.is_mem(): + sources.update(node.var.ptr.get_r(mem_read=True)) + sources.update(src.get_r(mem_read=True)) + for source in sources: + if not source.is_mem(): + use_nodes.setdefault(source, set()).add(node) + + def get_node_target(self, node): + return self._links[node.label][node.index][node.var] + + def set_node_target(self, node, src): + self._links[node.label][node.index][node.var] = src + + @classmethod + def from_ssa(cls, ssa): + """ + Return a DefUse DiGraph from a SSA graph + @ssa: SSADiGraph instance + """ + + graph = cls() + # First pass + # Link line to its use and def + def_nodes = {} + use_nodes = {} + graph._links = {} + for lbl in ssa.graph.nodes(): + block = ssa.graph.blocks.get(lbl, None) + if block is None: + continue + for index, assignblk in enumerate(block): + for dst, src in viewitems(assignblk): + node = AssignblkNode(lbl, index, dst) + graph.add_var_def(node, src) + graph.add_def_node(def_nodes, node, src) + graph.add_use_node(use_nodes, node, src) + + for dst, node in viewitems(def_nodes): + graph.add_node(node) + if dst not in use_nodes: + continue + for use in use_nodes[dst]: + graph.add_uniq_edge(node, use) + + return graph + + + +def expr_has_mem(expr): + """ + Return True if expr contains at least one memory access + @expr: Expr instance + """ + + def has_mem(self): + return self.is_mem() + visitor = ExprWalk(has_mem) + return visitor.visit(expr) + + +def stack_to_reg(expr): + if expr.is_mem(): + ptr = expr.arg + SP = lifter.sp + if ptr == SP: + return ExprId("STACK.0", expr.size) + elif (ptr.is_op('+') and + len(ptr.args) == 2 and + ptr.args[0] == SP and + ptr.args[1].is_int()): + diff = int(ptr.args[1]) + assert diff % 4 == 0 + diff = (0 - diff) & 0xFFFFFFFF + return ExprId("STACK.%d" % (diff // 4), expr.size) + return False + + +def is_stack_access(lifter, expr): + if not expr.is_mem(): + return False + ptr = expr.ptr + diff = expr_simp(ptr - lifter.sp) + if not diff.is_int(): + return False + return expr + + +def visitor_get_stack_accesses(lifter, expr, stack_vars): + if is_stack_access(lifter, expr): + stack_vars.add(expr) + return expr + + +def get_stack_accesses(lifter, expr): + result = set() + def get_stack(expr_to_test): + visitor_get_stack_accesses(lifter, expr_to_test, result) + return None + visitor = ExprWalk(get_stack) + visitor.visit(expr) + return result + + +def get_interval_length(interval_in): + length = 0 + for start, stop in interval_in.intervals: + length += stop + 1 - start + return length + + +def check_expr_below_stack(lifter, expr): + """ + Return False if expr pointer is below original stack pointer + @lifter: lifter_model_call instance + @expr: Expression instance + """ + ptr = expr.ptr + diff = expr_simp(ptr - lifter.sp) + if not diff.is_int(): + return True + if int(diff) == 0 or int(expr_simp(diff.msb())) == 0: + return False + return True + + +def retrieve_stack_accesses(lifter, ircfg): + """ + Walk the ssa graph and find stack based variables. + Return a dictionary linking stack base address to its size/name + @lifter: lifter_model_call instance + @ircfg: IRCFG instance + """ + stack_vars = set() + for block in viewvalues(ircfg.blocks): + for assignblk in block: + for dst, src in viewitems(assignblk): + stack_vars.update(get_stack_accesses(lifter, dst)) + stack_vars.update(get_stack_accesses(lifter, src)) + stack_vars = [expr for expr in stack_vars if check_expr_below_stack(lifter, expr)] + + base_to_var = {} + for var in stack_vars: + base_to_var.setdefault(var.ptr, set()).add(var) + + + base_to_interval = {} + for addr, vars in viewitems(base_to_var): + var_interval = interval() + for var in vars: + offset = expr_simp(addr - lifter.sp) + if not offset.is_int(): + # skip non linear stack offset + continue + + start = int(offset) + stop = int(expr_simp(offset + ExprInt(var.size // 8, offset.size))) + mem = interval([(start, stop-1)]) + var_interval += mem + base_to_interval[addr] = var_interval + if not base_to_interval: + return {} + # Check if not intervals overlap + _, tmp = base_to_interval.popitem() + while base_to_interval: + addr, mem = base_to_interval.popitem() + assert (tmp & mem).empty + tmp += mem + + base_to_info = {} + for addr, vars in viewitems(base_to_var): + name = "var_%d" % (len(base_to_info)) + size = max([var.size for var in vars]) + base_to_info[addr] = size, name + return base_to_info + + +def fix_stack_vars(expr, base_to_info): + """ + Replace local stack accesses in expr using information in @base_to_info + @expr: Expression instance + @base_to_info: dictionary linking stack base address to its size/name + """ + if not expr.is_mem(): + return expr + ptr = expr.ptr + if ptr not in base_to_info: + return expr + size, name = base_to_info[ptr] + var = ExprId(name, size) + if size == expr.size: + return var + assert expr.size < size + return var[:expr.size] + + +def replace_mem_stack_vars(expr, base_to_info): + return expr.visit(lambda expr:fix_stack_vars(expr, base_to_info)) + + +def replace_stack_vars(lifter, ircfg): + """ + Try to replace stack based memory accesses by variables. + + Hypothesis: the input ircfg must have all it's accesses to stack explicitly + done through the stack register, ie every aliases on those variables is + resolved. + + WARNING: may fail + + @lifter: lifter_model_call instance + @ircfg: IRCFG instance + """ + + base_to_info = retrieve_stack_accesses(lifter, ircfg) + modified = False + for block in list(viewvalues(ircfg.blocks)): + assignblks = [] + for assignblk in block: + out = {} + for dst, src in viewitems(assignblk): + new_dst = dst.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) + new_src = src.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) + if new_dst != dst or new_src != src: + modified |= True + + out[new_dst] = new_src + + out = AssignBlock(out, assignblk.instr) + assignblks.append(out) + new_block = IRBlock(block.loc_db, block.loc_key, assignblks) + ircfg.blocks[block.loc_key] = new_block + return modified + + +def memlookup_test(expr, bs, is_addr_ro_variable, result): + if expr.is_mem() and expr.ptr.is_int(): + ptr = int(expr.ptr) + if is_addr_ro_variable(bs, ptr, expr.size): + result.add(expr) + return False + return True + + +def memlookup_visit(expr, bs, is_addr_ro_variable): + result = set() + def retrieve_memlookup(expr_to_test): + memlookup_test(expr_to_test, bs, is_addr_ro_variable, result) + return None + visitor = ExprWalk(retrieve_memlookup) + visitor.visit(expr) + return result + +def get_memlookup(expr, bs, is_addr_ro_variable): + return memlookup_visit(expr, bs, is_addr_ro_variable) + + +def read_mem(bs, expr): + ptr = int(expr.ptr) + var_bytes = bs.getbytes(ptr, expr.size // 8)[::-1] + try: + value = int(encode_hex(var_bytes), 16) + except ValueError: + return expr + return ExprInt(value, expr.size) + + +def load_from_int(ircfg, bs, is_addr_ro_variable): + """ + Replace memory read based on constant with static value + @ircfg: IRCFG instance + @bs: binstream instance + @is_addr_ro_variable: callback(addr, size) to test memory candidate + """ + + modified = False + for block in list(viewvalues(ircfg.blocks)): + assignblks = list() + for assignblk in block: + out = {} + for dst, src in viewitems(assignblk): + # Test src + mems = get_memlookup(src, bs, is_addr_ro_variable) + src_new = src + if mems: + replace = {} + for mem in mems: + value = read_mem(bs, mem) + replace[mem] = value + src_new = src.replace_expr(replace) + if src_new != src: + modified = True + # Test dst pointer if dst is mem + if dst.is_mem(): + ptr = dst.ptr + mems = get_memlookup(ptr, bs, is_addr_ro_variable) + if mems: + replace = {} + for mem in mems: + value = read_mem(bs, mem) + replace[mem] = value + ptr_new = ptr.replace_expr(replace) + if ptr_new != ptr: + modified = True + dst = ExprMem(ptr_new, dst.size) + out[dst] = src_new + out = AssignBlock(out, assignblk.instr) + assignblks.append(out) + block = IRBlock(block.loc_db, block.loc_key, assignblks) + ircfg.blocks[block.loc_key] = block + return modified + + +class AssignBlockLivenessInfos(object): + """ + Description of live in / live out of an AssignBlock + """ + + __slots__ = ["gen", "kill", "var_in", "var_out", "live", "assignblk"] + + def __init__(self, assignblk, gen, kill): + self.gen = gen + self.kill = kill + self.var_in = set() + self.var_out = set() + self.live = set() + self.assignblk = assignblk + + def __str__(self): + out = [] + out.append("\tVarIn:" + ", ".join(str(x) for x in self.var_in)) + out.append("\tGen:" + ", ".join(str(x) for x in self.gen)) + out.append("\tKill:" + ", ".join(str(x) for x in self.kill)) + out.append( + '\n'.join( + "\t%s = %s" % (dst, src) + for (dst, src) in viewitems(self.assignblk) + ) + ) + out.append("\tVarOut:" + ", ".join(str(x) for x in self.var_out)) + return '\n'.join(out) + + +class IRBlockLivenessInfos(object): + """ + Description of live in / live out of an AssignBlock + """ + __slots__ = ["loc_key", "infos", "assignblks"] + + + def __init__(self, irblock): + self.loc_key = irblock.loc_key + self.infos = [] + self.assignblks = [] + for assignblk in irblock: + gens, kills = set(), set() + for dst, src in viewitems(assignblk): + expr = ExprAssign(dst, src) + read = expr.get_r(mem_read=True) + write = expr.get_w() + gens.update(read) + kills.update(write) + self.infos.append(AssignBlockLivenessInfos(assignblk, gens, kills)) + self.assignblks.append(assignblk) + + def __getitem__(self, index): + """Getitem on assignblks""" + return self.assignblks.__getitem__(index) + + def __str__(self): + out = [] + out.append("%s:" % self.loc_key) + for info in self.infos: + out.append(str(info)) + out.append('') + return "\n".join(out) + + +class DiGraphLiveness(DiGraph): + """ + DiGraph representing variable liveness + """ + + def __init__(self, ircfg): + super(DiGraphLiveness, self).__init__() + self.ircfg = ircfg + self.loc_db = ircfg.loc_db + self._blocks = {} + # Add irblocks gen/kill + for node in ircfg.nodes(): + irblock = ircfg.blocks.get(node, None) + if irblock is None: + continue + irblockinfos = IRBlockLivenessInfos(irblock) + self.add_node(irblockinfos.loc_key) + self.blocks[irblockinfos.loc_key] = irblockinfos + for succ in ircfg.successors(node): + self.add_uniq_edge(node, succ) + for pred in ircfg.predecessors(node): + self.add_uniq_edge(pred, node) + + @property + def blocks(self): + return self._blocks + + def init_var_info(self): + """Add ircfg out regs""" + raise NotImplementedError("Abstract method") + + def node2lines(self, node): + """ + Output liveness information in dot format + """ + names = self.loc_db.get_location_names(node) + if not names: + node_name = self.loc_db.pretty_str(node) + else: + node_name = "".join("%s:\n" % name for name in names) + yield self.DotCellDescription( + text="%s" % node_name, + attr={ + 'align': 'center', + 'colspan': 2, + 'bgcolor': 'grey', + } + ) + if node not in self._blocks: + yield [self.DotCellDescription(text="NOT PRESENT", attr={})] + return + + for i, info in enumerate(self._blocks[node].infos): + var_in = "VarIn:" + ", ".join(str(x) for x in info.var_in) + var_out = "VarOut:" + ", ".join(str(x) for x in info.var_out) + + assignmnts = ["%s = %s" % (dst, src) for (dst, src) in viewitems(info.assignblk)] + + if i == 0: + yield self.DotCellDescription( + text=var_in, + attr={ + 'bgcolor': 'green', + } + ) + + for assign in assignmnts: + yield self.DotCellDescription(text=assign, attr={}) + yield self.DotCellDescription( + text=var_out, + attr={ + 'bgcolor': 'green', + } + ) + yield self.DotCellDescription(text="", attr={}) + + def back_propagate_compute(self, block): + """ + Compute the liveness information in the @block. + @block: AssignBlockLivenessInfos instance + """ + infos = block.infos + modified = False + for i in reversed(range(len(infos))): + new_vars = set(infos[i].gen.union(infos[i].var_out.difference(infos[i].kill))) + if infos[i].var_in != new_vars: + modified = True + infos[i].var_in = new_vars + if i > 0 and infos[i - 1].var_out != set(infos[i].var_in): + modified = True + infos[i - 1].var_out = set(infos[i].var_in) + return modified + + def back_propagate_to_parent(self, todo, node, parent): + """ + Back propagate the liveness information from @node to @parent. + @node: loc_key of the source node + @parent: loc_key of the node to update + """ + parent_block = self.blocks[parent] + cur_block = self.blocks[node] + if cur_block.infos[0].var_in == parent_block.infos[-1].var_out: + return + var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out) + parent_block.infos[-1].var_out = var_info + todo.add(parent) + + def compute_liveness(self): + """ + Compute the liveness information for the digraph. + """ + todo = set(self.leaves()) + while todo: + node = todo.pop() + cur_block = self.blocks.get(node, None) + if cur_block is None: + continue + modified = self.back_propagate_compute(cur_block) + if not modified: + continue + # We modified parent in, propagate to parents + for pred in self.predecessors(node): + self.back_propagate_to_parent(todo, node, pred) + return True + + +class DiGraphLivenessIRA(DiGraphLiveness): + """ + DiGraph representing variable liveness for IRA + """ + + def init_var_info(self, lifter): + """Add ircfg out regs""" + + for node in self.leaves(): + irblock = self.ircfg.blocks.get(node, None) + if irblock is None: + continue + var_out = lifter.get_out_regs(irblock) + irblock_liveness = self.blocks[node] + irblock_liveness.infos[-1].var_out = var_out + + +def discard_phi_sources(ircfg, deleted_vars): + """ + Remove phi sources in @ircfg belonging to @deleted_vars set + @ircfg: IRCFG instance in ssa form + @deleted_vars: unused phi sources + """ + for block in list(viewvalues(ircfg.blocks)): + if not block.assignblks: + continue + assignblk = block[0] + todo = {} + modified = False + for dst, src in viewitems(assignblk): + if not src.is_op('Phi'): + todo[dst] = src + continue + srcs = set(expr for expr in src.args if expr not in deleted_vars) + assert(srcs) + if len(srcs) > 1: + todo[dst] = ExprOp('Phi', *srcs) + continue + todo[dst] = srcs.pop() + modified = True + if not modified: + continue + assignblks = list(block) + assignblk = dict(assignblk) + assignblk.update(todo) + assignblk = AssignBlock(assignblk, assignblks[0].instr) + assignblks[0] = assignblk + new_irblock = IRBlock(block.loc_db, block.loc_key, assignblks) + ircfg.blocks[block.loc_key] = new_irblock + return True + + +def get_unreachable_nodes(ircfg, edges_to_del, heads): + """ + Return the unreachable nodes starting from heads and the associated edges to + be deleted. + + @ircfg: IRCFG instance + @edges_to_del: edges already marked as deleted + heads: locations of graph heads + """ + todo = set(heads) + visited_nodes = set() + new_edges_to_del = set() + while todo: + node = todo.pop() + if node in visited_nodes: + continue + visited_nodes.add(node) + for successor in ircfg.successors(node): + if (node, successor) not in edges_to_del: + todo.add(successor) + all_nodes = set(ircfg.nodes()) + nodes_to_del = all_nodes.difference(visited_nodes) + for node in nodes_to_del: + for successor in ircfg.successors(node): + if successor not in nodes_to_del: + # Frontier: link from a deleted node to a living node + new_edges_to_del.add((node, successor)) + return nodes_to_del, new_edges_to_del + + +def update_phi_with_deleted_edges(ircfg, edges_to_del): + """ + Update phi which have a source present in @edges_to_del + @ssa: IRCFG instance in ssa form + @edges_to_del: edges to delete + """ + + + phi_locs_to_srcs = {} + for loc_src, loc_dst in edges_to_del: + phi_locs_to_srcs.setdefault(loc_dst, set()).add(loc_src) + + modified = False + blocks = dict(ircfg.blocks) + for loc_dst, loc_srcs in viewitems(phi_locs_to_srcs): + if loc_dst not in ircfg.blocks: + continue + block = ircfg.blocks[loc_dst] + if not irblock_has_phi(block): + continue + assignblks = list(block) + assignblk = assignblks[0] + out = {} + for dst, phi_sources in viewitems(assignblk): + if not phi_sources.is_op('Phi'): + out[dst] = phi_sources + continue + var_to_parents = get_phi_sources_parent_block( + ircfg, + loc_dst, + phi_sources.args + ) + to_keep = set(phi_sources.args) + for src in phi_sources.args: + parents = var_to_parents[src] + remaining = parents.difference(loc_srcs) + if not remaining: + to_keep.discard(src) + modified = True + assert to_keep + if len(to_keep) == 1: + out[dst] = to_keep.pop() + else: + out[dst] = ExprOp('Phi', *to_keep) + assignblk = AssignBlock(out, assignblks[0].instr) + assignblks[0] = assignblk + new_irblock = IRBlock(block.loc_db, loc_dst, assignblks) + blocks[block.loc_key] = new_irblock + + for loc_key, block in viewitems(blocks): + ircfg.blocks[loc_key] = block + return modified + + +def del_unused_edges(ircfg, heads): + """ + Delete non accessible edges in the @ircfg graph. + @ircfg: IRCFG instance in ssa form + @heads: location of the heads of the graph + """ + + deleted_vars = set() + modified = False + edges_to_del_1 = set() + for node in ircfg.nodes(): + successors = set(ircfg.successors(node)) + block = ircfg.blocks.get(node, None) + if block is None: + continue + dst = block.dst + possible_dsts = set(solution.value for solution in possible_values(dst)) + if not all(dst.is_loc() for dst in possible_dsts): + continue + possible_dsts = set(dst.loc_key for dst in possible_dsts) + if len(possible_dsts) == len(successors): + continue + dsts_to_del = successors.difference(possible_dsts) + for dst in dsts_to_del: + edges_to_del_1.add((node, dst)) + + # Remove edges and update phi accordingly + # Two cases here: + # - edge is directly linked to a phi node + # - edge is indirect linked to a phi node + nodes_to_del, edges_to_del_2 = get_unreachable_nodes(ircfg, edges_to_del_1, heads) + modified |= update_phi_with_deleted_edges(ircfg, edges_to_del_1.union(edges_to_del_2)) + + for src, dst in edges_to_del_1.union(edges_to_del_2): + ircfg.del_edge(src, dst) + for node in nodes_to_del: + if node not in ircfg.blocks: + continue + block = ircfg.blocks[node] + ircfg.del_node(node) + del ircfg.blocks[node] + + for assignblock in block: + for dst in assignblock: + deleted_vars.add(dst) + + if deleted_vars: + modified |= discard_phi_sources(ircfg, deleted_vars) + + return modified + + +class DiGraphLivenessSSA(DiGraphLivenessIRA): + """ + DiGraph representing variable liveness is a SSA graph + """ + def __init__(self, ircfg): + super(DiGraphLivenessSSA, self).__init__(ircfg) + + self.loc_key_to_phi_parents = {} + for irblock in viewvalues(self.blocks): + if not irblock_has_phi(irblock): + continue + out = {} + for sources in viewvalues(irblock[0]): + if not sources.is_op('Phi'): + # Some phi sources may have already been resolved to an + # expression + continue + var_to_parents = get_phi_sources_parent_block(self, irblock.loc_key, sources.args) + for var, var_parents in viewitems(var_to_parents): + out.setdefault(var, set()).update(var_parents) + self.loc_key_to_phi_parents[irblock.loc_key] = out + + def back_propagate_to_parent(self, todo, node, parent): + if parent not in self.blocks: + return + parent_block = self.blocks[parent] + cur_block = self.blocks[node] + irblock = self.ircfg.blocks[node] + if cur_block.infos[0].var_in == parent_block.infos[-1].var_out: + return + var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out) + + if irblock_has_phi(irblock): + # Remove phi special case + out = set() + phi_sources = self.loc_key_to_phi_parents[irblock.loc_key] + for var in var_info: + if var not in phi_sources: + out.add(var) + continue + if parent in phi_sources[var]: + out.add(var) + var_info = out + + parent_block.infos[-1].var_out = var_info + todo.add(parent) + + +def get_phi_sources(phi_src, phi_dsts, ids_to_src): + """ + Return False if the @phi_src has more than one non-phi source + Else, return its source + @ids_to_src: Dictionary linking phi source to its definition + """ + true_values = set() + for src in phi_src.args: + if src in phi_dsts: + # Source is phi dst => skip + continue + true_src = ids_to_src[src] + if true_src in phi_dsts: + # Source is phi dst => skip + continue + # Check if src is not also a phi + if true_src.is_op('Phi'): + phi_dsts.add(src) + true_src = get_phi_sources(true_src, phi_dsts, ids_to_src) + if true_src is False: + return False + if true_src is True: + continue + true_values.add(true_src) + if len(true_values) != 1: + return False + if not true_values: + return True + if len(true_values) != 1: + return False + true_value = true_values.pop() + return true_value + + +class DelDummyPhi(object): + """ + Del dummy phi + Find nodes which are in the same equivalence class and replace phi nodes by + the class representative. + """ + + def src_gen_phi_node_srcs(self, equivalence_graph): + for node in equivalence_graph.nodes(): + if not node.is_op("Phi"): + continue + phi_successors = equivalence_graph.successors(node) + for head in phi_successors: + # Walk from head to find if we have a phi merging node + known = set([node]) + todo = set([head]) + done = set() + while todo: + node = todo.pop() + if node in done: + continue + + known.add(node) + is_ok = True + for parent in equivalence_graph.predecessors(node): + if parent not in known: + is_ok = False + break + if not is_ok: + continue + if node.is_op("Phi"): + successors = equivalence_graph.successors(node) + phi_node = successors.pop() + return set([phi_node]), phi_node, head, equivalence_graph + done.add(node) + for successor in equivalence_graph.successors(node): + todo.add(successor) + return None + + def get_equivalence_class(self, node, ids_to_src): + todo = set([node]) + done = set() + defined = set() + equivalence = set() + src_to_dst = {} + equivalence_graph = DiGraph() + while todo: + dst = todo.pop() + if dst in done: + continue + done.add(dst) + equivalence.add(dst) + src = ids_to_src.get(dst) + if src is None: + # Node is not defined + continue + src_to_dst[src] = dst + defined.add(dst) + if src.is_id(): + equivalence_graph.add_uniq_edge(src, dst) + todo.add(src) + elif src.is_op('Phi'): + equivalence_graph.add_uniq_edge(src, dst) + for arg in src.args: + assert arg.is_id() + equivalence_graph.add_uniq_edge(arg, src) + todo.add(arg) + else: + if src.is_mem() or (src.is_op() and src.op.startswith("call")): + if src in equivalence_graph.nodes(): + return None + equivalence_graph.add_uniq_edge(src, dst) + equivalence.add(src) + + if len(equivalence_graph.heads()) == 0: + raise RuntimeError("Inconsistent graph") + elif len(equivalence_graph.heads()) == 1: + # Every nodes in the equivalence graph may be equivalent to the root + head = equivalence_graph.heads().pop() + successors = equivalence_graph.successors(head) + if len(successors) == 1: + # If successor is an id + successor = successors.pop() + if successor.is_id(): + nodes = equivalence_graph.nodes() + nodes.discard(head) + nodes.discard(successor) + nodes = [node for node in nodes if node.is_id()] + return nodes, successor, head, equivalence_graph + else: + # Walk from head to find if we have a phi merging node + known = set() + todo = set([head]) + done = set() + while todo: + node = todo.pop() + if node in done: + continue + known.add(node) + is_ok = True + for parent in equivalence_graph.predecessors(node): + if parent not in known: + is_ok = False + break + if not is_ok: + continue + if node.is_op("Phi"): + successors = equivalence_graph.successors(node) + assert len(successors) == 1 + phi_node = successors.pop() + return set([phi_node]), phi_node, head, equivalence_graph + done.add(node) + for successor in equivalence_graph.successors(node): + todo.add(successor) + + return self.src_gen_phi_node_srcs(equivalence_graph) + + def del_dummy_phi(self, ssa, head): + ids_to_src = {} + def_to_loc = {} + for block in viewvalues(ssa.graph.blocks): + for index, assignblock in enumerate(block): + for dst, src in viewitems(assignblock): + if not dst.is_id(): + continue + ids_to_src[dst] = src + def_to_loc[dst] = block.loc_key + + + modified = False + for loc_key in ssa.graph.blocks.keys(): + block = ssa.graph.blocks[loc_key] + if not irblock_has_phi(block): + continue + assignblk = block[0] + for dst, phi_src in viewitems(assignblk): + assert phi_src.is_op('Phi') + result = self.get_equivalence_class(dst, ids_to_src) + if result is None: + continue + defined, node, true_value, equivalence_graph = result + if expr_has_mem(true_value): + # Don't propagate ExprMem + continue + if true_value.is_op() and true_value.op.startswith("call"): + # Don't propagate call + continue + # We have an equivalence of nodes + to_del = set(defined) + # Remove all implicated phis + for dst in to_del: + loc_key = def_to_loc[dst] + block = ssa.graph.blocks[loc_key] + + assignblk = block[0] + fixed_phis = {} + for old_dst, old_phi_src in viewitems(assignblk): + if old_dst in defined: + continue + fixed_phis[old_dst] = old_phi_src + + assignblks = list(block) + assignblks[0] = AssignBlock(fixed_phis, assignblk.instr) + assignblks[1:1] = [AssignBlock({dst: true_value}, assignblk.instr)] + new_irblock = IRBlock(block.loc_db, block.loc_key, assignblks) + ssa.graph.blocks[loc_key] = new_irblock + modified = True + return modified + + +def replace_expr_from_bottom(expr_orig, dct): + def replace(expr): + if expr in dct: + return dct[expr] + return expr + visitor = ExprVisitorCallbackBottomToTop(lambda expr:replace(expr)) + return visitor.visit(expr_orig) + + +def is_mem_sub_part(needle, mem): + """ + If @needle is a sub part of @mem, return the offset of @needle in @mem + Else, return False + @needle: ExprMem + @mem: ExprMem + """ + ptr_base_a, ptr_offset_a = get_expr_base_offset(needle.ptr) + ptr_base_b, ptr_offset_b = get_expr_base_offset(mem.ptr) + if ptr_base_a != ptr_base_b: + return False + # Test if sub part starts after mem + if not (ptr_offset_b <= ptr_offset_a < ptr_offset_b + mem.size // 8): + return False + # Test if sub part ends before mem + if not (ptr_offset_a + needle.size // 8 <= ptr_offset_b + mem.size // 8): + return False + return ptr_offset_a - ptr_offset_b + +class UnionFind(object): + """ + Implementation of UnionFind structure + __classes: a list of Set of equivalent elements + node_to_class: Dictionary linkink an element to its equivalent class + order: Dictionary link an element to it's weight + + The order attributes is used to allow the selection of a representative + element of an equivalence class + """ + + def __init__(self): + self.index = 0 + self.__classes = [] + self.node_to_class = {} + self.order = dict() + + def copy(self): + """ + Return a copy of the object + """ + unionfind = UnionFind() + unionfind.index = self.index + unionfind.__classes = [set(known_class) for known_class in self.__classes] + node_to_class = {} + for class_eq in unionfind.__classes: + for node in class_eq: + node_to_class[node] = class_eq + unionfind.node_to_class = node_to_class + unionfind.order = dict(self.order) + return unionfind + + def replace_node(self, old_node, new_node): + """ + Replace the @old_node by the @new_node + """ + classes = self.get_classes() + + new_classes = [] + replace_dct = {old_node:new_node} + for eq_class in classes: + new_class = set() + for node in eq_class: + new_class.add(replace_expr_from_bottom(node, replace_dct)) + new_classes.append(new_class) + + node_to_class = {} + for class_eq in new_classes: + for node in class_eq: + node_to_class[node] = class_eq + self.__classes = new_classes + self.node_to_class = node_to_class + new_order = dict() + for node,index in self.order.items(): + new_node = replace_expr_from_bottom(node, replace_dct) + new_order[new_node] = index + self.order = new_order + + def get_classes(self): + """ + Return a list of the equivalent classes + """ + classes = [] + for class_tmp in self.__classes: + classes.append(set(class_tmp)) + return classes + + def nodes(self): + for known_class in self.__classes: + for node in known_class: + yield node + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + + return Counter(frozenset(known_class) for known_class in self.__classes) == Counter(frozenset(known_class) for known_class in other.__classes) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __str__(self): + components = self.__classes + out = ['UnionFind<'] + for component in components: + out.append("\t" + (", ".join([str(node) for node in component]))) + out.append('>') + return "\n".join(out) + + def add_equivalence(self, node_a, node_b): + """ + Add the new equivalence @node_a == @node_b + @node_a is equivalent to @node_b, but @node_b is more representative + than @node_a + """ + if node_b not in self.order: + self.order[node_b] = self.index + self.index += 1 + # As node_a is destination, we always replace its index + self.order[node_a] = self.index + self.index += 1 + + if node_a not in self.node_to_class and node_b not in self.node_to_class: + new_class = set([node_a, node_b]) + self.node_to_class[node_a] = new_class + self.node_to_class[node_b] = new_class + self.__classes.append(new_class) + elif node_a in self.node_to_class and node_b not in self.node_to_class: + known_class = self.node_to_class[node_a] + known_class.add(node_b) + self.node_to_class[node_b] = known_class + elif node_a not in self.node_to_class and node_b in self.node_to_class: + known_class = self.node_to_class[node_b] + known_class.add(node_a) + self.node_to_class[node_a] = known_class + else: + raise RuntimeError("Two nodes cannot be in two classes") + + def _get_master(self, node): + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + best_node = node + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + + def get_master(self, node): + """ + Return the representative element of the equivalence class containing + @node + @node: ExprMem or ExprId + """ + if not node.is_mem(): + return self._get_master(node) + if node in self.node_to_class: + # Full expr mem is known + return self._get_master(node) + # Test if mem is sub part of known node + for expr in self.node_to_class: + if not expr.is_mem(): + continue + ret = is_mem_sub_part(node, expr) + if ret is False: + continue + master = self._get_master(expr) + master = master[ret * 8 : ret * 8 + node.size] + return master + + return self._get_master(node) + + + def del_element(self, node): + """ + Remove @node for the equivalence classes + """ + assert node in self.node_to_class + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + def del_get_new_master(self, node): + """ + Remove @node for the equivalence classes and return it's representative + equivalent element + @node: Element to delete + """ + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + if not known_class: + return None + best_node = list(known_class)[0] + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + +class ExprToGraph(ExprWalk): + """ + Transform an Expression into a tree and add link nodes to an existing tree + """ + def __init__(self, graph): + super(ExprToGraph, self).__init__(self.link_nodes) + self.graph = graph + + def link_nodes(self, expr, *args, **kwargs): + """ + Transform an Expression @expr into a tree and add link nodes to the + current tree + @expr: Expression + """ + if expr in self.graph.nodes(): + return None + self.graph.add_node(expr) + if expr.is_mem(): + self.graph.add_uniq_edge(expr, expr.ptr) + elif expr.is_slice(): + self.graph.add_uniq_edge(expr, expr.arg) + elif expr.is_cond(): + self.graph.add_uniq_edge(expr, expr.cond) + self.graph.add_uniq_edge(expr, expr.src1) + self.graph.add_uniq_edge(expr, expr.src2) + elif expr.is_compose(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + elif expr.is_op(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + return None + +class State(object): + """ + Object representing the state of a program at a given point + The state is represented using equivalence classes + + Each assignment can create/destroy equivalence classes. Interferences + between expression is computed using `may_interfer` function + """ + + def __init__(self): + self.equivalence_classes = UnionFind() + self.undefined = set() + + def __str__(self): + return "{0.equivalence_classes}\n{0.undefined}".format(self) + + def copy(self): + state = self.__class__() + state.equivalence_classes = self.equivalence_classes.copy() + state.undefined = self.undefined.copy() + return state + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + return ( + set(self.equivalence_classes.nodes()) == set(other.equivalence_classes.nodes()) and + sorted(self.equivalence_classes.edges()) == sorted(other.equivalence_classes.edges()) and + self.undefined == other.undefined + ) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def may_interfer(self, dsts, src): + """ + Return True if @src may interfere with expressions in @dsts + @dsts: Set of Expressions + @src: expression to test + """ + + srcs = src.get_r() + for src in srcs: + for dst in dsts: + if dst in src: + return True + if dst.is_mem() and src.is_mem(): + dst_base, dst_offset = get_expr_base_offset(dst.ptr) + src_base, src_offset = get_expr_base_offset(src.ptr) + if dst_base != src_base: + return True + dst_size = dst.size // 8 + src_size = src.size // 8 + # Special case: + # @32[ESP + 0xFFFFFFFE], @32[ESP] + # Both memories alias + if dst_offset + dst_size <= int(dst_base.mask) + 1: + # @32[ESP + 0xFFFFFFFC] => [0xFFFFFFFC, 0xFFFFFFFF] + interval1 = interval([(dst_offset, dst_offset + dst.size // 8 - 1)]) + else: + # @32[ESP + 0xFFFFFFFE] => [0x0, 0x1] U [0xFFFFFFFE, 0xFFFFFFFF] + interval1 = interval([(dst_offset, int(dst_base.mask))]) + interval1 += interval([(0, dst_size - (int(dst_base.mask) + 1 - dst_offset) - 1 )]) + if src_offset + src_size <= int(src_base.mask) + 1: + # @32[ESP + 0xFFFFFFFC] => [0xFFFFFFFC, 0xFFFFFFFF] + interval2 = interval([(src_offset, src_offset + src.size // 8 - 1)]) + else: + # @32[ESP + 0xFFFFFFFE] => [0x0, 0x1] U [0xFFFFFFFE, 0xFFFFFFFF] + interval2 = interval([(src_offset, int(src_base.mask))]) + interval2 += interval([(0, src_size - (int(src_base.mask) + 1 - src_offset) - 1)]) + if (interval1 & interval2).empty: + continue + return True + return False + + def _get_representative_expr(self, expr): + representative = self.equivalence_classes.get_master(expr) + if representative is None: + return expr + return representative + + def get_representative_expr(self, expr): + """ + Replace each sub expression of @expr by its representative element + @expr: Expression to analyse + """ + new_expr = expr.visit(self._get_representative_expr) + return new_expr + + def propagation_allowed(self, expr): + """ + Return True if @expr can be propagated + Don't propagate: + - Phi nodes + - call_func_ret / call_func_stack operants + """ + + if ( + expr.is_op('Phi') or + (expr.is_op() and expr.op.startswith("call_func")) + ): + return False + return True + + def eval_assignblock(self, assignblock): + """ + Evaluate the @assignblock on the current state + @assignblock: AssignBlock instance + """ + + out = dict(assignblock.items()) + new_out = dict() + # Replace sub expression by their equivalence class repesentative + for dst, src in out.items(): + if src.is_op('Phi'): + # Don't replace in phi + new_src = src + else: + new_src = self.get_representative_expr(src) + if dst.is_mem(): + new_ptr = self.get_representative_expr(dst.ptr) + new_dst = ExprMem(new_ptr, dst.size) + else: + new_dst = dst + new_dst = expr_simp(new_dst) + new_src = expr_simp(new_src) + new_out[new_dst] = new_src + + # For each destination, update (or delete) dependent's node according to + # equivalence classes + classes = self.equivalence_classes + + for dst in new_out: + + replacement = classes.del_get_new_master(dst) + if replacement is None: + to_del = set([dst]) + to_replace = {} + else: + to_del = set() + to_replace = {dst:replacement} + + graph = DiGraph() + # Build en expression graph linking all classes + has_parents = False + for node in classes.nodes(): + if dst in node: + # Only dependent nodes are interesting here + has_parents = True + expr_to_graph = ExprToGraph(graph) + expr_to_graph.visit(node) + + if not has_parents: + continue + + todo = graph.leaves() + done = set() + + while todo: + node = todo.pop(0) + if node in done: + continue + # If at least one son is not done, re do later + if [son for son in graph.successors(node) if son not in done]: + todo.append(node) + continue + done.add(node) + + # If at least one son cannot be replaced (deleted), our last + # chance is to have an equivalence + if any(son in to_del for son in graph.successors(node)): + # One son has been deleted! + # Try to find a replacement of the whole expression + replacement = classes.del_get_new_master(node) + if replacement is None: + to_del.add(node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + else: + to_replace[node] = replacement + # Continue with replacement + + # Everyson is live or has been replaced + new_node = node.replace_expr(to_replace) + + if new_node == node: + # If node is not touched (Ex: leaf node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + + # Node has been modified, update equivalence classes + classes.replace_node(node, new_node) + to_replace[node] = new_node + + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + + continue + + new_assignblk = AssignBlock(new_out, assignblock.instr) + dsts = new_out.keys() + + # Remove interfering known classes + to_del = set() + for node in list(classes.nodes()): + if self.may_interfer(dsts, node): + # Interfere with known equivalence class + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + + # Update equivalence classes + for dst, src in new_out.items(): + # Delete equivalence class interfering with dst + to_del = set() + classes = self.equivalence_classes + for node in classes.nodes(): + if dst in node: + to_del.add(node) + for node in to_del: + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + # Don't create equivalence if self interfer + if self.may_interfer(dsts, src): + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + if dst.is_id() or dst.is_mem(): + self.undefined.add(dst) + continue + + if not self.propagation_allowed(src): + continue + + self.undefined.discard(dst) + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + self.equivalence_classes.add_equivalence(dst, src) + + return new_assignblk + + + def merge(self, other): + """ + Merge the current state with @other + Merge rules: + - if two nodes are equal in both states => in equivalence class + - if node value is different or non present in another state => undefined + @other: State instance + """ + classes1 = self.equivalence_classes + classes2 = other.equivalence_classes + + undefined = set(node for node in self.undefined if node.is_id() or node.is_mem()) + undefined.update(set(node for node in other.undefined if node.is_id() or node.is_mem())) + # Should we compute interference between srcs and undefined ? + # Nop => should already interfere in other state + components1 = classes1.get_classes() + components2 = classes2.get_classes() + + node_to_component2 = {} + for component in components2: + for node in component: + node_to_component2[node] = component + + # Compute intersection of equivalence classes of states + out = [] + nodes_ok = set() + while components1: + component1 = components1.pop() + for node in component1: + if node in undefined: + continue + component2 = node_to_component2.get(node) + if component2 is None: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + continue + if node not in component2: + continue + # Found two classes containing node + common = component1.intersection(component2) + if len(common) == 1: + # Intersection contains only one node => undefine node + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + component2.discard(common.pop()) + continue + if common: + # Intersection contains multiple nodes + # Here, common nodes don't interfere with any undefined + nodes_ok.update(common) + out.append(common) + diff = component1.difference(common) + if diff: + components1.append(diff) + component2.difference_update(common) + break + + # Discard remaining components2 elements + for component in components2: + for node in component: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + + all_nodes = set() + for common in out: + all_nodes.update(common) + + new_order = dict( + (node, index) for (node, index) in classes1.order.items() + if node in all_nodes + ) + + unionfind = UnionFind() + new_classes = [] + global_max_index = 0 + for common in out: + min_index = None + master = None + for node in common: + index = new_order[node] + global_max_index = max(index, global_max_index) + if min_index is None or min_index > index: + min_index = index + master = node + for node in common: + if node == master: + continue + unionfind.add_equivalence(node, master) + + unionfind.index = global_max_index + unionfind.order = new_order + state = self.__class__() + state.equivalence_classes = unionfind + state.undefined = undefined + + return state + + +class PropagateExpressions(object): + """ + Propagate expressions + + The algorithm propagates equivalence classes expressions from the entry + point. During the analyse, we replace source nodes by its equivalence + classes representative. Equivalence classes can be modified during analyse + due to memory aliasing. + + For example: + B = A+1 + C = A + A = 6 + D = [B] + + Will result in: + B = A+1 + C = A + A = 6 + D = [C+1] + """ + + @staticmethod + def new_state(): + return State() + + def merge_prev_states(self, ircfg, states, loc_key): + """ + Merge predecessors states of irblock at location @loc_key + @ircfg: IRCfg instance + @states: Dictionary linking locations to state + @loc_key: location of the current irblock + """ + + prev_states = [] + for predecessor in ircfg.predecessors(loc_key): + prev_states.append((predecessor, states[predecessor])) + + filtered_prev_states = [] + for (_, prev_state) in prev_states: + if prev_state is not None: + filtered_prev_states.append(prev_state) + + prev_states = filtered_prev_states + if not prev_states: + state = self.new_state() + elif len(prev_states) == 1: + state = prev_states[0].copy() + else: + while prev_states: + state = prev_states.pop() + if state is not None: + break + for prev_state in prev_states: + state = state.merge(prev_state) + + return state + + def update_state(self, irblock, state): + """ + Propagate the @state through the @irblock + @irblock: IRBlock instance + @state: State instance + """ + new_assignblocks = [] + modified = False + + for assignblock in irblock: + if not assignblock.items(): + continue + new_assignblk = state.eval_assignblock(assignblock) + new_assignblocks.append(new_assignblk) + if new_assignblk != assignblock: + modified = True + + new_irblock = IRBlock(irblock.loc_db, irblock.loc_key, new_assignblocks) + + return new_irblock, modified + + def propagate(self, ssa, head, max_expr_depth=None): + """ + Apply algorithm on the @ssa graph + """ + ircfg = ssa.ircfg + self.loc_db = ircfg.loc_db + irblocks = ssa.ircfg.blocks + states = {} + for loc_key, irblock in irblocks.items(): + states[loc_key] = None + + todo = deque([head]) + while todo: + loc_key = todo.popleft() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state_orig = states[loc_key] + state = self.merge_prev_states(ircfg, states, loc_key) + state = state.copy() + + new_irblock, modified_irblock = self.update_state(irblock, state) + if state_orig is not None: + # Merge current and previous state + state = state.merge(state_orig) + if (state.equivalence_classes == state_orig.equivalence_classes and + state.undefined == state_orig.undefined + ): + continue + + states[loc_key] = state + # Propagate to sons + for successor in ircfg.successors(loc_key): + todo.append(successor) + + # Update blocks + todo = set(loc_key for loc_key in irblocks) + modified = False + while todo: + loc_key = todo.pop() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state = self.merge_prev_states(ircfg, states, loc_key) + new_irblock, modified_irblock = self.update_state(irblock, state) + modified |= modified_irblock + irblocks[new_irblock.loc_key] = new_irblock + + return modified |