diff options
Diffstat (limited to 'miasm/analysis/data_flow.py')
| -rw-r--r-- | miasm/analysis/data_flow.py | 1065 |
1 files changed, 790 insertions, 275 deletions
diff --git a/miasm/analysis/data_flow.py b/miasm/analysis/data_flow.py index ef8a8cb0..12474f9b 100644 --- a/miasm/analysis/data_flow.py +++ b/miasm/analysis/data_flow.py @@ -1,18 +1,20 @@ """Data flow analysis based on miasm intermediate representation""" from builtins import range -from collections import namedtuple - +from collections import namedtuple, Counter +from pprint import pprint as pp from future.utils import viewitems, viewvalues from miasm.core.utils import encode_hex from miasm.core.graph import DiGraph from miasm.ir.ir import AssignBlock, IRBlock from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ - ExprAssign, ExprOp, ExprWalk, is_function_call + ExprAssign, ExprOp, ExprWalk, ExprSlice, \ + is_function_call, ExprVisitorCallbackBottomToTop from miasm.expression.simplifications import expr_simp from miasm.core.interval import interval from miasm.expression.expression_helper import possible_values from miasm.analysis.ssa import get_phi_sources_parent_block, \ irblock_has_phi +from miasm.ir.symbexec import get_expr_base_offset class ReachingDefinitions(dict): @@ -735,278 +737,6 @@ def expr_has_mem(expr): return visitor.visit(expr) -class PropagateThroughExprId(object): - """ - Propagate expressions though ExprId - """ - - def has_propagation_barrier(self, assignblks): - """ - Return True if propagation cannot cross the @assignblks - @assignblks: list of AssignBlock to check - """ - for assignblk in assignblks: - for dst, src in viewitems(assignblk): - if is_function_call(src): - return True - if dst.is_mem(): - return True - return False - - def is_mem_written(self, ssa, node_a, node_b): - """ - Return True if memory is written at least once between @node_a and - @node_b - - @node: AssignblkNode representing the start position - @successor: AssignblkNode representing the end position - """ - - block_b = ssa.graph.blocks[node_b.label] - nodes_to_do = self.compute_reachable_nodes_from_a_to_b(ssa.graph, node_a.label, node_b.label) - - if node_a.label == node_b.label: - # src is dst - assert nodes_to_do == set([node_a.label]) - if self.has_propagation_barrier(block_b.assignblks[node_a.index:node_b.index]): - return True - else: - # Check everyone but node_a.label and node_b.label - for loc in nodes_to_do - set([node_a.label, node_b.label]): - if loc not in ssa.graph.blocks: - continue - block = ssa.graph.blocks[loc] - if self.has_propagation_barrier(block.assignblks): - return True - # Check node_a.label partially - block_a = ssa.graph.blocks[node_a.label] - if self.has_propagation_barrier(block_a.assignblks[node_a.index:]): - return True - if nodes_to_do.intersection(ssa.graph.successors(node_b.label)): - # There is a path from node_b.label to node_b.label => Check node_b.label fully - if self.has_propagation_barrier(block_b.assignblks): - return True - else: - # Check node_b.label partially - if self.has_propagation_barrier(block_b.assignblks[:node_b.index]): - return True - return False - - def compute_reachable_nodes_from_a_to_b(self, ssa, loc_a, loc_b): - reachables_a = set(ssa.reachable_sons(loc_a)) - reachables_b = set(ssa.reachable_parents_stop_node(loc_b, loc_a)) - return reachables_a.intersection(reachables_b) - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Return True if we can replace @node_a source present in @to_replace into - @node_b - - @node_a: AssignblkNode position - @node_b: AssignblkNode position - """ - if not expr_has_mem(to_replace[node_a.var]): - return True - if self.is_mem_written(ssa, node_a, node_b): - return False - return True - - - def get_var_definitions(self, ssa): - """ - Return a dictionary linking variable to its assignment location - @ssa: SSADiGraph instance - """ - ircfg = ssa.graph - def_dct = {} - for node in ircfg.nodes(): - block = ircfg.blocks.get(node, None) - if block is None: - continue - for index, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_id(): - continue - if dst in ssa.immutable_ids: - continue - assert dst not in def_dct - def_dct[dst] = node, index - return def_dct - - def get_candidates(self, ssa, head, max_expr_depth): - def_dct = self.get_var_definitions(ssa) - defuse = SSADefUse.from_ssa(ssa) - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - if node.var in ssa.immutable_ids: - continue - src = defuse.get_node_target(node) - if max_expr_depth is not None and len(str(src)) > max_expr_depth: - continue - if is_function_call(src): - continue - if node.var.is_mem(): - continue - if src.is_op('Phi'): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagate(self, ssa, head, max_expr_depth=None): - """ - Do expression propagation - @ssa: SSADiGraph instance - @head: the head location of the graph - @max_expr_depth: the maximum allowed depth of an expression - """ - node_to_reg, to_replace, defuse = self.get_candidates(ssa, head, max_expr_depth) - modified = False - for node, reg in viewitems(node_to_reg): - for successor in defuse.successors(node): - if not self.propagation_allowed(ssa, to_replace, node, successor): - continue - - node_a = node - node_b = successor - block = ssa.graph.blocks[node_b.label] - - replace = {node_a.var: to_replace[node_a.var]} - # Replace - assignblks = list(block) - assignblk = block[node_b.index] - out = {} - for dst, src in viewitems(assignblk): - if src.is_op('Phi'): - out[dst] = src - continue - - if src.is_mem(): - ptr = src.ptr.replace_expr(replace) - new_src = ExprMem(ptr, src.size) - else: - new_src = src.replace_expr(replace) - - if dst.is_id(): - new_dst = dst - elif dst.is_mem(): - ptr = dst.ptr.replace_expr(replace) - new_dst = ExprMem(ptr, dst.size) - else: - new_dst = dst.replace_expr(replace) - if not (new_dst.is_id() or new_dst.is_mem()): - new_dst = dst - if src != new_src or dst != new_dst: - modified = True - out[new_dst] = new_src - out = AssignBlock(out, assignblk.instr) - assignblks[node_b.index] = out - new_block = IRBlock(block.loc_key, assignblks) - ssa.graph.blocks[block.loc_key] = new_block - - return modified - - - -class PropagateExprIntThroughExprId(PropagateThroughExprId): - """ - Propagate ExprInt though ExprId: classic constant propagation - This is a sub family of PropagateThroughExprId. - It reduces leaves in expressions of a program. - """ - - def get_candidates(self, ssa, head, max_expr_depth): - defuse = SSADefUse.from_ssa(ssa) - - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - src = defuse.get_node_target(node) - if not src.is_int(): - continue - if is_function_call(src): - continue - if node.var.is_mem(): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Propagating ExprInt is always ok - """ - return True - - -class PropagateThroughExprMem(object): - """ - Propagate through ExprMem in very simple cases: - - if no memory write between source and target - - if source does not contain any memory reference - """ - - def propagate(self, ssa, head, max_expr_depth=None): - ircfg = ssa.graph - todo = set() - modified = False - for block in viewvalues(ircfg.blocks): - for i, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_mem(): - continue - if expr_has_mem(src): - continue - todo.add((block.loc_key, i + 1, dst, src)) - ptr = dst.ptr - for size in range(8, dst.size, 8): - todo.add((block.loc_key, i + 1, ExprMem(ptr, size), src[:size])) - - while todo: - loc_key, index, mem_dst, mem_src = todo.pop() - block = ircfg.blocks.get(loc_key, None) - if block is None: - continue - assignblks = list(block) - block_modified = False - for i in range(index, len(block)): - assignblk = block[i] - write_mem = False - assignblk_modified = False - out = dict(assignblk) - out_new = {} - for dst, src in viewitems(out): - if dst.is_mem(): - write_mem = True - ptr = dst.ptr.replace_expr({mem_dst:mem_src}) - dst = ExprMem(ptr, dst.size) - src = src.replace_expr({mem_dst:mem_src}) - out_new[dst] = src - if out != out_new: - assignblk_modified = True - - if assignblk_modified: - assignblks[i] = AssignBlock(out_new, assignblk.instr) - block_modified = True - if write_mem: - break - else: - # If no memory written, we may propagate to sons - # if son has only parent - for successor in ircfg.successors(loc_key): - predecessors = ircfg.predecessors(successor) - if len(predecessors) != 1: - continue - todo.add((successor, 0, mem_dst, mem_src)) - - if block_modified: - modified = True - new_block = IRBlock(block.loc_key, assignblks) - ircfg.blocks[block.loc_key] = new_block - return modified - - def stack_to_reg(expr): if expr.is_mem(): ptr = expr.arg @@ -1683,3 +1413,788 @@ class DiGraphLivenessSSA(DiGraphLivenessIRA): parent_block.infos[-1].var_out = var_info todo.add(parent) + + +def get_phi_sources(phi_src, phi_dsts, ids_to_src): + """ + Return False if the @phi_src has more than one non-phi source + Else, return its source + @ids_to_src: Dictionary linking phi source to its definition + """ + true_values = set() + for src in phi_src.args: + if src in phi_dsts: + # Source is phi dst => skip + continue + true_src = ids_to_src[src] + if true_src in phi_dsts: + # Source is phi dst => skip + continue + # Check if src is not also a phi + if true_src.is_op('Phi'): + phi_dsts.add(src) + true_src = get_phi_sources(true_src, phi_dsts, ids_to_src) + if true_src is False: + return False + if true_src is True: + continue + true_values.add(true_src) + if len(true_values) != 1: + return False + if not true_values: + return True + if len(true_values) != 1: + return False + true_value = true_values.pop() + return true_value + + +class DelDummyPhi(object): + """ + Del dummy phi + """ + + def del_dummy_phi(self, ssa, head): + ids_to_src = {} + for block in viewvalues(ssa.graph.blocks): + for index, assignblock in enumerate(block): + for dst, src in viewitems(assignblock): + if not dst.is_id(): + continue + ids_to_src[dst] = src + + modified = False + for block in ssa.graph.blocks.values(): + if not irblock_has_phi(block): + continue + assignblk = block[0] + modified_assignblk = False + for dst, phi_src in viewitems(assignblk): + assert phi_src.is_op('Phi') + true_value = get_phi_sources(phi_src, set([dst]), ids_to_src) + if true_value is False: + continue + if expr_has_mem(true_value): + continue + fixed_phis = {} + for old_dst, old_phi_src in viewitems(assignblk): + if old_dst == dst: + continue + fixed_phis[old_dst] = old_phi_src + + modified = True + + assignblks = list(block) + assignblks[0] = AssignBlock(fixed_phis, assignblk.instr) + assignblks[1:1] = [AssignBlock({dst: true_value}, assignblk.instr)] + new_irblock = IRBlock(block.loc_key, assignblks) + ssa.graph.blocks[block.loc_key] = new_irblock + + return modified + + +def replace_expr_from_bottom(expr_orig, dct): + def replace(expr): + if expr in dct: + return dct[expr] + return expr + visitor = ExprVisitorCallbackBottomToTop(lambda expr:replace(expr)) + return visitor.visit(expr_orig) + + +def is_mem_sub_part(needle, mem): + """ + If @needle is a sub part of @mem, return the offset of @needle in @mem + Else, return False + @needle: ExprMem + @mem: ExprMem + """ + ptr_base_a, ptr_offset_a = get_expr_base_offset(needle.ptr) + ptr_base_b, ptr_offset_b = get_expr_base_offset(mem.ptr) + if ptr_base_a != ptr_base_b: + return False + # Test if sub part starts after mem + if not (ptr_offset_b <= ptr_offset_a < ptr_offset_b + mem.size // 8): + return False + # Test if sub part ends before mem + if not (ptr_offset_a + needle.size // 8 <= ptr_offset_b + mem.size // 8): + return False + return ptr_offset_a - ptr_offset_b + +class UnionFind(object): + """ + Implementation of UnionFind structure + __classes: a list of Set of equivalent elements + node_to_class: Dictionary linkink an element to its equivalent class + order: Dictionary link an element to it's weight + + The order attributes is used to allow the selection of a representative + element of an equivalence class + """ + + def __init__(self): + self.index = 0 + self.__classes = [] + self.node_to_class = {} + self.order = dict() + + def copy(self): + """ + Return a copy of the object + """ + unionfind = UnionFind() + unionfind.index = self.index + unionfind.__classes = [set(known_class) for known_class in self.__classes] + node_to_class = {} + for class_eq in unionfind.__classes: + for node in class_eq: + node_to_class[node] = class_eq + unionfind.node_to_class = node_to_class + unionfind.order = dict(self.order) + return unionfind + + def replace_node(self, old_node, new_node): + """ + Replace the @old_node by the @new_node + """ + classes = self.get_classes() + node_to_class = dict(self.node_to_class) + + new_classes = [] + replace_dct = {old_node:new_node} + for eq_class in classes: + new_class = set() + for node in eq_class: + new_class.add(replace_expr_from_bottom(node, replace_dct)) + new_classes.append(new_class) + + node_to_class = {} + for class_eq in new_classes: + for node in class_eq: + node_to_class[node] = class_eq + self.__classes = new_classes + self.node_to_class = node_to_class + new_order = dict() + for node,index in self.order.items(): + new_node = replace_expr_from_bottom(node, replace_dct) + new_order[new_node] = index + self.order = new_order + + def get_classes(self): + """ + Return a list of the equivalent classes + """ + classes = [] + for class_tmp in self.__classes: + classes.append(set(class_tmp)) + return classes + + def nodes(self): + for known_class in self.__classes: + for node in known_class: + yield node + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + + return Counter(frozenset(known_class) for known_class in self.__classes) == Counter(frozenset(known_class) for known_class in other.__classes) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __str__(self): + components = self.__classes + out = ['UnionFind<'] + for component in components: + out.append("\t" + (", ".join([str(node) for node in component]))) + out.append('>') + return "\n".join(out) + + def add_equivalence(self, node_a, node_b): + """ + Add the new equivalence @node_a == @node_b + @node_a is equivalent to @node_b, but @node_b is more representative + than @node_a + """ + if node_b not in self.order: + self.order[node_b] = self.index + self.index += 1 + # As node_a is destination, we always replace its index + self.order[node_a] = self.index + self.index += 1 + + if node_a not in self.node_to_class and node_b not in self.node_to_class: + new_class = set([node_a, node_b]) + self.node_to_class[node_a] = new_class + self.node_to_class[node_b] = new_class + self.__classes.append(new_class) + elif node_a in self.node_to_class and node_b not in self.node_to_class: + known_class = self.node_to_class[node_a] + known_class.add(node_b) + self.node_to_class[node_b] = known_class + elif node_a not in self.node_to_class and node_b in self.node_to_class: + known_class = self.node_to_class[node_b] + known_class.add(node_a) + self.node_to_class[node_a] = known_class + else: + raise RuntimeError("Two nodes cannot be in two classes") + + def _get_master(self, node): + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + best_node = node + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + + def get_master(self, node): + """ + Return the representative element of the equivalence class containing + @node + @node: ExprMem or ExprId + """ + if not node.is_mem(): + return self._get_master(node) + if node in self.node_to_class: + # Full expr mem is known + return self._get_master(node) + # Test if mem is sub part of known node + for expr in self.node_to_class: + if not expr.is_mem(): + continue + ret = is_mem_sub_part(node, expr) + if ret is False: + continue + master = self._get_master(expr) + master = master[ret * 8 : ret * 8 + node.size] + return master + + return self._get_master(node) + + + def del_element(self, node): + """ + Remove @node for the equivalence classes + """ + assert node in self.node_to_class + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + def del_get_new_master(self, node): + """ + Remove @node for the equivalence classes and return it's representative + equivalent element + @node: Element to delete + """ + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + if not known_class: + return None + best_node = list(known_class)[0] + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + +class ExprToGraph(ExprWalk): + """ + Transform an Expression into a tree and add link nodes to an existing tree + """ + def __init__(self, graph): + super(ExprToGraph, self).__init__(self.link_nodes) + self.graph = graph + + def link_nodes(self, expr, *args, **kwargs): + """ + Transform an Expression @expr into a tree and add link nodes to the + current tree + @expr: Expression + """ + if expr in self.graph.nodes(): + return None + self.graph.add_node(expr) + if expr.is_mem(): + self.graph.add_uniq_edge(expr, expr.ptr) + elif expr.is_slice(): + self.graph.add_uniq_edge(expr, expr.arg) + elif expr.is_cond(): + self.graph.add_uniq_edge(expr, expr.cond) + self.graph.add_uniq_edge(expr, expr.src1) + self.graph.add_uniq_edge(expr, expr.src2) + elif expr.is_compose(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + elif expr.is_op(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + return None + +class State(object): + """ + Object representing the state of a program at a given point + The state is represented using equivalence classes + + Each assignment can create/destroy equivalence classes. Interferences + between expression is computed using `may_interfer` function + """ + + def __init__(self): + self.equivalence_classes = UnionFind() + self.undefined = set() + + def copy(self): + state = self.__class__() + state.equivalence_classes = self.equivalence_classes.copy() + state.undefined = self.undefined.copy() + return state + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + return ( + set(self.equivalence_classes.nodes()) == set(other.equivalence_classes.nodes()) and + sorted(self.equivalence_classes.edges()) == sorted(other.equivalence_classes.edges()) and + self.undefined == other.undefined + ) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def may_interfer(self, dsts, src): + """ + Return True is @src may interfer with expressions in @dsts + @dsts: Set of Expressions + @src: expression to test + """ + + srcs = src.get_r() + for src in srcs: + for dst in dsts: + if dst in src: + return True + if dst.is_mem() and src.is_mem(): + base1, offset1 = get_expr_base_offset(dst.ptr) + base2, offset2 = get_expr_base_offset(src.ptr) + if base1 != base2: + return True + assert offset1 + dst.size // 8 - 1 <= int(base1.mask) + assert offset2 + src.size // 8 - 1 <= int(base2.mask) + interval1 = interval([(offset1, offset1 + dst.size // 8 - 1)]) + interval2 = interval([(offset2, offset2 + src.size // 8 - 1)]) + if (interval1 & interval2).empty: + continue + return True + return False + + def _get_representative_expr(self, expr): + representative = self.equivalence_classes.get_master(expr) + if representative is None: + return expr + return representative + + def get_representative_expr(self, expr): + """ + Replace each sub expression of @expr by its representative element + @expr: Expression to analyse + """ + new_expr = expr.visit(self._get_representative_expr) + return new_expr + + def propagation_allowed(self, expr): + """ + Return True if @expr can be propagated + Don't propagate: + - Phi nodes + - call_func_ret / call_func_stack operants + """ + + if ( + expr.is_op('Phi') or + (expr.is_op() and expr.op.startswith("call_func")) + ): + return False + return True + + def eval_assignblock(self, assignblock): + """ + Evaluate the @assignblock on the current state + @assignblock: AssignBlock instance + """ + + out = dict(assignblock.items()) + new_out = dict() + # Replace sub expression by their equivalence class repesentative + for dst, src in out.items(): + if src.is_op('Phi'): + # Don't replace in phi + new_src = src + else: + new_src = self.get_representative_expr(src) + if dst.is_mem(): + new_ptr = self.get_representative_expr(dst.ptr) + new_dst = ExprMem(new_ptr, dst.size) + else: + new_dst = dst + new_dst = expr_simp(new_dst) + new_src = expr_simp(new_src) + new_out[new_dst] = new_src + + # For each destination, update (or delete) dependent's node according to + # equivalence classes + classes = self.equivalence_classes + + for dst in new_out: + + replacement = classes.del_get_new_master(dst) + if replacement is None: + to_del = set([dst]) + to_replace = {} + else: + to_del = set() + to_replace = {dst:replacement} + + graph = DiGraph() + # Build en expression graph linking all classes + has_parents = False + for node in classes.nodes(): + if dst in node: + # Only dependent nodes are interesting here + has_parents = True + expr_to_graph = ExprToGraph(graph) + expr_to_graph.visit(node) + + if not has_parents: + continue + + todo = graph.leaves() + done = set() + + while todo: + node = todo.pop(0) + if node in done: + continue + # If at least one son is not done, re do later + if [son for son in graph.successors(node) if son not in done]: + todo.append(node) + continue + done.add(node) + + # If at least one son cannot be replaced (deleted), our last + # chance is to have an equivalence + if any(son in to_del for son in graph.successors(node)): + # One son has been deleted! + # Try to find a replacement of the whole expression + replacement = classes.del_get_new_master(node) + if replacement is None: + to_del.add(node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + else: + to_replace[node] = replacement + # Continue with replacement + + # Everyson is live or has been replaced + new_node = node.replace_expr(to_replace) + + if new_node == node: + # If node is not touched (Ex: leaf node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + + # Node has been modified, update equivalence classes + classes.replace_node(node, new_node) + to_replace[node] = new_node + + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + + continue + + new_assignblk = AssignBlock(new_out, assignblock.instr) + dsts = new_out.keys() + + # Remove interfering known classes + to_del = set() + for node in list(classes.nodes()): + if self.may_interfer(dsts, node): + # Interfer with known equivalence class + self.equivalence_classes.del_element(node) + self.undefined.add(node) + + + # Update equivalence classes + for dst, src in new_out.items(): + # Delete equivalence class interfering with dst + to_del = set() + classes = self.equivalence_classes + for node in classes.nodes(): + if dst in node: + to_del.add(node) + for node in to_del: + self.equivalence_classes.del_element(node) + self.undefined.add(node) + + # Don't create equivalence if self interfer + if self.may_interfer(dsts, src): + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + self.undefined.add(dst) + continue + + if not self.propagation_allowed(src): + continue + + ## Dont create equivalence if dependence on undef + if dst.is_mem() and self.may_interfer(self.undefined, dst.ptr): + continue + + self.undefined.discard(dst) + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + self.equivalence_classes.add_equivalence(dst, src) + + return new_assignblk + + + def merge(self, other): + """ + Merge the current state with @other + @other: State instance + """ + classes1 = self.equivalence_classes + classes2 = other.equivalence_classes + + undefined = set(node for node in self.undefined if node.is_id() or node.is_mem()) + undefined.update(set(node for node in other.undefined if node.is_id() or node.is_mem())) + # Should we compute interference between srcs and undefined ? + # Nop => should already interfer in other state + components1 = classes1.get_classes() + components2 = classes2.get_classes() + + node_to_component2 = {} + for component in components2: + for node in component: + node_to_component2[node] = component + + out = [] + nodes_ok = set() + while components1: + component1 = components1.pop() + new_component1 = set() + for node in component1: + if node in undefined: + continue + component2 = node_to_component2.get(node) + if component2 is None: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + continue + if node not in component2: + continue + common = component1.intersection(component2) + if len(common) == 1: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + + undefined.add(node) + component2.discard(common.pop()) + continue + if common: + nodes_ok.update(common) + out.append(common) + diff = component1.difference(common) + if diff: + components1.append(diff) + component2.difference_update(common) + break + + # Discard remaining components2 elements + for component in components2: + for node in component: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + + all_nodes = set() + for common in out: + all_nodes.update(common) + + new_order = dict( + (node, index) for (node, index) in classes1.order.items() + if node in all_nodes + ) + + unionfind = UnionFind() + new_classes = [] + global_max_index = 0 + for common in out: + min_index = None + master = None + for node in common: + index = new_order[node] + global_max_index = max(index, global_max_index) + if min_index is None or min_index > index: + min_index = index + master = node + for node in common: + if node == master: + continue + unionfind.add_equivalence(node, master) + + unionfind.index = global_max_index + unionfind.order = new_order + state = self.__class__() + state.equivalence_classes = unionfind + state.undefined = undefined + + return state + +class PropagateExpressions(object): + """ + Propagate expressions + + The algorithm propagates equivalence classes expressions from the entry + point. During the analyse, we replace source nodes by its equivalence + classes representative. Equivalence classes can be modified during analyse + due to memory aliasing. + + For example: + B = A+1 + C = A + A = 6 + D = [B] + + Will result in: + B = A+1 + C = A + A = 6 + D = [C+1] + """ + + @staticmethod + def new_state(): + return State() + + def merge_prev_states(self, ircfg, states, loc_key): + """ + Merge predecessors states of irblock at location @loc_key + @ircfg: IRCfg instance + @sates: Dictionary linking locations to state + @loc_key: location of the current irblock + """ + + prev_states = [] + for predecessor in ircfg.predecessors(loc_key): + prev_states.append(states[predecessor]) + + filtered_prev_states = [] + for prev_state in prev_states: + if prev_state is not None: + filtered_prev_states.append(prev_state) + + prev_states = filtered_prev_states + if not prev_states: + state = self.new_state() + elif len(prev_states) == 1: + state = prev_states[0].copy() + else: + while prev_states: + state = prev_states.pop() + if state is not None: + break + for prev_state in prev_states: + state = state.merge(prev_state) + + return state + + def update_state(self, irblock, state): + """ + Propagate the @state through the @irblock + @irblock: IRBlock instance + @state: State instance + """ + new_assignblocks = [] + modified = False + + for index, assignblock in enumerate(irblock): + if not assignblock.items(): + continue + + new_assignblk = state.eval_assignblock(assignblock) + new_assignblocks.append(new_assignblk) + if new_assignblk != assignblock: + modified = True + + new_irblock = IRBlock(irblock.loc_key, new_assignblocks) + + return new_irblock, modified + + def propagate(self, ssa, head, max_expr_depth=None): + """ + Apply algorithm on the @ssa graph + """ + ircfg = ssa.ircfg + self.loc_db = ircfg.loc_db + irblocks = ssa.ircfg.blocks + states = {} + for loc_key, irblock in irblocks.items(): + states[loc_key] = None + + todo = set([head]) + while todo: + loc_key = todo.pop() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state_orig = states[loc_key] + state = self.merge_prev_states(ircfg, states, loc_key) + new_irblock, _ = self.update_state(irblock, state) + if ( + state_orig is not None and + state.equivalence_classes == state_orig.equivalence_classes and + state.undefined == state_orig.undefined + ): + continue + + states[loc_key] = state + # Propagate to sons + todo.update(ircfg.successors(loc_key)) + + # Update blocks + todo = set(loc_key for loc_key in irblocks) + modified = False + while todo: + loc_key = todo.pop() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state = self.merge_prev_states(ircfg, states, loc_key) + new_irblock, modified_irblock = self.update_state(irblock, state) + modified |= modified_irblock + irblocks[new_irblock.loc_key] = new_irblock + + return modified |