diff options
Diffstat (limited to 'miasm/analysis')
| -rw-r--r-- | miasm/analysis/data_flow.py | 1132 | ||||
| -rw-r--r-- | miasm/analysis/depgraph.py | 121 | ||||
| -rw-r--r-- | miasm/analysis/dse.py | 16 | ||||
| -rw-r--r-- | miasm/analysis/gdbserver.py | 4 | ||||
| -rw-r--r-- | miasm/analysis/sandbox.py | 3 | ||||
| -rw-r--r-- | miasm/analysis/simplifier.py | 33 |
6 files changed, 909 insertions, 400 deletions
diff --git a/miasm/analysis/data_flow.py b/miasm/analysis/data_flow.py index 7bd6d72f..7340c023 100644 --- a/miasm/analysis/data_flow.py +++ b/miasm/analysis/data_flow.py @@ -1,19 +1,21 @@ """Data flow analysis based on miasm intermediate representation""" from builtins import range -from collections import namedtuple - +from collections import namedtuple, Counter +from pprint import pprint as pp from future.utils import viewitems, viewvalues from miasm.core.utils import encode_hex from miasm.core.graph import DiGraph from miasm.ir.ir import AssignBlock, IRBlock from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ - ExprAssign, ExprOp -from miasm.expression.simplifications import expr_simp + ExprAssign, ExprOp, ExprWalk, ExprSlice, \ + is_function_call, ExprVisitorCallbackBottomToTop +from miasm.expression.simplifications import expr_simp, expr_simp_explicit from miasm.core.interval import interval from miasm.expression.expression_helper import possible_values from miasm.analysis.ssa import get_phi_sources_parent_block, \ irblock_has_phi - +from miasm.ir.symbexec import get_expr_base_offset +from collections import deque class ReachingDefinitions(dict): """ @@ -131,7 +133,7 @@ class DiGraphDefUse(DiGraph): def __init__(self, reaching_defs, - deref_mem=False, *args, **kwargs): + deref_mem=False, apply_simp=False, *args, **kwargs): """Instantiate a DiGraph @blocks: IR blocks """ @@ -144,7 +146,8 @@ class DiGraphDefUse(DiGraph): super(DiGraphDefUse, self).__init__(*args, **kwargs) self._compute_def_use(reaching_defs, - deref_mem=deref_mem) + deref_mem=deref_mem, + apply_simp=apply_simp) def edge_attr(self, src, dst): """ @@ -155,18 +158,20 @@ class DiGraphDefUse(DiGraph): return self._edge_attr[(src, dst)] def _compute_def_use(self, reaching_defs, - deref_mem=False): + deref_mem=False, apply_simp=False): for block in viewvalues(self._blocks): self._compute_def_use_block(block, reaching_defs, - deref_mem=deref_mem) + deref_mem=deref_mem, + apply_simp=apply_simp) - def _compute_def_use_block(self, block, reaching_defs, deref_mem=False): + def _compute_def_use_block(self, block, reaching_defs, deref_mem=False, apply_simp=False): for index, assignblk in enumerate(block): assignblk_reaching_defs = reaching_defs.get_definitions(block.loc_key, index) for lval, expr in viewitems(assignblk): self.add_node(AssignblkNode(block.loc_key, index, lval)) + expr = expr_simp_explicit(expr) if apply_simp else expr read_vars = expr.get_r(mem_read=deref_mem) if deref_mem and lval.is_mem(): read_vars.update(lval.ptr.get_r(mem_read=deref_mem)) @@ -223,7 +228,7 @@ class DeadRemoval(object): lval.is_mem() or self.ir_arch.IRDst == lval or lval.is_id("exception_flags") or - rval.is_function_call() + is_function_call(rval) ): return True return False @@ -723,307 +728,16 @@ class SSADefUse(DiGraph): - -def expr_test_visit(expr, test): - result = set() - expr.visit( - lambda expr: expr, - lambda expr: test(expr, result) - ) - if result: - return True - else: - return False - - -def expr_has_mem_test(expr, result): - if result: - # Don't analyse if we already found a candidate - return False - if expr.is_mem(): - result.add(expr) - return False - return True - - def expr_has_mem(expr): """ Return True if expr contains at least one memory access @expr: Expr instance """ - return expr_test_visit(expr, expr_has_mem_test) - - -class PropagateThroughExprId(object): - """ - Propagate expressions though ExprId - """ - - def has_propagation_barrier(self, assignblks): - """ - Return True if propagation cannot cross the @assignblks - @assignblks: list of AssignBlock to check - """ - for assignblk in assignblks: - for dst, src in viewitems(assignblk): - if src.is_function_call(): - return True - if dst.is_mem(): - return True - return False - - def is_mem_written(self, ssa, node_a, node_b): - """ - Return True if memory is written at least once between @node_a and - @node_b - - @node: AssignblkNode representing the start position - @successor: AssignblkNode representing the end position - """ - - block_b = ssa.graph.blocks[node_b.label] - nodes_to_do = self.compute_reachable_nodes_from_a_to_b(ssa.graph, node_a.label, node_b.label) - - if node_a.label == node_b.label: - # src is dst - assert nodes_to_do == set([node_a.label]) - if self.has_propagation_barrier(block_b.assignblks[node_a.index:node_b.index]): - return True - else: - # Check everyone but node_a.label and node_b.label - for loc in nodes_to_do - set([node_a.label, node_b.label]): - if loc not in ssa.graph.blocks: - continue - block = ssa.graph.blocks[loc] - if self.has_propagation_barrier(block.assignblks): - return True - # Check node_a.label partially - block_a = ssa.graph.blocks[node_a.label] - if self.has_propagation_barrier(block_a.assignblks[node_a.index:]): - return True - if nodes_to_do.intersection(ssa.graph.successors(node_b.label)): - # There is a path from node_b.label to node_b.label => Check node_b.label fully - if self.has_propagation_barrier(block_b.assignblks): - return True - else: - # Check node_b.label partially - if self.has_propagation_barrier(block_b.assignblks[:node_b.index]): - return True - return False - - def compute_reachable_nodes_from_a_to_b(self, ssa, loc_a, loc_b): - reachables_a = set(ssa.reachable_sons(loc_a)) - reachables_b = set(ssa.reachable_parents_stop_node(loc_b, loc_a)) - return reachables_a.intersection(reachables_b) - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Return True if we can replace @node_a source present in @to_replace into - @node_b - @node_a: AssignblkNode position - @node_b: AssignblkNode position - """ - if not expr_has_mem(to_replace[node_a.var]): - return True - if self.is_mem_written(ssa, node_a, node_b): - return False - return True - - - def get_var_definitions(self, ssa): - """ - Return a dictionary linking variable to its assignment location - @ssa: SSADiGraph instance - """ - ircfg = ssa.graph - def_dct = {} - for node in ircfg.nodes(): - block = ircfg.blocks.get(node, None) - if block is None: - continue - for index, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_id(): - continue - if dst in ssa.immutable_ids: - continue - assert dst not in def_dct - def_dct[dst] = node, index - return def_dct - - def get_candidates(self, ssa, head, max_expr_depth): - def_dct = self.get_var_definitions(ssa) - defuse = SSADefUse.from_ssa(ssa) - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - if node.var in ssa.immutable_ids: - continue - src = defuse.get_node_target(node) - if max_expr_depth is not None and len(str(src)) > max_expr_depth: - continue - if src.is_function_call(): - continue - if node.var.is_mem(): - continue - if src.is_op('Phi'): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagate(self, ssa, head, max_expr_depth=None): - """ - Do expression propagation - @ssa: SSADiGraph instance - @head: the head location of the graph - @max_expr_depth: the maximum allowed depth of an expression - """ - node_to_reg, to_replace, defuse = self.get_candidates(ssa, head, max_expr_depth) - modified = False - for node, reg in viewitems(node_to_reg): - for successor in defuse.successors(node): - if not self.propagation_allowed(ssa, to_replace, node, successor): - continue - - node_a = node - node_b = successor - block = ssa.graph.blocks[node_b.label] - - replace = {node_a.var: to_replace[node_a.var]} - # Replace - assignblks = list(block) - assignblk = block[node_b.index] - out = {} - for dst, src in viewitems(assignblk): - if src.is_op('Phi'): - out[dst] = src - continue - - if src.is_mem(): - ptr = src.ptr.replace_expr(replace) - new_src = ExprMem(ptr, src.size) - else: - new_src = src.replace_expr(replace) - - if dst.is_id(): - new_dst = dst - elif dst.is_mem(): - ptr = dst.ptr.replace_expr(replace) - new_dst = ExprMem(ptr, dst.size) - else: - new_dst = dst.replace_expr(replace) - if not (new_dst.is_id() or new_dst.is_mem()): - new_dst = dst - if src != new_src or dst != new_dst: - modified = True - out[new_dst] = new_src - out = AssignBlock(out, assignblk.instr) - assignblks[node_b.index] = out - new_block = IRBlock(block.loc_key, assignblks) - ssa.graph.blocks[block.loc_key] = new_block - - return modified - - - -class PropagateExprIntThroughExprId(PropagateThroughExprId): - """ - Propagate ExprInt though ExprId: classic constant propagation - This is a sub family of PropagateThroughExprId. - It reduces leaves in expressions of a program. - """ - - def get_candidates(self, ssa, head, max_expr_depth): - defuse = SSADefUse.from_ssa(ssa) - - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - src = defuse.get_node_target(node) - if not src.is_int(): - continue - if src.is_function_call(): - continue - if node.var.is_mem(): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Propagating ExprInt is always ok - """ - return True - - -class PropagateThroughExprMem(object): - """ - Propagate through ExprMem in very simple cases: - - if no memory write between source and target - - if source does not contain any memory reference - """ - - def propagate(self, ssa, head, max_expr_depth=None): - ircfg = ssa.graph - todo = set() - modified = False - for block in viewvalues(ircfg.blocks): - for i, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_mem(): - continue - if expr_has_mem(src): - continue - todo.add((block.loc_key, i + 1, dst, src)) - ptr = dst.ptr - for size in range(8, dst.size, 8): - todo.add((block.loc_key, i + 1, ExprMem(ptr, size), src[:size])) - - while todo: - loc_key, index, mem_dst, mem_src = todo.pop() - block = ircfg.blocks.get(loc_key, None) - if block is None: - continue - assignblks = list(block) - block_modified = False - for i in range(index, len(block)): - assignblk = block[i] - write_mem = False - assignblk_modified = False - out = dict(assignblk) - out_new = {} - for dst, src in viewitems(out): - if dst.is_mem(): - write_mem = True - ptr = dst.ptr.replace_expr({mem_dst:mem_src}) - dst = ExprMem(ptr, dst.size) - src = src.replace_expr({mem_dst:mem_src}) - out_new[dst] = src - if out != out_new: - assignblk_modified = True - - if assignblk_modified: - assignblks[i] = AssignBlock(out_new, assignblk.instr) - block_modified = True - if write_mem: - break - else: - # If no memory written, we may propagate to sons - # if son has only parent - for successor in ircfg.successors(loc_key): - predecessors = ircfg.predecessors(successor) - if len(predecessors) != 1: - continue - todo.add((successor, 0, mem_dst, mem_src)) - - if block_modified: - modified = True - new_block = IRBlock(block.loc_key, assignblks) - ircfg.blocks[block.loc_key] = new_block - return modified + def has_mem(self): + return self.is_mem() + visitor = ExprWalk(has_mem) + return visitor.visit(expr) def stack_to_reg(expr): @@ -1061,7 +775,11 @@ def visitor_get_stack_accesses(ir_arch_a, expr, stack_vars): def get_stack_accesses(ir_arch_a, expr): result = set() - expr.visit(lambda expr:visitor_get_stack_accesses(ir_arch_a, expr, result)) + def get_stack(expr_to_test): + visitor_get_stack_accesses(ir_arch_a, expr_to_test, result) + return None + visitor = ExprWalk(get_stack) + visitor.visit(expr) return result @@ -1207,11 +925,13 @@ def memlookup_test(expr, bs, is_addr_ro_variable, result): def memlookup_visit(expr, bs, is_addr_ro_variable): result = set() - expr.visit(lambda expr: expr, - lambda expr: memlookup_test(expr, bs, is_addr_ro_variable, result)) + def retrieve_memlookup(expr_to_test): + memlookup_test(expr_to_test, bs, is_addr_ro_variable, result) + return None + visitor = ExprWalk(retrieve_memlookup) + visitor.visit(expr) return result - def get_memlookup(expr, bs, is_addr_ro_variable): return memlookup_visit(expr, bs, is_addr_ro_variable) @@ -1696,3 +1416,795 @@ class DiGraphLivenessSSA(DiGraphLivenessIRA): parent_block.infos[-1].var_out = var_info todo.add(parent) + + +def get_phi_sources(phi_src, phi_dsts, ids_to_src): + """ + Return False if the @phi_src has more than one non-phi source + Else, return its source + @ids_to_src: Dictionary linking phi source to its definition + """ + true_values = set() + for src in phi_src.args: + if src in phi_dsts: + # Source is phi dst => skip + continue + true_src = ids_to_src[src] + if true_src in phi_dsts: + # Source is phi dst => skip + continue + # Check if src is not also a phi + if true_src.is_op('Phi'): + phi_dsts.add(src) + true_src = get_phi_sources(true_src, phi_dsts, ids_to_src) + if true_src is False: + return False + if true_src is True: + continue + true_values.add(true_src) + if len(true_values) != 1: + return False + if not true_values: + return True + if len(true_values) != 1: + return False + true_value = true_values.pop() + return true_value + + +class DelDummyPhi(object): + """ + Del dummy phi + """ + + def del_dummy_phi(self, ssa, head): + ids_to_src = {} + for block in viewvalues(ssa.graph.blocks): + for index, assignblock in enumerate(block): + for dst, src in viewitems(assignblock): + if not dst.is_id(): + continue + ids_to_src[dst] = src + + modified = False + for block in ssa.graph.blocks.values(): + if not irblock_has_phi(block): + continue + assignblk = block[0] + modified_assignblk = False + for dst, phi_src in viewitems(assignblk): + assert phi_src.is_op('Phi') + true_value = get_phi_sources(phi_src, set([dst]), ids_to_src) + if true_value is False: + continue + if expr_has_mem(true_value): + continue + fixed_phis = {} + for old_dst, old_phi_src in viewitems(assignblk): + if old_dst == dst: + continue + fixed_phis[old_dst] = old_phi_src + + modified = True + + assignblks = list(block) + assignblks[0] = AssignBlock(fixed_phis, assignblk.instr) + assignblks[1:1] = [AssignBlock({dst: true_value}, assignblk.instr)] + new_irblock = IRBlock(block.loc_key, assignblks) + ssa.graph.blocks[block.loc_key] = new_irblock + + return modified + + +def replace_expr_from_bottom(expr_orig, dct): + def replace(expr): + if expr in dct: + return dct[expr] + return expr + visitor = ExprVisitorCallbackBottomToTop(lambda expr:replace(expr)) + return visitor.visit(expr_orig) + + +def is_mem_sub_part(needle, mem): + """ + If @needle is a sub part of @mem, return the offset of @needle in @mem + Else, return False + @needle: ExprMem + @mem: ExprMem + """ + ptr_base_a, ptr_offset_a = get_expr_base_offset(needle.ptr) + ptr_base_b, ptr_offset_b = get_expr_base_offset(mem.ptr) + if ptr_base_a != ptr_base_b: + return False + # Test if sub part starts after mem + if not (ptr_offset_b <= ptr_offset_a < ptr_offset_b + mem.size // 8): + return False + # Test if sub part ends before mem + if not (ptr_offset_a + needle.size // 8 <= ptr_offset_b + mem.size // 8): + return False + return ptr_offset_a - ptr_offset_b + +class UnionFind(object): + """ + Implementation of UnionFind structure + __classes: a list of Set of equivalent elements + node_to_class: Dictionary linkink an element to its equivalent class + order: Dictionary link an element to it's weight + + The order attributes is used to allow the selection of a representative + element of an equivalence class + """ + + def __init__(self): + self.index = 0 + self.__classes = [] + self.node_to_class = {} + self.order = dict() + + def copy(self): + """ + Return a copy of the object + """ + unionfind = UnionFind() + unionfind.index = self.index + unionfind.__classes = [set(known_class) for known_class in self.__classes] + node_to_class = {} + for class_eq in unionfind.__classes: + for node in class_eq: + node_to_class[node] = class_eq + unionfind.node_to_class = node_to_class + unionfind.order = dict(self.order) + return unionfind + + def replace_node(self, old_node, new_node): + """ + Replace the @old_node by the @new_node + """ + classes = self.get_classes() + node_to_class = dict(self.node_to_class) + + new_classes = [] + replace_dct = {old_node:new_node} + for eq_class in classes: + new_class = set() + for node in eq_class: + new_class.add(replace_expr_from_bottom(node, replace_dct)) + new_classes.append(new_class) + + node_to_class = {} + for class_eq in new_classes: + for node in class_eq: + node_to_class[node] = class_eq + self.__classes = new_classes + self.node_to_class = node_to_class + new_order = dict() + for node,index in self.order.items(): + new_node = replace_expr_from_bottom(node, replace_dct) + new_order[new_node] = index + self.order = new_order + + def get_classes(self): + """ + Return a list of the equivalent classes + """ + classes = [] + for class_tmp in self.__classes: + classes.append(set(class_tmp)) + return classes + + def nodes(self): + for known_class in self.__classes: + for node in known_class: + yield node + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + + return Counter(frozenset(known_class) for known_class in self.__classes) == Counter(frozenset(known_class) for known_class in other.__classes) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __str__(self): + components = self.__classes + out = ['UnionFind<'] + for component in components: + out.append("\t" + (", ".join([str(node) for node in component]))) + out.append('>') + return "\n".join(out) + + def add_equivalence(self, node_a, node_b): + """ + Add the new equivalence @node_a == @node_b + @node_a is equivalent to @node_b, but @node_b is more representative + than @node_a + """ + if node_b not in self.order: + self.order[node_b] = self.index + self.index += 1 + # As node_a is destination, we always replace its index + self.order[node_a] = self.index + self.index += 1 + + if node_a not in self.node_to_class and node_b not in self.node_to_class: + new_class = set([node_a, node_b]) + self.node_to_class[node_a] = new_class + self.node_to_class[node_b] = new_class + self.__classes.append(new_class) + elif node_a in self.node_to_class and node_b not in self.node_to_class: + known_class = self.node_to_class[node_a] + known_class.add(node_b) + self.node_to_class[node_b] = known_class + elif node_a not in self.node_to_class and node_b in self.node_to_class: + known_class = self.node_to_class[node_b] + known_class.add(node_a) + self.node_to_class[node_a] = known_class + else: + raise RuntimeError("Two nodes cannot be in two classes") + + def _get_master(self, node): + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + best_node = node + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + + def get_master(self, node): + """ + Return the representative element of the equivalence class containing + @node + @node: ExprMem or ExprId + """ + if not node.is_mem(): + return self._get_master(node) + if node in self.node_to_class: + # Full expr mem is known + return self._get_master(node) + # Test if mem is sub part of known node + for expr in self.node_to_class: + if not expr.is_mem(): + continue + ret = is_mem_sub_part(node, expr) + if ret is False: + continue + master = self._get_master(expr) + master = master[ret * 8 : ret * 8 + node.size] + return master + + return self._get_master(node) + + + def del_element(self, node): + """ + Remove @node for the equivalence classes + """ + assert node in self.node_to_class + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + def del_get_new_master(self, node): + """ + Remove @node for the equivalence classes and return it's representative + equivalent element + @node: Element to delete + """ + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + if not known_class: + return None + best_node = list(known_class)[0] + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + +class ExprToGraph(ExprWalk): + """ + Transform an Expression into a tree and add link nodes to an existing tree + """ + def __init__(self, graph): + super(ExprToGraph, self).__init__(self.link_nodes) + self.graph = graph + + def link_nodes(self, expr, *args, **kwargs): + """ + Transform an Expression @expr into a tree and add link nodes to the + current tree + @expr: Expression + """ + if expr in self.graph.nodes(): + return None + self.graph.add_node(expr) + if expr.is_mem(): + self.graph.add_uniq_edge(expr, expr.ptr) + elif expr.is_slice(): + self.graph.add_uniq_edge(expr, expr.arg) + elif expr.is_cond(): + self.graph.add_uniq_edge(expr, expr.cond) + self.graph.add_uniq_edge(expr, expr.src1) + self.graph.add_uniq_edge(expr, expr.src2) + elif expr.is_compose(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + elif expr.is_op(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + return None + +class State(object): + """ + Object representing the state of a program at a given point + The state is represented using equivalence classes + + Each assignment can create/destroy equivalence classes. Interferences + between expression is computed using `may_interfer` function + """ + + def __init__(self): + self.equivalence_classes = UnionFind() + self.undefined = set() + + def copy(self): + state = self.__class__() + state.equivalence_classes = self.equivalence_classes.copy() + state.undefined = self.undefined.copy() + return state + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + return ( + set(self.equivalence_classes.nodes()) == set(other.equivalence_classes.nodes()) and + sorted(self.equivalence_classes.edges()) == sorted(other.equivalence_classes.edges()) and + self.undefined == other.undefined + ) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def may_interfer(self, dsts, src): + """ + Return True is @src may interfer with expressions in @dsts + @dsts: Set of Expressions + @src: expression to test + """ + + srcs = src.get_r() + for src in srcs: + for dst in dsts: + if dst in src: + return True + if dst.is_mem() and src.is_mem(): + base1, offset1 = get_expr_base_offset(dst.ptr) + base2, offset2 = get_expr_base_offset(src.ptr) + if base1 != base2: + return True + assert offset1 + dst.size // 8 - 1 <= int(base1.mask) + assert offset2 + src.size // 8 - 1 <= int(base2.mask) + interval1 = interval([(offset1, offset1 + dst.size // 8 - 1)]) + interval2 = interval([(offset2, offset2 + src.size // 8 - 1)]) + if (interval1 & interval2).empty: + continue + return True + return False + + def _get_representative_expr(self, expr): + representative = self.equivalence_classes.get_master(expr) + if representative is None: + return expr + return representative + + def get_representative_expr(self, expr): + """ + Replace each sub expression of @expr by its representative element + @expr: Expression to analyse + """ + new_expr = expr.visit(self._get_representative_expr) + return new_expr + + def propagation_allowed(self, expr): + """ + Return True if @expr can be propagated + Don't propagate: + - Phi nodes + - call_func_ret / call_func_stack operants + """ + + if ( + expr.is_op('Phi') or + (expr.is_op() and expr.op.startswith("call_func")) + ): + return False + return True + + def eval_assignblock(self, assignblock): + """ + Evaluate the @assignblock on the current state + @assignblock: AssignBlock instance + """ + + out = dict(assignblock.items()) + new_out = dict() + # Replace sub expression by their equivalence class repesentative + for dst, src in out.items(): + if src.is_op('Phi'): + # Don't replace in phi + new_src = src + else: + new_src = self.get_representative_expr(src) + if dst.is_mem(): + new_ptr = self.get_representative_expr(dst.ptr) + new_dst = ExprMem(new_ptr, dst.size) + else: + new_dst = dst + new_dst = expr_simp(new_dst) + new_src = expr_simp(new_src) + new_out[new_dst] = new_src + + # For each destination, update (or delete) dependent's node according to + # equivalence classes + classes = self.equivalence_classes + + for dst in new_out: + + replacement = classes.del_get_new_master(dst) + if replacement is None: + to_del = set([dst]) + to_replace = {} + else: + to_del = set() + to_replace = {dst:replacement} + + graph = DiGraph() + # Build en expression graph linking all classes + has_parents = False + for node in classes.nodes(): + if dst in node: + # Only dependent nodes are interesting here + has_parents = True + expr_to_graph = ExprToGraph(graph) + expr_to_graph.visit(node) + + if not has_parents: + continue + + todo = graph.leaves() + done = set() + + while todo: + node = todo.pop(0) + if node in done: + continue + # If at least one son is not done, re do later + if [son for son in graph.successors(node) if son not in done]: + todo.append(node) + continue + done.add(node) + + # If at least one son cannot be replaced (deleted), our last + # chance is to have an equivalence + if any(son in to_del for son in graph.successors(node)): + # One son has been deleted! + # Try to find a replacement of the whole expression + replacement = classes.del_get_new_master(node) + if replacement is None: + to_del.add(node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + else: + to_replace[node] = replacement + # Continue with replacement + + # Everyson is live or has been replaced + new_node = node.replace_expr(to_replace) + + if new_node == node: + # If node is not touched (Ex: leaf node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + + # Node has been modified, update equivalence classes + classes.replace_node(node, new_node) + to_replace[node] = new_node + + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + + continue + + new_assignblk = AssignBlock(new_out, assignblock.instr) + dsts = new_out.keys() + + # Remove interfering known classes + to_del = set() + for node in list(classes.nodes()): + if self.may_interfer(dsts, node): + # Interfer with known equivalence class + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + + # Update equivalence classes + for dst, src in new_out.items(): + # Delete equivalence class interfering with dst + to_del = set() + classes = self.equivalence_classes + for node in classes.nodes(): + if dst in node: + to_del.add(node) + for node in to_del: + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + # Don't create equivalence if self interfer + if self.may_interfer(dsts, src): + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + if dst.is_id() or dst.is_mem(): + self.undefined.add(dst) + continue + + if not self.propagation_allowed(src): + continue + + ## Dont create equivalence if dependence on undef + if dst.is_mem() and self.may_interfer(self.undefined, dst.ptr): + continue + + self.undefined.discard(dst) + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + self.equivalence_classes.add_equivalence(dst, src) + + return new_assignblk + + + def merge(self, other): + """ + Merge the current state with @other + @other: State instance + """ + classes1 = self.equivalence_classes + classes2 = other.equivalence_classes + + undefined = set(node for node in self.undefined if node.is_id() or node.is_mem()) + undefined.update(set(node for node in other.undefined if node.is_id() or node.is_mem())) + # Should we compute interference between srcs and undefined ? + # Nop => should already interfer in other state + components1 = classes1.get_classes() + components2 = classes2.get_classes() + + node_to_component2 = {} + for component in components2: + for node in component: + node_to_component2[node] = component + + out = [] + nodes_ok = set() + while components1: + component1 = components1.pop() + new_component1 = set() + for node in component1: + if node in undefined: + continue + component2 = node_to_component2.get(node) + if component2 is None: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + continue + if node not in component2: + continue + common = component1.intersection(component2) + if len(common) == 1: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + component2.discard(common.pop()) + continue + if common: + nodes_ok.update(common) + out.append(common) + diff = component1.difference(common) + if diff: + components1.append(diff) + component2.difference_update(common) + break + + # Discard remaining components2 elements + for component in components2: + for node in component: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + + all_nodes = set() + for common in out: + all_nodes.update(common) + + new_order = dict( + (node, index) for (node, index) in classes1.order.items() + if node in all_nodes + ) + + unionfind = UnionFind() + new_classes = [] + global_max_index = 0 + for common in out: + min_index = None + master = None + for node in common: + index = new_order[node] + global_max_index = max(index, global_max_index) + if min_index is None or min_index > index: + min_index = index + master = node + for node in common: + if node == master: + continue + unionfind.add_equivalence(node, master) + + unionfind.index = global_max_index + unionfind.order = new_order + state = self.__class__() + state.equivalence_classes = unionfind + state.undefined = undefined + + return state + + +class PropagateExpressions(object): + """ + Propagate expressions + + The algorithm propagates equivalence classes expressions from the entry + point. During the analyse, we replace source nodes by its equivalence + classes representative. Equivalence classes can be modified during analyse + due to memory aliasing. + + For example: + B = A+1 + C = A + A = 6 + D = [B] + + Will result in: + B = A+1 + C = A + A = 6 + D = [C+1] + """ + + @staticmethod + def new_state(): + return State() + + def merge_prev_states(self, ircfg, states, loc_key): + """ + Merge predecessors states of irblock at location @loc_key + @ircfg: IRCfg instance + @sates: Dictionary linking locations to state + @loc_key: location of the current irblock + """ + + prev_states = [] + for predecessor in ircfg.predecessors(loc_key): + prev_states.append((predecessor, states[predecessor])) + + filtered_prev_states = [] + for (_, prev_state) in prev_states: + if prev_state is not None: + filtered_prev_states.append(prev_state) + + prev_states = filtered_prev_states + if not prev_states: + state = self.new_state() + elif len(prev_states) == 1: + state = prev_states[0].copy() + else: + while prev_states: + state = prev_states.pop() + if state is not None: + break + for prev_state in prev_states: + state = state.merge(prev_state) + + return state + + def update_state(self, irblock, state): + """ + Propagate the @state through the @irblock + @irblock: IRBlock instance + @state: State instance + """ + new_assignblocks = [] + modified = False + + for index, assignblock in enumerate(irblock): + if not assignblock.items(): + continue + new_assignblk = state.eval_assignblock(assignblock) + new_assignblocks.append(new_assignblk) + if new_assignblk != assignblock: + modified = True + + new_irblock = IRBlock(irblock.loc_key, new_assignblocks) + + return new_irblock, modified + + def propagate(self, ssa, head, max_expr_depth=None): + """ + Apply algorithm on the @ssa graph + """ + ircfg = ssa.ircfg + self.loc_db = ircfg.loc_db + irblocks = ssa.ircfg.blocks + states = {} + for loc_key, irblock in irblocks.items(): + states[loc_key] = None + + todo = deque([head]) + while todo: + loc_key = todo.popleft() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state_orig = states[loc_key] + state = self.merge_prev_states(ircfg, states, loc_key) + state = state.copy() + + new_irblock, modified_irblock = self.update_state(irblock, state) + if ( + state_orig is not None and + state.equivalence_classes == state_orig.equivalence_classes and + state.undefined == state_orig.undefined + ): + continue + + if state_orig: + state.undefined.update(state_orig.undefined) + states[loc_key] = state + # Propagate to sons + for successor in ircfg.successors(loc_key): + todo.append(successor) + + # Update blocks + todo = set(loc_key for loc_key in irblocks) + modified = False + while todo: + loc_key = todo.pop() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state = self.merge_prev_states(ircfg, states, loc_key) + new_irblock, modified_irblock = self.update_state(irblock, state) + modified |= modified_irblock + irblocks[new_irblock.loc_key] = new_irblock + + return modified diff --git a/miasm/analysis/depgraph.py b/miasm/analysis/depgraph.py index 7113dd51..964dcef4 100644 --- a/miasm/analysis/depgraph.py +++ b/miasm/analysis/depgraph.py @@ -4,7 +4,8 @@ from functools import total_ordering from future.utils import viewitems -from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign +from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign, \ + ExprWalk, canonize_to_exprloc from miasm.core.graph import DiGraph from miasm.core.locationdb import LocationDB from miasm.expression.simplifications import expr_simp_explicit @@ -333,10 +334,10 @@ class DependencyResultImplicit(DependencyResult): generated loc_keys """ out = [] - expected = self._ircfg.loc_db.canonize_to_exprloc(expected) + expected = canonize_to_exprloc(self._ircfg.loc_db, expected) expected_is_loc_key = expected.is_loc() for consval in possible_values(expr): - value = self._ircfg.loc_db.canonize_to_exprloc(consval.value) + value = canonize_to_exprloc(self._ircfg.loc_db, consval.value) if expected_is_loc_key and value != expected: continue if not expected_is_loc_key and value.is_loc_key(): @@ -449,6 +450,50 @@ class FollowExpr(object): if not(only_follow) or follow_expr.follow) +class FilterExprSources(ExprWalk): + """ + Walk Expression to find sources to track + @follow_mem: (optional) Track memory syntactically + @follow_call: (optional) Track through "call" + """ + def __init__(self, follow_mem, follow_call): + super(FilterExprSources, self).__init__(lambda x:None) + self.follow_mem = follow_mem + self.follow_call = follow_call + self.nofollow = set() + self.follow = set() + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = self.visit_inner(expr, *args, **kwargs) + self.cache.add(expr) + return ret + + def visit_inner(self, expr, *args, **kwargs): + if expr.is_id(): + self.follow.add(expr) + elif expr.is_int(): + self.nofollow.add(expr) + elif expr.is_loc(): + self.nofollow.add(expr) + elif expr.is_mem(): + if self.follow_mem: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + elif expr.is_function_call(): + if self.follow_call: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + + ret = super(FilterExprSources, self).visit(expr, *args, **kwargs) + return ret + + class DependencyGraph(object): """Implementation of a dependency graph @@ -480,10 +525,14 @@ class DependencyGraph(object): self._cb_follow = [] if apply_simp: self._cb_follow.append(self._follow_simp_expr) - self._cb_follow.append(lambda exprs: self._follow_exprs(exprs, - follow_mem, - follow_call)) - self._cb_follow.append(self._follow_no_loc_key) + self._cb_follow.append(lambda exprs: self.do_follow(exprs, follow_mem, follow_call)) + + @staticmethod + def do_follow(exprs, follow_mem, follow_call): + visitor = FilterExprSources(follow_mem, follow_call) + for expr in exprs: + visitor.visit(expr) + return visitor.follow, visitor.nofollow @staticmethod def _follow_simp_expr(exprs): @@ -495,64 +544,6 @@ class DependencyGraph(object): follow.add(expr_simp_explicit(expr)) return follow, set() - @staticmethod - def get_expr(expr, follow, nofollow): - """Update @follow/@nofollow according to insteresting nodes - Returns same expression (non modifier visitor). - - @expr: expression to handle - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - """ - if expr.is_id(): - follow.add(expr) - elif expr.is_int(): - nofollow.add(expr) - elif expr.is_mem(): - follow.add(expr) - return expr - - @staticmethod - def follow_expr(expr, _, nofollow, follow_mem=False, follow_call=False): - """Returns True if we must visit sub expressions. - @expr: expression to browse - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - @follow_mem: force the visit of memory sub expressions - @follow_call: force the visit of call sub expressions - """ - if not follow_mem and expr.is_mem(): - nofollow.add(expr) - return False - if not follow_call and expr.is_function_call(): - nofollow.add(expr) - return False - return True - - @classmethod - def _follow_exprs(cls, exprs, follow_mem=False, follow_call=False): - """Extracts subnodes from exprs and returns followed/non followed - expressions according to @follow_mem/@follow_call - - """ - follow, nofollow = set(), set() - for expr in exprs: - expr.visit(lambda x: cls.get_expr(x, follow, nofollow), - lambda x: cls.follow_expr(x, follow, nofollow, - follow_mem, follow_call)) - return follow, nofollow - - @staticmethod - def _follow_no_loc_key(exprs): - """Do not follow loc_keys""" - follow = set() - for expr in exprs: - if expr.is_int() or expr.is_loc(): - continue - follow.add(expr) - - return follow, set() - def _follow_apply_cb(self, expr): """Apply callback functions to @expr @expr : FollowExpr instance""" diff --git a/miasm/analysis/dse.py b/miasm/analysis/dse.py index 3a0482a3..9cc342c7 100644 --- a/miasm/analysis/dse.py +++ b/miasm/analysis/dse.py @@ -59,7 +59,7 @@ from future.utils import viewitems from miasm.core.utils import encode_hex, force_bytes from miasm.expression.expression import ExprMem, ExprInt, ExprCompose, \ - ExprAssign, ExprId, ExprLoc, LocKey + ExprAssign, ExprId, ExprLoc, LocKey, canonize_to_exprloc from miasm.core.bin_stream import bin_stream_vm from miasm.jitter.emulatedsymbexec import EmulatedSymbExec from miasm.expression.expression_helper import possible_values @@ -258,7 +258,7 @@ class DSEEngine(object): # lambda cannot contain statement def default_func(dse): - fname = b"%s_symb" % libimp.fad2cname[dse.jitter.pc] + fname = b"%s_symb" % force_bytes(libimp.fad2cname[dse.jitter.pc]) raise RuntimeError("Symbolic stub '%s' not found" % fname) for addr, fname in viewitems(libimp.fad2cname): @@ -333,8 +333,8 @@ class DSEEngine(object): self.handle(ExprInt(cur_addr, self.ir_arch.IRDst.size)) # Avoid memory issue in ExpressionSimplifier - if len(self.symb.expr_simp.simplified_exprs) > 100000: - self.symb.expr_simp.simplified_exprs.clear() + if len(self.symb.expr_simp.cache) > 100000: + self.symb.expr_simp.cache.clear() # Get IR blocks if cur_addr in self.addr_to_cacheblocks: @@ -633,19 +633,17 @@ class DSEPathConstraint(DSEEngine): self.cur_solver.add(self.z3_trans.from_expr(cons)) def handle(self, cur_addr): - cur_addr = self.ir_arch.loc_db.canonize_to_exprloc(cur_addr) + cur_addr = canonize_to_exprloc(self.ir_arch.loc_db, cur_addr) symb_pc = self.eval_expr(self.ir_arch.IRDst) possibilities = possible_values(symb_pc) cur_path_constraint = set() # path_constraint for the concrete path if len(possibilities) == 1: dst = next(iter(possibilities)).value - dst = self.ir_arch.loc_db.canonize_to_exprloc(dst) + dst = canonize_to_exprloc(self.ir_arch.loc_db, dst) assert dst == cur_addr else: for possibility in possibilities: - target_addr = self.ir_arch.loc_db.canonize_to_exprloc( - possibility.value - ) + target_addr = canonize_to_exprloc(self.ir_arch.loc_db, possibility.value) path_constraint = set() # Set of ExprAssign for the possible path # Get constraint associated to the possible path diff --git a/miasm/analysis/gdbserver.py b/miasm/analysis/gdbserver.py index ac58cdad..b45e9f35 100644 --- a/miasm/analysis/gdbserver.py +++ b/miasm/analysis/gdbserver.py @@ -251,8 +251,8 @@ class GdbServer(object): else: raise NotImplementedError("Unknown Except") elif isinstance(ret, debugging.DebugBreakpointTerminate): - # Connexion should close, but keep it running as a TRAP - # The connexion will be close on instance destruction + # Connection should close, but keep it running as a TRAP + # The connection will be close on instance destruction print(ret) self.status = b"S05" self.send_queue.append(b"S05") diff --git a/miasm/analysis/sandbox.py b/miasm/analysis/sandbox.py index 3040a1a8..1449d7be 100644 --- a/miasm/analysis/sandbox.py +++ b/miasm/analysis/sandbox.py @@ -213,6 +213,7 @@ class OS_Win(OS): fstream.read(), load_hdr=self.options.load_hdr, name=self.fname, + winobjs=win_api_x86_32.winobjs, **kwargs ) self.name2module[fname_basename] = self.pe @@ -227,6 +228,7 @@ class OS_Win(OS): self.ALL_IMP_DLL, libs, self.modules_path, + winobjs=win_api_x86_32.winobjs, **kwargs ) ) @@ -242,6 +244,7 @@ class OS_Win(OS): self.name2module, libs, self.modules_path, + winobjs=win_api_x86_32.winobjs, **kwargs ) diff --git a/miasm/analysis/simplifier.py b/miasm/analysis/simplifier.py index 8e9005a8..43623476 100644 --- a/miasm/analysis/simplifier.py +++ b/miasm/analysis/simplifier.py @@ -11,8 +11,8 @@ from miasm.expression.simplifications import expr_simp from miasm.ir.ir import AssignBlock, IRBlock from miasm.analysis.data_flow import DeadRemoval, \ merge_blocks, remove_empty_assignblks, \ - PropagateExprIntThroughExprId, PropagateThroughExprId, \ - PropagateThroughExprMem, del_unused_edges + del_unused_edges, \ + PropagateExpressions, DelDummyPhi log = logging.getLogger("simplifier") @@ -129,9 +129,7 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): and apply out-of-ssa. Final passes of IRcfgSimplifier are applied This class apply following pass until reaching a fix point: - - do_propagate_int - - do_propagate_mem - - do_propagate_expr + - do_propagate_expressions - do_dead_simp_ssa """ @@ -143,9 +141,9 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): self.ssa_forbidden_regs = self.get_forbidden_regs() - self.propag_int = PropagateExprIntThroughExprId() - self.propag_expr = PropagateThroughExprId() - self.propag_mem = PropagateThroughExprMem() + self.propag_expressions = PropagateExpressions() + self.del_dummy_phi = DelDummyPhi() + self.deadremoval = DeadRemoval(self.ir_arch, self.all_ssa_vars) def get_forbidden_regs(self): @@ -167,9 +165,8 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): """ self.passes = [ self.simplify_ssa, - self.do_propagate_int, - self.do_propagate_mem, - self.do_propagate_expr, + self.do_propagate_expressions, + self.do_del_dummy_phi, self.do_dead_simp_ssa, self.do_remove_empty_assignblks, self.do_del_unused_edges, @@ -245,13 +242,21 @@ class IRCFGSimplifierSSA(IRCFGSimplifierCommon): modified = self.propag_mem.propagate(ssa, head) return modified - @fix_point - def do_propagate_expr(self, ssa, head): + def do_propagate_expressions(self, ssa, head): """ Expressions propagation through ExprId in the @ssa graph @head: Location instance of the graph head """ - modified = self.propag_expr.propagate(ssa, head) + modified = self.propag_expressions.propagate(ssa, head) + return modified + + @fix_point + def do_del_dummy_phi(self, ssa, head): + """ + Del dummy phi + @head: Location instance of the graph head + """ + modified = self.del_dummy_phi.del_dummy_phi(ssa, head) return modified @fix_point |