diff options
Diffstat (limited to 'miasm2/analysis')
| -rw-r--r-- | miasm2/analysis/__init__.py | 1 | ||||
| -rw-r--r-- | miasm2/analysis/binary.py | 236 | ||||
| -rw-r--r-- | miasm2/analysis/cst_propag.py | 185 | ||||
| -rw-r--r-- | miasm2/analysis/data_analysis.py | 204 | ||||
| -rw-r--r-- | miasm2/analysis/data_flow.py | 1579 | ||||
| -rw-r--r-- | miasm2/analysis/debugging.py | 499 | ||||
| -rw-r--r-- | miasm2/analysis/depgraph.py | 651 | ||||
| -rw-r--r-- | miasm2/analysis/disasm_cb.py | 128 | ||||
| -rw-r--r-- | miasm2/analysis/dse.py | 708 | ||||
| -rw-r--r-- | miasm2/analysis/expression_range.py | 70 | ||||
| -rw-r--r-- | miasm2/analysis/gdbserver.py | 453 | ||||
| -rw-r--r-- | miasm2/analysis/machine.py | 265 | ||||
| -rw-r--r-- | miasm2/analysis/modularintervals.py | 530 | ||||
| -rw-r--r-- | miasm2/analysis/outofssa.py | 413 | ||||
| -rw-r--r-- | miasm2/analysis/sandbox.py | 1026 | ||||
| -rw-r--r-- | miasm2/analysis/simplifier.py | 303 | ||||
| -rw-r--r-- | miasm2/analysis/ssa.py | 1118 |
17 files changed, 0 insertions, 8369 deletions
diff --git a/miasm2/analysis/__init__.py b/miasm2/analysis/__init__.py deleted file mode 100644 index 5abdd3a3..00000000 --- a/miasm2/analysis/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"High-level tools for binary analysis" diff --git a/miasm2/analysis/binary.py b/miasm2/analysis/binary.py deleted file mode 100644 index ee733d79..00000000 --- a/miasm2/analysis/binary.py +++ /dev/null @@ -1,236 +0,0 @@ -import logging -import warnings - -from miasm2.core.bin_stream import bin_stream_str, bin_stream_elf, bin_stream_pe -from miasm2.jitter.csts import PAGE_READ -from miasm2.core.locationdb import LocationDB - - -log = logging.getLogger("binary") -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) -log.addHandler(console_handler) -log.setLevel(logging.ERROR) - - -# Container -## Exceptions -class ContainerSignatureException(Exception): - "The container does not match the current container signature" - - -class ContainerParsingException(Exception): - "Error during container parsing" - - -## Parent class -class Container(object): - """Container abstraction layer - - This class aims to offer a common interface for abstracting container - such as PE or ELF. - """ - - available_container = [] # Available container formats - fallback_container = None # Fallback container format - - @classmethod - def from_string(cls, data, *args, **kwargs): - """Instantiate a container and parse the binary - @data: str containing the binary - """ - log.info('Load binary') - # Try each available format - for container_type in cls.available_container: - try: - return container_type(data, *args, **kwargs) - except ContainerSignatureException: - continue - except ContainerParsingException as error: - log.error(error) - - # Fallback mode - log.warning('Fallback to string input') - return cls.fallback_container(data, *args, **kwargs) - - @classmethod - def register_container(cls, container): - "Add a Container format" - cls.available_container.append(container) - - @classmethod - def register_fallback(cls, container): - "Set the Container fallback format" - cls.fallback_container = container - - @classmethod - def from_stream(cls, stream, *args, **kwargs): - """Instantiate a container and parse the binary - @stream: stream to use as binary - @vm: (optional) VmMngr instance to link with the executable - @addr: (optional) Base address of the parsed binary. If set, - force the unknown format - """ - return Container.from_string(stream.read(), *args, **kwargs) - - def parse(self, data, *args, **kwargs): - """Launch parsing of @data - @data: str containing the binary - """ - raise NotImplementedError("Abstract method") - - def __init__(self, data, loc_db=None, **kwargs): - "Alias for 'parse'" - # Init attributes - self._executable = None - self._bin_stream = None - self._entry_point = None - self._arch = None - if loc_db is None: - self._loc_db = LocationDB() - else: - self._loc_db = loc_db - - # Launch parsing - self.parse(data, **kwargs) - - @property - def bin_stream(self): - "Return the BinStream instance corresponding to container content" - return self._bin_stream - - @property - def executable(self): - "Return the abstract instance standing for parsed executable" - return self._executable - - @property - def entry_point(self): - "Return the detected entry_point" - return self._entry_point - - @property - def arch(self): - "Return the guessed architecture" - return self._arch - - @property - def loc_db(self): - "LocationDB instance preloaded with container symbols (if any)" - return self._loc_db - - @property - def symbol_pool(self): - "[DEPRECATED API]" - warnings.warn("Deprecated API: use 'loc_db'") - return self.loc_db - -## Format dependent classes -class ContainerPE(Container): - "Container abstraction for PE" - - def parse(self, data, vm=None, **kwargs): - from miasm2.jitter.loader.pe import vm_load_pe, guess_arch - from elfesteem import pe_init - - # Parse signature - if not data.startswith(b'MZ'): - raise ContainerSignatureException() - - # Build executable instance - try: - if vm is not None: - self._executable = vm_load_pe(vm, data) - else: - self._executable = pe_init.PE(data) - except Exception as error: - raise ContainerParsingException('Cannot read PE: %s' % error) - - # Check instance validity - if not self._executable.isPE() or \ - self._executable.NTsig.signature_value != 0x4550: - raise ContainerSignatureException() - - # Guess the architecture - self._arch = guess_arch(self._executable) - - # Build the bin_stream instance and set the entry point - try: - self._bin_stream = bin_stream_pe(self._executable) - ep_detected = self._executable.Opthdr.AddressOfEntryPoint - self._entry_point = self._executable.rva2virt(ep_detected) - except Exception as error: - raise ContainerParsingException('Cannot read PE: %s' % error) - - -class ContainerELF(Container): - "Container abstraction for ELF" - - def parse(self, data, vm=None, addr=0, apply_reloc=False, **kwargs): - """Load an ELF from @data - @data: bytes containing the ELF bytes - @vm (optional): VmMngr instance. If set, load the ELF in virtual memory - @addr (optional): base address the ELF in virtual memory - @apply_reloc (optional): if set, apply relocation during ELF loading - - @addr and @apply_reloc are only meaningful in the context of a - non-empty @vm - """ - from miasm2.jitter.loader.elf import vm_load_elf, guess_arch, \ - fill_loc_db_with_symbols - from elfesteem import elf_init - - # Parse signature - if not data.startswith(b'\x7fELF'): - raise ContainerSignatureException() - - # Build executable instance - try: - if vm is not None: - self._executable = vm_load_elf( - vm, - data, - loc_db=self.loc_db, - base_addr=addr, - apply_reloc=apply_reloc - ) - else: - self._executable = elf_init.ELF(data) - except Exception as error: - raise ContainerParsingException('Cannot read ELF: %s' % error) - - # Guess the architecture - self._arch = guess_arch(self._executable) - - # Build the bin_stream instance and set the entry point - try: - self._bin_stream = bin_stream_elf(self._executable) - self._entry_point = self._executable.Ehdr.entry + addr - except Exception as error: - raise ContainerParsingException('Cannot read ELF: %s' % error) - - if vm is None: - # Add known symbols (vm_load_elf already does it) - fill_loc_db_with_symbols(self._executable, self.loc_db, addr) - - - -class ContainerUnknown(Container): - "Container abstraction for unknown format" - - def parse(self, data, vm=None, addr=0, **kwargs): - self._bin_stream = bin_stream_str(data, base_address=addr) - if vm is not None: - vm.add_memory_page( - addr, - PAGE_READ, - data - ) - self._executable = None - self._entry_point = 0 - - -## Register containers -Container.register_container(ContainerPE) -Container.register_container(ContainerELF) -Container.register_fallback(ContainerUnknown) diff --git a/miasm2/analysis/cst_propag.py b/miasm2/analysis/cst_propag.py deleted file mode 100644 index 25d66318..00000000 --- a/miasm2/analysis/cst_propag.py +++ /dev/null @@ -1,185 +0,0 @@ -import logging - -from future.utils import viewitems - -from miasm2.ir.symbexec import SymbolicExecutionEngine -from miasm2.expression.expression import ExprMem -from miasm2.expression.expression_helper import possible_values -from miasm2.expression.simplifications import expr_simp -from miasm2.ir.ir import IRBlock, AssignBlock - -LOG_CST_PROPAG = logging.getLogger("cst_propag") -CONSOLE_HANDLER = logging.StreamHandler() -CONSOLE_HANDLER.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) -LOG_CST_PROPAG.addHandler(CONSOLE_HANDLER) -LOG_CST_PROPAG.setLevel(logging.WARNING) - - -class SymbExecState(SymbolicExecutionEngine): - """ - State manager for SymbolicExecution - """ - def __init__(self, ir_arch, ircfg, state): - super(SymbExecState, self).__init__(ir_arch, {}) - self.set_state(state) - - -def add_state(ircfg, todo, states, addr, state): - """ - Add or merge the computed @state for the block at @addr. Update @todo - @todo: modified block set - @states: dictionary linking a label to its entering state. - @addr: address of the considered block - @state: computed state - """ - addr = ircfg.get_loc_key(addr) - todo.add(addr) - if addr not in states: - states[addr] = state - else: - states[addr] = states[addr].merge(state) - - -def is_expr_cst(ir_arch, expr): - """Return true if @expr is only composed of ExprInt and init_regs - @ir_arch: IR instance - @expr: Expression to test""" - - elements = expr.get_r(mem_read=True) - for element in elements: - if element.is_mem(): - continue - if element.is_id() and element in ir_arch.arch.regs.all_regs_ids_init: - continue - if element.is_int(): - continue - return False - # Expr is a constant - return True - - -class SymbExecStateFix(SymbolicExecutionEngine): - """ - Emul blocks and replace expressions with their corresponding constant if - any. - - """ - # Function used to test if an Expression is considered as a constant - is_expr_cst = lambda _, ir_arch, expr: is_expr_cst(ir_arch, expr) - - def __init__(self, ir_arch, ircfg, state, cst_propag_link): - self.ircfg = ircfg - super(SymbExecStateFix, self).__init__(ir_arch, {}) - self.set_state(state) - self.cst_propag_link = cst_propag_link - - def propag_expr_cst(self, expr): - """Propagate constant expressions in @expr - @expr: Expression to update""" - elements = expr.get_r(mem_read=True) - to_propag = {} - for element in elements: - # Only ExprId can be safely propagated - if not element.is_id(): - continue - value = self.eval_expr(element) - if self.is_expr_cst(self.ir_arch, value): - to_propag[element] = value - return expr_simp(expr.replace_expr(to_propag)) - - def eval_updt_irblock(self, irb, step=False): - """ - Symbolic execution of the @irb on the current state - @irb: IRBlock instance - @step: display intermediate steps - """ - assignblks = [] - for index, assignblk in enumerate(irb): - new_assignblk = {} - links = {} - for dst, src in viewitems(assignblk): - src = self.propag_expr_cst(src) - if dst.is_mem(): - ptr = dst.ptr - ptr = self.propag_expr_cst(ptr) - dst = ExprMem(ptr, dst.size) - new_assignblk[dst] = src - - if assignblk.instr is not None: - for arg in assignblk.instr.args: - new_arg = self.propag_expr_cst(arg) - links[new_arg] = arg - self.cst_propag_link[(irb.loc_key, index)] = links - - self.eval_updt_assignblk(assignblk) - assignblks.append(AssignBlock(new_assignblk, assignblk.instr)) - self.ircfg.blocks[irb.loc_key] = IRBlock(irb.loc_key, assignblks) - - -def compute_cst_propagation_states(ir_arch, ircfg, init_addr, init_infos): - """ - Propagate "constant expressions" in a function. - The attribute "constant expression" is true if the expression is based on - constants or "init" regs values. - - @ir_arch: IntermediateRepresentation instance - @init_addr: analysis start address - @init_infos: dictionary linking expressions to their values at @init_addr - """ - - done = set() - state = SymbExecState.StateEngine(init_infos) - lbl = ircfg.get_loc_key(init_addr) - todo = set([lbl]) - states = {lbl: state} - - while todo: - if not todo: - break - lbl = todo.pop() - state = states[lbl] - if (lbl, state) in done: - continue - done.add((lbl, state)) - if lbl not in ircfg.blocks: - continue - - symbexec_engine = SymbExecState(ir_arch, ircfg, state) - addr = symbexec_engine.run_block_at(ircfg, lbl) - symbexec_engine.del_mem_above_stack(ir_arch.sp) - - for dst in possible_values(addr): - value = dst.value - if value.is_mem(): - LOG_CST_PROPAG.warning('Bad destination: %s', value) - continue - elif value.is_int(): - value = ircfg.get_loc_key(value) - add_state( - ircfg, todo, states, value, - symbexec_engine.get_state() - ) - - return states - - -def propagate_cst_expr(ir_arch, ircfg, addr, init_infos): - """ - Propagate "constant expressions" in a @ir_arch. - The attribute "constant expression" is true if the expression is based on - constants or "init" regs values. - - @ir_arch: IntermediateRepresentation instance - @addr: analysis start address - @init_infos: dictionary linking expressions to their values at @init_addr - - Returns a mapping between replaced Expression and their new values. - """ - states = compute_cst_propagation_states(ir_arch, ircfg, addr, init_infos) - cst_propag_link = {} - for lbl, state in viewitems(states): - if lbl not in ircfg.blocks: - continue - symbexec = SymbExecStateFix(ir_arch, ircfg, state, cst_propag_link) - symbexec.eval_updt_irblock(ircfg.blocks[lbl]) - return cst_propag_link diff --git a/miasm2/analysis/data_analysis.py b/miasm2/analysis/data_analysis.py deleted file mode 100644 index bd073fcb..00000000 --- a/miasm2/analysis/data_analysis.py +++ /dev/null @@ -1,204 +0,0 @@ -from __future__ import print_function - -from future.utils import viewitems - -from builtins import object -from functools import cmp_to_key -from miasm2.expression.expression \ - import get_expr_mem, get_list_rw, ExprId, ExprInt, \ - compare_exprs -from miasm2.ir.symbexec import SymbolicExecutionEngine - - -def get_node_name(label, i, n): - n_name = (label, i, n) - return n_name - - -def intra_block_flow_raw(ir_arch, ircfg, flow_graph, irb, in_nodes, out_nodes): - """ - Create data flow for an irbloc using raw IR expressions - """ - current_nodes = {} - for i, assignblk in enumerate(irb): - dict_rw = assignblk.get_rw(cst_read=True) - current_nodes.update(out_nodes) - - # gen mem arg to mem node links - all_mems = set() - for node_w, nodes_r in viewitems(dict_rw): - for n in nodes_r.union([node_w]): - all_mems.update(get_expr_mem(n)) - if not all_mems: - continue - - for n in all_mems: - node_n_w = get_node_name(irb.loc_key, i, n) - if not n in nodes_r: - continue - o_r = n.ptr.get_r(mem_read=False, cst_read=True) - for n_r in o_r: - if n_r in current_nodes: - node_n_r = current_nodes[n_r] - else: - node_n_r = get_node_name(irb.loc_key, i, n_r) - current_nodes[n_r] = node_n_r - in_nodes[n_r] = node_n_r - flow_graph.add_uniq_edge(node_n_r, node_n_w) - - # gen data flow links - for node_w, nodes_r in viewitems(dict_rw): - for n_r in nodes_r: - if n_r in current_nodes: - node_n_r = current_nodes[n_r] - else: - node_n_r = get_node_name(irb.loc_key, i, n_r) - current_nodes[n_r] = node_n_r - in_nodes[n_r] = node_n_r - - flow_graph.add_node(node_n_r) - - node_n_w = get_node_name(irb.loc_key, i + 1, node_w) - out_nodes[node_w] = node_n_w - - flow_graph.add_node(node_n_w) - flow_graph.add_uniq_edge(node_n_r, node_n_w) - - - -def inter_block_flow_link(ir_arch, ircfg, flow_graph, irb_in_nodes, irb_out_nodes, todo, link_exec_to_data): - lbl, current_nodes, exec_nodes = todo - current_nodes = dict(current_nodes) - - # link current nodes to bloc in_nodes - if not lbl in ircfg.blocks: - print("cannot find bloc!!", lbl) - return set() - irb = ircfg.blocks[lbl] - to_del = set() - for n_r, node_n_r in viewitems(irb_in_nodes[irb.loc_key]): - if not n_r in current_nodes: - continue - flow_graph.add_uniq_edge(current_nodes[n_r], node_n_r) - to_del.add(n_r) - - # if link exec to data, all nodes depends on exec nodes - if link_exec_to_data: - for n_x_r in exec_nodes: - for n_r, node_n_r in viewitems(irb_in_nodes[irb.loc_key]): - if not n_x_r in current_nodes: - continue - if isinstance(n_r, ExprInt): - continue - flow_graph.add_uniq_edge(current_nodes[n_x_r], node_n_r) - - # update current nodes using bloc out_nodes - for n_w, node_n_w in viewitems(irb_out_nodes[irb.loc_key]): - current_nodes[n_w] = node_n_w - - # get nodes involved in exec flow - x_nodes = tuple(sorted(irb.dst.get_r(), key=cmp_to_key(compare_exprs))) - - todo = set() - for lbl_dst in ircfg.successors(irb.loc_key): - todo.add((lbl_dst, tuple(viewitems(current_nodes)), x_nodes)) - - return todo - - -def create_implicit_flow(ir_arch, flow_graph, irb_in_nodes, irb_out_ndes): - - # first fix IN/OUT - # If a son read a node which in not in OUT, add it - todo = set(ir_arch.blocks.keys()) - while todo: - lbl = todo.pop() - irb = ir_arch.blocks[lbl] - for lbl_son in ir_arch.graph.successors(irb.loc_key): - if not lbl_son in ir_arch.blocks: - print("cannot find bloc!!", lbl) - continue - irb_son = ir_arch.blocks[lbl_son] - for n_r in irb_in_nodes[irb_son.loc_key]: - if n_r in irb_out_nodes[irb.loc_key]: - continue - if not isinstance(n_r, ExprId): - continue - - node_n_w = irb.loc_key, len(irb), n_r - irb_out_nodes[irb.loc_key][n_r] = node_n_w - if not n_r in irb_in_nodes[irb.loc_key]: - irb_in_nodes[irb.loc_key][n_r] = irb.loc_key, 0, n_r - node_n_r = irb_in_nodes[irb.loc_key][n_r] - for lbl_p in ir_arch.graph.predecessors(irb.loc_key): - todo.add(lbl_p) - - flow_graph.add_uniq_edge(node_n_r, node_n_w) - - -def inter_block_flow(ir_arch, ircfg, flow_graph, irb_0, irb_in_nodes, irb_out_nodes, link_exec_to_data=True): - - todo = set() - done = set() - todo.add((irb_0, (), ())) - - while todo: - state = todo.pop() - if state in done: - continue - done.add(state) - out = inter_block_flow_link(ir_arch, ircfg, flow_graph, irb_in_nodes, irb_out_nodes, state, link_exec_to_data) - todo.update(out) - - -class symb_exec_func(object): - - """ - This algorithm will do symbolic execution on a function, trying to propagate - states between basic blocks in order to extract inter-blocs dataflow. The - algorithm tries to merge states from blocks with multiple parents. - - There is no real magic here, loops and complex merging will certainly fail. - """ - - def __init__(self, ir_arch): - self.todo = set() - self.stateby_ad = {} - self.cpt = {} - self.states_var_done = set() - self.states_done = set() - self.total_done = 0 - self.ir_arch = ir_arch - - def add_state(self, parent, ad, state): - variables = dict(state.symbols) - - # get bloc dead, and remove from state - b = self.ir_arch.get_block(ad) - if b is None: - raise ValueError("unknown bloc! %s" % ad) - s = parent, ad, tuple(sorted(viewitems(variables))) - self.todo.add(s) - - def get_next_state(self): - state = self.todo.pop() - return state - - def do_step(self): - if len(self.todo) == 0: - return None - if self.total_done > 600: - print("symbexec watchdog!") - return None - self.total_done += 1 - print('CPT', self.total_done) - while self.todo: - state = self.get_next_state() - parent, ad, s = state - self.states_done.add(state) - self.states_var_done.add(state) - - sb = SymbolicExecutionEngine(self.ir_arch, dict(s)) - - return parent, ad, sb - return None diff --git a/miasm2/analysis/data_flow.py b/miasm2/analysis/data_flow.py deleted file mode 100644 index 3874b21b..00000000 --- a/miasm2/analysis/data_flow.py +++ /dev/null @@ -1,1579 +0,0 @@ -"""Data flow analysis based on miasm intermediate representation""" -from builtins import range -from collections import namedtuple -from future.utils import viewitems, viewvalues -from miasm2.core.utils import encode_hex -from miasm2.core.graph import DiGraph -from miasm2.ir.ir import AssignBlock, IRBlock -from miasm2.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ - ExprAssign, ExprOp -from miasm2.expression.simplifications import expr_simp -from miasm2.core.interval import interval -from miasm2.expression.expression_helper import possible_values -from miasm2.analysis.ssa import get_phi_sources_parent_block, \ - irblock_has_phi - - -class ReachingDefinitions(dict): - """ - Computes for each assignblock the set of reaching definitions. - Example: - IR block: - lbl0: - 0 A = 1 - B = 3 - 1 B = 2 - 2 A = A + B + 4 - - Reach definition of lbl0: - (lbl0, 0) => {} - (lbl0, 1) => {A: {(lbl0, 0)}, B: {(lbl0, 0)}} - (lbl0, 2) => {A: {(lbl0, 0)}, B: {(lbl0, 1)}} - (lbl0, 3) => {A: {(lbl0, 2)}, B: {(lbl0, 1)}} - - Source set 'REACHES' in: Kennedy, K. (1979). - A survey of data flow analysis techniques. - IBM Thomas J. Watson Research Division, Algorithm MK - - This class is usable as a dictionary whose structure is - { (block, index): { lvalue: set((block, index)) } } - """ - - ircfg = None - - def __init__(self, ircfg): - super(ReachingDefinitions, self).__init__() - self.ircfg = ircfg - self.compute() - - def get_definitions(self, block_lbl, assignblk_index): - """Returns the dict { lvalue: set((def_block_lbl, def_index)) } - associated with self.ircfg.@block.assignblks[@assignblk_index] - or {} if it is not yet computed - """ - return self.get((block_lbl, assignblk_index), {}) - - def compute(self): - """This is the main fixpoint""" - modified = True - while modified: - modified = False - for block in viewvalues(self.ircfg.blocks): - modified |= self.process_block(block) - - def process_block(self, block): - """ - Fetch reach definitions from predecessors and propagate it to - the assignblk in block @block. - """ - predecessor_state = {} - for pred_lbl in self.ircfg.predecessors(block.loc_key): - pred = self.ircfg.blocks[pred_lbl] - for lval, definitions in viewitems(self.get_definitions(pred_lbl, len(pred))): - predecessor_state.setdefault(lval, set()).update(definitions) - - modified = self.get((block.loc_key, 0)) != predecessor_state - if not modified: - return False - self[(block.loc_key, 0)] = predecessor_state - - for index in range(len(block)): - modified |= self.process_assignblock(block, index) - return modified - - def process_assignblock(self, block, assignblk_index): - """ - Updates the reach definitions with values defined at - assignblock @assignblk_index in block @block. - NB: the effect of assignblock @assignblk_index in stored at index - (@block, @assignblk_index + 1). - """ - - assignblk = block[assignblk_index] - defs = self.get_definitions(block.loc_key, assignblk_index).copy() - for lval in assignblk: - defs.update({lval: set([(block.loc_key, assignblk_index)])}) - - modified = self.get((block.loc_key, assignblk_index + 1)) != defs - if modified: - self[(block.loc_key, assignblk_index + 1)] = defs - - return modified - -ATTR_DEP = {"color" : "black", - "_type" : "data"} - -AssignblkNode = namedtuple('AssignblkNode', ['label', 'index', 'var']) - - -class DiGraphDefUse(DiGraph): - """Representation of a Use-Definition graph as defined by - Kennedy, K. (1979). A survey of data flow analysis techniques. - IBM Thomas J. Watson Research Division. - Example: - IR block: - lbl0: - 0 A = 1 - B = 3 - 1 B = 2 - 2 A = A + B + 4 - - Def use analysis: - (lbl0, 0, A) => {(lbl0, 2, A)} - (lbl0, 0, B) => {} - (lbl0, 1, B) => {(lbl0, 2, A)} - (lbl0, 2, A) => {} - - """ - - - def __init__(self, reaching_defs, - deref_mem=False, *args, **kwargs): - """Instantiate a DiGraph - @blocks: IR blocks - """ - self._edge_attr = {} - - # For dot display - self._filter_node = None - self._dot_offset = None - self._blocks = reaching_defs.ircfg.blocks - - super(DiGraphDefUse, self).__init__(*args, **kwargs) - self._compute_def_use(reaching_defs, - deref_mem=deref_mem) - - def edge_attr(self, src, dst): - """ - Return a dictionary of attributes for the edge between @src and @dst - @src: the source node of the edge - @dst: the destination node of the edge - """ - return self._edge_attr[(src, dst)] - - def _compute_def_use(self, reaching_defs, - deref_mem=False): - for block in viewvalues(self._blocks): - self._compute_def_use_block(block, - reaching_defs, - deref_mem=deref_mem) - - def _compute_def_use_block(self, block, reaching_defs, deref_mem=False): - for index, assignblk in enumerate(block): - assignblk_reaching_defs = reaching_defs.get_definitions(block.loc_key, index) - for lval, expr in viewitems(assignblk): - self.add_node(AssignblkNode(block.loc_key, index, lval)) - - read_vars = expr.get_r(mem_read=deref_mem) - if deref_mem and lval.is_mem(): - read_vars.update(lval.ptr.get_r(mem_read=deref_mem)) - for read_var in read_vars: - for reach in assignblk_reaching_defs.get(read_var, set()): - self.add_data_edge(AssignblkNode(reach[0], reach[1], read_var), - AssignblkNode(block.loc_key, index, lval)) - - def del_edge(self, src, dst): - super(DiGraphDefUse, self).del_edge(src, dst) - del self._edge_attr[(src, dst)] - - def add_uniq_labeled_edge(self, src, dst, edge_label): - """Adds the edge (@src, @dst) with label @edge_label. - if edge (@src, @dst) already exists, the previous label is overridden - """ - self.add_uniq_edge(src, dst) - self._edge_attr[(src, dst)] = edge_label - - def add_data_edge(self, src, dst): - """Adds an edge representing a data dependencie - and sets the label accordingly""" - self.add_uniq_labeled_edge(src, dst, ATTR_DEP) - - def node2lines(self, node): - lbl, index, reg = node - yield self.DotCellDescription(text="%s (%s)" % (lbl, index), - attr={'align': 'center', - 'colspan': 2, - 'bgcolor': 'grey'}) - src = self._blocks[lbl][index][reg] - line = "%s = %s" % (reg, src) - yield self.DotCellDescription(text=line, attr={}) - yield self.DotCellDescription(text="", attr={}) - - -def dead_simp_useful_assignblks(irarch, defuse, reaching_defs): - """Mark useful statements using previous reach analysis and defuse - - Source : Kennedy, K. (1979). A survey of data flow analysis techniques. - IBM Thomas J. Watson Research Division, Algorithm MK - - Return a set of triplets (block, assignblk number, lvalue) of - useful definitions - PRE: compute_reach(self) - - """ - ircfg = reaching_defs.ircfg - useful = set() - - for block_lbl, block in viewitems(ircfg.blocks): - successors = ircfg.successors(block_lbl) - for successor in successors: - if successor not in ircfg.blocks: - keep_all_definitions = True - break - else: - keep_all_definitions = False - - # Block has a nonexistent successor or is a leaf - if keep_all_definitions or (len(successors) == 0): - valid_definitions = reaching_defs.get_definitions(block_lbl, - len(block)) - for lval, definitions in viewitems(valid_definitions): - if lval in irarch.get_out_regs(block) or keep_all_definitions: - for definition in definitions: - useful.add(AssignblkNode(definition[0], definition[1], lval)) - - # Force keeping of specific cases - for index, assignblk in enumerate(block): - for lval, rval in viewitems(assignblk): - if (lval.is_mem() or - irarch.IRDst == lval or - lval.is_id("exception_flags") or - rval.is_function_call()): - useful.add(AssignblkNode(block_lbl, index, lval)) - - # Useful nodes dependencies - for node in useful: - for parent in defuse.reachable_parents(node): - yield parent - - -def dead_simp(irarch, ircfg): - """ - Remove useless assignments. - - This function is used to analyse relation of a * complete function * - This means the blocks under study represent a solid full function graph. - - Source : Kennedy, K. (1979). A survey of data flow analysis techniques. - IBM Thomas J. Watson Research Division, page 43 - - @ircfg: IntermediateRepresentation instance - """ - - modified = False - reaching_defs = ReachingDefinitions(ircfg) - defuse = DiGraphDefUse(reaching_defs, deref_mem=True) - useful = set(dead_simp_useful_assignblks(irarch, defuse, reaching_defs)) - for block in list(viewvalues(ircfg.blocks)): - irs = [] - for idx, assignblk in enumerate(block): - new_assignblk = dict(assignblk) - for lval in assignblk: - if AssignblkNode(block.loc_key, idx, lval) not in useful: - del new_assignblk[lval] - modified = True - irs.append(AssignBlock(new_assignblk, assignblk.instr)) - ircfg.blocks[block.loc_key] = IRBlock(block.loc_key, irs) - return modified - - -def _test_merge_next_block(ircfg, loc_key): - """ - Test if the irblock at @loc_key can be merge with its son - @ircfg: IRCFG instance - @loc_key: LocKey instance of the candidate parent irblock - """ - - if loc_key not in ircfg.blocks: - return None - sons = ircfg.successors(loc_key) - if len(sons) != 1: - return None - son = list(sons)[0] - if ircfg.predecessors(son) != [loc_key]: - return None - if son not in ircfg.blocks: - return None - - return son - - -def _do_merge_blocks(ircfg, loc_key, son_loc_key): - """ - Merge two irblocks at @loc_key and @son_loc_key - - @ircfg: DiGrpahIR - @loc_key: LocKey instance of the parent IRBlock - @loc_key: LocKey instance of the son IRBlock - """ - - assignblks = [] - for assignblk in ircfg.blocks[loc_key]: - if ircfg.IRDst not in assignblk: - assignblks.append(assignblk) - continue - affs = {} - for dst, src in viewitems(assignblk): - if dst != ircfg.IRDst: - affs[dst] = src - if affs: - assignblks.append(AssignBlock(affs, assignblk.instr)) - - assignblks += ircfg.blocks[son_loc_key].assignblks - new_block = IRBlock(loc_key, assignblks) - - ircfg.discard_edge(loc_key, son_loc_key) - - for son_successor in ircfg.successors(son_loc_key): - ircfg.add_uniq_edge(loc_key, son_successor) - ircfg.discard_edge(son_loc_key, son_successor) - del ircfg.blocks[son_loc_key] - ircfg.del_node(son_loc_key) - ircfg.blocks[loc_key] = new_block - - -def _test_jmp_only(ircfg, loc_key, heads): - """ - If irblock at @loc_key sets only IRDst to an ExprLoc, return the - corresponding loc_key target. - Avoid creating predecssors for heads LocKeys - None in other cases. - - @ircfg: IRCFG instance - @loc_key: LocKey instance of the candidate irblock - @heads: LocKey heads of the graph - - """ - - if loc_key not in ircfg.blocks: - return None - irblock = ircfg.blocks[loc_key] - if len(irblock.assignblks) != 1: - return None - items = list(viewitems(dict(irblock.assignblks[0]))) - if len(items) != 1: - return None - if len(ircfg.successors(loc_key)) != 1: - return None - # Don't create predecessors on heads - dst, src = items[0] - assert dst.is_id("IRDst") - if not src.is_loc(): - return None - dst = src.loc_key - if loc_key in heads: - predecessors = set(ircfg.predecessors(dst)) - predecessors.difference_update(set([loc_key])) - if predecessors: - return None - return dst - - -def _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct): - """ - Link loc_key's parents to parents directly to son_loc_key - """ - for parent in set(ircfg.predecessors(loc_key)): - parent_block = ircfg.blocks.get(parent, None) - if parent_block is None: - continue - - new_block = parent_block.modify_exprs( - lambda expr:expr.replace_expr(replace_dct), - lambda expr:expr.replace_expr(replace_dct) - ) - - # Link parent to new dst - ircfg.add_uniq_edge(parent, son_loc_key) - - # Unlink block - ircfg.blocks[new_block.loc_key] = new_block - ircfg.del_node(loc_key) - - -def _remove_to_son(ircfg, loc_key, son_loc_key): - """ - Merge irblocks; The final block has the @son_loc_key loc_key - Update references - - Condition: - - irblock at @loc_key is a pure jump block - - @loc_key is not an entry point (can be removed) - - @irblock: IRCFG instance - @loc_key: LocKey instance of the parent irblock - @son_loc_key: LocKey instance of the son irblock - """ - - # Ircfg loop => don't mess - if loc_key == son_loc_key: - return False - - # Unlink block destinations - ircfg.del_edge(loc_key, son_loc_key) - - replace_dct = { - ExprLoc(loc_key, ircfg.IRDst.size):ExprLoc(son_loc_key, ircfg.IRDst.size) - } - - _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct) - - ircfg.del_node(loc_key) - del ircfg.blocks[loc_key] - - return True - - -def _remove_to_parent(ircfg, loc_key, son_loc_key): - """ - Merge irblocks; The final block has the @loc_key loc_key - Update references - - Condition: - - irblock at @loc_key is a pure jump block - - @son_loc_key is not an entry point (can be removed) - - @irblock: IRCFG instance - @loc_key: LocKey instance of the parent irblock - @son_loc_key: LocKey instance of the son irblock - """ - - # Ircfg loop => don't mess - if loc_key == son_loc_key: - return False - - # Unlink block destinations - ircfg.del_edge(loc_key, son_loc_key) - - old_irblock = ircfg.blocks[son_loc_key] - new_irblock = IRBlock(loc_key, old_irblock.assignblks) - - ircfg.blocks[son_loc_key] = new_irblock - - ircfg.add_irblock(new_irblock) - - replace_dct = { - ExprLoc(son_loc_key, ircfg.IRDst.size):ExprLoc(loc_key, ircfg.IRDst.size) - } - - _relink_block_node(ircfg, son_loc_key, loc_key, replace_dct) - - - ircfg.del_node(son_loc_key) - del ircfg.blocks[son_loc_key] - - return True - - -def merge_blocks(ircfg, heads): - """ - This function modifies @ircfg to apply the following transformations: - - group an irblock with its son if the irblock has one and only one son and - this son has one and only one parent (spaghetti code). - - if an irblock is only made of an assignment to IRDst with a given label, - this irblock is dropped and its parent destination targets are - updated. The irblock must have a parent (avoid deleting the function head) - - if an irblock is a head of the graph and is only made of an assignment to - IRDst with a given label, this irblock is dropped and its son becomes the - head. References are fixed - - This function avoid creating predecessors on heads - - Return True if at least an irblock has been modified - - @ircfg: IRCFG instance - @heads: loc_key to keep - """ - - modified = False - todo = set(ircfg.nodes()) - while todo: - loc_key = todo.pop() - - # Test merge block - son = _test_merge_next_block(ircfg, loc_key) - if son is not None and son not in heads: - _do_merge_blocks(ircfg, loc_key, son) - todo.add(loc_key) - modified = True - continue - - # Test jmp only block - son = _test_jmp_only(ircfg, loc_key, heads) - if son is not None and loc_key not in heads: - ret = _remove_to_son(ircfg, loc_key, son) - modified |= ret - if ret: - todo.add(loc_key) - continue - - # Test head jmp only block - if (son is not None and - son not in heads and - son in ircfg.blocks): - # jmp only test done previously - ret = _remove_to_parent(ircfg, loc_key, son) - modified |= ret - if ret: - todo.add(loc_key) - continue - - - return modified - - -def remove_empty_assignblks(ircfg): - """ - Remove empty assignblks in irblocks of @ircfg - Return True if at least an irblock has been modified - - @ircfg: IRCFG instance - """ - modified = False - for loc_key, block in list(viewitems(ircfg.blocks)): - irs = [] - block_modified = False - for assignblk in block: - if len(assignblk): - irs.append(assignblk) - else: - block_modified = True - if block_modified: - new_irblock = IRBlock(loc_key, irs) - ircfg.blocks[loc_key] = new_irblock - modified = True - return modified - - -class SSADefUse(DiGraph): - """ - Generate DefUse information from SSA transformation - Links are not valid for ExprMem. - """ - - def add_var_def(self, node, src): - index2dst = self._links.setdefault(node.label, {}) - dst2src = index2dst.setdefault(node.index, {}) - dst2src[node.var] = src - - def add_def_node(self, def_nodes, node, src): - if node.var.is_id(): - def_nodes[node.var] = node - - def add_use_node(self, use_nodes, node, src): - sources = set() - if node.var.is_mem(): - sources.update(node.var.ptr.get_r(mem_read=True)) - sources.update(src.get_r(mem_read=True)) - for source in sources: - if not source.is_mem(): - use_nodes.setdefault(source, set()).add(node) - - def get_node_target(self, node): - return self._links[node.label][node.index][node.var] - - def set_node_target(self, node, src): - self._links[node.label][node.index][node.var] = src - - @classmethod - def from_ssa(cls, ssa): - """ - Return a DefUse DiGraph from a SSA graph - @ssa: SSADiGraph instance - """ - - graph = cls() - # First pass - # Link line to its use and def - def_nodes = {} - use_nodes = {} - graph._links = {} - for lbl in ssa.graph.nodes(): - block = ssa.graph.blocks.get(lbl, None) - if block is None: - continue - for index, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - node = AssignblkNode(lbl, index, dst) - graph.add_var_def(node, src) - graph.add_def_node(def_nodes, node, src) - graph.add_use_node(use_nodes, node, src) - - for dst, node in viewitems(def_nodes): - graph.add_node(node) - if dst not in use_nodes: - continue - for use in use_nodes[dst]: - graph.add_uniq_edge(node, use) - - return graph - - - - -def expr_test_visit(expr, test): - result = set() - expr.visit( - lambda expr: expr, - lambda expr: test(expr, result) - ) - if result: - return True - else: - return False - - -def expr_has_mem_test(expr, result): - if result: - # Don't analyse if we already found a candidate - return False - if expr.is_mem(): - result.add(expr) - return False - return True - - -def expr_has_mem(expr): - """ - Return True if expr contains at least one memory access - @expr: Expr instance - """ - return expr_test_visit(expr, expr_has_mem_test) - - -class PropagateThroughExprId(object): - """ - Propagate expressions though ExprId - """ - - def has_propagation_barrier(self, assignblks): - """ - Return True if propagation cannot cross the @assignblks - @assignblks: list of AssignBlock to check - """ - for assignblk in assignblks: - for dst, src in viewitems(assignblk): - if src.is_function_call(): - return True - if dst.is_mem(): - return True - return False - - def is_mem_written(self, ssa, node_a, node_b): - """ - Return True if memory is written at least once between @node_a and - @node_b - - @node: AssignblkNode representing the start position - @successor: AssignblkNode representing the end position - """ - - block_b = ssa.graph.blocks[node_b.label] - nodes_to_do = self.compute_reachable_nodes_from_a_to_b(ssa.graph, node_a.label, node_b.label) - - if node_a.label == node_b.label: - # src is dst - assert nodes_to_do == set([node_a.label]) - if self.has_propagation_barrier(block_b.assignblks[node_a.index:node_b.index]): - return True - else: - # Check everyone but node_a.label and node_b.label - for loc in nodes_to_do - set([node_a.label, node_b.label]): - block = ssa.graph.blocks[loc] - if self.has_propagation_barrier(block.assignblks): - return True - # Check node_a.label partially - block_a = ssa.graph.blocks[node_a.label] - if self.has_propagation_barrier(block_a.assignblks[node_a.index:]): - return True - if nodes_to_do.intersection(ssa.graph.successors(node_b.label)): - # There is a path from node_b.label to node_b.label => Check node_b.label fully - if self.has_propagation_barrier(block_b.assignblks): - return True - else: - # Check node_b.label partially - if self.has_propagation_barrier(block_b.assignblks[:node_b.index]): - return True - return False - - def compute_reachable_nodes_from_a_to_b(self, ssa, loc_a, loc_b): - reachables_a = set(ssa.reachable_sons(loc_a)) - reachables_b = set(ssa.reachable_parents_stop_node(loc_b, loc_a)) - return reachables_a.intersection(reachables_b) - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Return True if we can replace @node_a source present in @to_replace into - @node_b - - @node_a: AssignblkNode position - @node_b: AssignblkNode position - """ - if not expr_has_mem(to_replace[node_a.var]): - return True - if self.is_mem_written(ssa, node_a, node_b): - return False - return True - - - def get_var_definitions(self, ssa): - """ - Return a dictionary linking variable to its assignment location - @ssa: SSADiGraph instance - """ - ircfg = ssa.graph - def_dct = {} - for node in ircfg.nodes(): - for index, assignblk in enumerate(ircfg.blocks[node]): - for dst, src in viewitems(assignblk): - if not dst.is_id(): - continue - if dst in ssa.immutable_ids: - continue - assert dst not in def_dct - def_dct[dst] = node, index - return def_dct - - def phi_has_identical_sources(self, ssa, def_dct, var): - """ - If phi operation has identical source values, return it; else None - @ssa: SSADiGraph instance - @def_dct: dictionary linking variable to its assignment location - @var: Phi destination variable - """ - loc_key, index = def_dct[var] - sources = ssa.graph.blocks[loc_key][index][var] - assert sources.is_op('Phi') - sources_values = set() - for src in sources.args: - assert src in def_dct - loc_key, index = def_dct[src] - value = ssa.graph.blocks[loc_key][index][src] - sources_values.add(value) - if len(sources_values) != 1: - return None - return list(sources_values)[0] - - def get_candidates(self, ssa, head, max_expr_depth): - def_dct = self.get_var_definitions(ssa) - defuse = SSADefUse.from_ssa(ssa) - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - if node.var in ssa.immutable_ids: - continue - src = defuse.get_node_target(node) - if max_expr_depth is not None and len(str(src)) > max_expr_depth: - continue - if src.is_function_call(): - continue - if node.var.is_mem(): - continue - if src.is_op('Phi'): - ret = self.phi_has_identical_sources(ssa, def_dct, node.var) - if ret: - to_replace[node.var] = ret - node_to_reg[node] = node.var - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagate(self, ssa, head, max_expr_depth=None): - """ - Do expression propagation - @ssa: SSADiGraph instance - @head: the head location of the graph - @max_expr_depth: the maximum allowed depth of an expression - """ - node_to_reg, to_replace, defuse = self.get_candidates(ssa, head, max_expr_depth) - modified = False - for node, reg in viewitems(node_to_reg): - for successor in defuse.successors(node): - if not self.propagation_allowed(ssa, to_replace, node, successor): - continue - - node_a = node - node_b = successor - block = ssa.graph.blocks[node_b.label] - - replace = {node_a.var: to_replace[node_a.var]} - # Replace - assignblks = list(block) - assignblk = block[node_b.index] - out = {} - for dst, src in viewitems(assignblk): - if src.is_op('Phi'): - out[dst] = src - continue - - if src.is_mem(): - ptr = src.ptr.replace_expr(replace) - new_src = ExprMem(ptr, src.size) - else: - new_src = src.replace_expr(replace) - - if dst.is_id(): - new_dst = dst - elif dst.is_mem(): - ptr = dst.ptr.replace_expr(replace) - new_dst = ExprMem(ptr, dst.size) - else: - new_dst = dst.replace_expr(replace) - if not (new_dst.is_id() or new_dst.is_mem()): - new_dst = dst - if src != new_src or dst != new_dst: - modified = True - out[new_dst] = new_src - out = AssignBlock(out, assignblk.instr) - assignblks[node_b.index] = out - new_block = IRBlock(block.loc_key, assignblks) - ssa.graph.blocks[block.loc_key] = new_block - - return modified - - - -class PropagateExprIntThroughExprId(PropagateThroughExprId): - """ - Propagate ExprInt though ExprId: classic constant propagation - This is a sub family of PropagateThroughExprId. - It reduces leaves in expressions of a program. - """ - - def get_candidates(self, ssa, head, max_expr_depth): - defuse = SSADefUse.from_ssa(ssa) - - to_replace = {} - node_to_reg = {} - for node in defuse.nodes(): - src = defuse.get_node_target(node) - if not src.is_int(): - continue - if src.is_function_call(): - continue - if node.var.is_mem(): - continue - to_replace[node.var] = src - node_to_reg[node] = node.var - return node_to_reg, to_replace, defuse - - def propagation_allowed(self, ssa, to_replace, node_a, node_b): - """ - Propagating ExprInt is always ok - """ - return True - - -class PropagateThroughExprMem(object): - """ - Propagate through ExprMem in very simple cases: - - if no memory write between source and target - - if source does not contain any memory reference - """ - - def propagate(self, ssa, head, max_expr_depth=None): - ircfg = ssa.graph - todo = set() - modified = False - for block in viewvalues(ircfg.blocks): - for i, assignblk in enumerate(block): - for dst, src in viewitems(assignblk): - if not dst.is_mem(): - continue - if expr_has_mem(src): - continue - todo.add((block.loc_key, i + 1, dst, src)) - ptr = dst.ptr - for size in range(8, dst.size, 8): - todo.add((block.loc_key, i + 1, ExprMem(ptr, size), src[:size])) - - while todo: - loc_key, index, mem_dst, mem_src = todo.pop() - block = ircfg.blocks[loc_key] - assignblks = list(block) - block_modified = False - for i in range(index, len(block)): - assignblk = block[i] - write_mem = False - assignblk_modified = False - out = dict(assignblk) - out_new = {} - for dst, src in viewitems(out): - if dst.is_mem(): - write_mem = True - ptr = dst.ptr.replace_expr({mem_dst:mem_src}) - dst = ExprMem(ptr, dst.size) - src = src.replace_expr({mem_dst:mem_src}) - out_new[dst] = src - if out != out_new: - assignblk_modified = True - - if assignblk_modified: - assignblks[i] = AssignBlock(out_new, assignblk.instr) - block_modified = True - if write_mem: - break - else: - # If no memory written, we may propagate to sons - # if son has only parent - for successor in ircfg.successors(loc_key): - predecessors = ircfg.predecessors(successor) - if len(predecessors) != 1: - continue - todo.add((successor, 0, mem_dst, mem_src)) - - if block_modified: - modified = True - new_block = IRBlock(block.loc_key, assignblks) - ircfg.blocks[block.loc_key] = new_block - return modified - - -def stack_to_reg(expr): - if expr.is_mem(): - ptr = expr.arg - SP = ir_arch_a.sp - if ptr == SP: - return ExprId("STACK.0", expr.size) - elif (ptr.is_op('+') and - len(ptr.args) == 2 and - ptr.args[0] == SP and - ptr.args[1].is_int()): - diff = int(ptr.args[1]) - assert diff % 4 == 0 - diff = (0 - diff) & 0xFFFFFFFF - return ExprId("STACK.%d" % (diff // 4), expr.size) - return False - - -def is_stack_access(ir_arch_a, expr): - if not expr.is_mem(): - return False - ptr = expr.ptr - diff = expr_simp(ptr - ir_arch_a.sp) - if not diff.is_int(): - return False - return expr - - -def visitor_get_stack_accesses(ir_arch_a, expr, stack_vars): - if is_stack_access(ir_arch_a, expr): - stack_vars.add(expr) - return expr - - -def get_stack_accesses(ir_arch_a, expr): - result = set() - expr.visit(lambda expr:visitor_get_stack_accesses(ir_arch_a, expr, result)) - return result - - -def get_interval_length(interval_in): - length = 0 - for start, stop in interval_in.intervals: - length += stop + 1 - start - return length - - -def check_expr_below_stack(ir_arch_a, expr): - """ - Return False if expr pointer is below original stack pointer - @ir_arch_a: ira instance - @expr: Expression instance - """ - ptr = expr.ptr - diff = expr_simp(ptr - ir_arch_a.sp) - if not diff.is_int(): - return True - if int(diff) == 0 or int(expr_simp(diff.msb())) == 0: - return False - return True - - -def retrieve_stack_accesses(ir_arch_a, ircfg): - """ - Walk the ssa graph and find stack based variables. - Return a dictionary linking stack base address to its size/name - @ir_arch_a: ira instance - @ircfg: IRCFG instance - """ - stack_vars = set() - for block in viewvalues(ircfg.blocks): - for assignblk in block: - for dst, src in viewitems(assignblk): - stack_vars.update(get_stack_accesses(ir_arch_a, dst)) - stack_vars.update(get_stack_accesses(ir_arch_a, src)) - stack_vars = [expr for expr in stack_vars if check_expr_below_stack(ir_arch_a, expr)] - - base_to_var = {} - for var in stack_vars: - base_to_var.setdefault(var.ptr, set()).add(var) - - - base_to_interval = {} - for addr, vars in viewitems(base_to_var): - var_interval = interval() - for var in vars: - offset = expr_simp(addr - ir_arch_a.sp) - if not offset.is_int(): - # skip non linear stack offset - continue - - start = int(offset) - stop = int(expr_simp(offset + ExprInt(var.size // 8, offset.size))) - mem = interval([(start, stop-1)]) - var_interval += mem - base_to_interval[addr] = var_interval - if not base_to_interval: - return {} - # Check if not intervals overlap - _, tmp = base_to_interval.popitem() - while base_to_interval: - addr, mem = base_to_interval.popitem() - assert (tmp & mem).empty - tmp += mem - - base_to_info = {} - for addr, vars in viewitems(base_to_var): - name = "var_%d" % (len(base_to_info)) - size = max([var.size for var in vars]) - base_to_info[addr] = size, name - return base_to_info - - -def fix_stack_vars(expr, base_to_info): - """ - Replace local stack accesses in expr using information in @base_to_info - @expr: Expression instance - @base_to_info: dictionary linking stack base address to its size/name - """ - if not expr.is_mem(): - return expr - ptr = expr.ptr - if ptr not in base_to_info: - return expr - size, name = base_to_info[ptr] - var = ExprId(name, size) - if size == expr.size: - return var - assert expr.size < size - return var[:expr.size] - - -def replace_mem_stack_vars(expr, base_to_info): - return expr.visit(lambda expr:fix_stack_vars(expr, base_to_info)) - - -def replace_stack_vars(ir_arch_a, ircfg): - """ - Try to replace stack based memory accesses by variables. - - Hypothesis: the input ircfg must have all it's accesses to stack explicitly - done through the stack register, ie every aliases on those variables is - resolved. - - WARNING: may fail - - @ir_arch_a: ira instance - @ircfg: IRCFG instance - """ - - base_to_info = retrieve_stack_accesses(ir_arch_a, ircfg) - modified = False - for block in list(viewvalues(ircfg.blocks)): - assignblks = [] - for assignblk in block: - out = {} - for dst, src in viewitems(assignblk): - new_dst = dst.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) - new_src = src.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) - if new_dst != dst or new_src != src: - modified |= True - - out[new_dst] = new_src - - out = AssignBlock(out, assignblk.instr) - assignblks.append(out) - new_block = IRBlock(block.loc_key, assignblks) - ircfg.blocks[block.loc_key] = new_block - return modified - - -def memlookup_test(expr, bs, is_addr_ro_variable, result): - if expr.is_mem() and expr.ptr.is_int(): - ptr = int(expr.ptr) - if is_addr_ro_variable(bs, ptr, expr.size): - result.add(expr) - return False - return True - - -def memlookup_visit(expr, bs, is_addr_ro_variable): - result = set() - expr.visit(lambda expr: expr, - lambda expr: memlookup_test(expr, bs, is_addr_ro_variable, result)) - return result - - -def get_memlookup(expr, bs, is_addr_ro_variable): - return memlookup_visit(expr, bs, is_addr_ro_variable) - - -def read_mem(bs, expr): - ptr = int(expr.ptr) - var_bytes = bs.getbytes(ptr, expr.size // 8)[::-1] - try: - value = int(encode_hex(var_bytes), 16) - except ValueError: - return expr - return ExprInt(value, expr.size) - - -def load_from_int(ir_arch, bs, is_addr_ro_variable): - """ - Replace memory read based on constant with static value - @ir_arch: ira instance - @bs: binstream instance - @is_addr_ro_variable: callback(addr, size) to test memory candidate - """ - - modified = False - for block in list(viewvalues(ir_arch.blocks)): - assignblks = list() - for assignblk in block: - out = {} - for dst, src in viewitems(assignblk): - # Test src - mems = get_memlookup(src, bs, is_addr_ro_variable) - src_new = src - if mems: - replace = {} - for mem in mems: - value = read_mem(bs, mem) - replace[mem] = value - src_new = src.replace_expr(replace) - if src_new != src: - modified = True - # Test dst pointer if dst is mem - if dst.is_mem(): - ptr = dst.ptr - mems = get_memlookup(ptr, bs, is_addr_ro_variable) - if mems: - replace = {} - for mem in mems: - value = read_mem(bs, mem) - replace[mem] = value - ptr_new = ptr.replace_expr(replace) - if ptr_new != ptr: - modified = True - dst = ExprMem(ptr_new, dst.size) - out[dst] = src_new - out = AssignBlock(out, assignblk.instr) - assignblks.append(out) - block = IRBlock(block.loc_key, assignblks) - ir_arch.blocks[block.loc_key] = block - return modified - - -class AssignBlockLivenessInfos(object): - """ - Description of live in / live out of an AssignBlock - """ - - __slots__ = ["gen", "kill", "var_in", "var_out", "live", "assignblk"] - - def __init__(self, assignblk, gen, kill): - self.gen = gen - self.kill = kill - self.var_in = set() - self.var_out = set() - self.live = set() - self.assignblk = assignblk - - def __str__(self): - out = [] - out.append("\tVarIn:" + ", ".join(str(x) for x in self.var_in)) - out.append("\tGen:" + ", ".join(str(x) for x in self.gen)) - out.append("\tKill:" + ", ".join(str(x) for x in self.kill)) - out.append( - '\n'.join( - "\t%s = %s" % (dst, src) - for (dst, src) in viewitems(self.assignblk) - ) - ) - out.append("\tVarOut:" + ", ".join(str(x) for x in self.var_out)) - return '\n'.join(out) - - -class IRBlockLivenessInfos(object): - """ - Description of live in / live out of an AssignBlock - """ - __slots__ = ["loc_key", "infos", "assignblks"] - - - def __init__(self, irblock): - self.loc_key = irblock.loc_key - self.infos = [] - self.assignblks = [] - for assignblk in irblock: - gens, kills = set(), set() - for dst, src in viewitems(assignblk): - expr = ExprAssign(dst, src) - read = expr.get_r(mem_read=True) - write = expr.get_w() - gens.update(read) - kills.update(write) - self.infos.append(AssignBlockLivenessInfos(assignblk, gens, kills)) - self.assignblks.append(assignblk) - - def __getitem__(self, index): - """Getitem on assignblks""" - return self.assignblks.__getitem__(index) - - def __str__(self): - out = [] - out.append("%s:" % self.loc_key) - for info in self.infos: - out.append(str(info)) - out.append('') - return "\n".join(out) - - -class DiGraphLiveness(DiGraph): - """ - DiGraph representing variable liveness - """ - - def __init__(self, ircfg, loc_db=None): - super(DiGraphLiveness, self).__init__() - self.ircfg = ircfg - self.loc_db = loc_db - self._blocks = {} - # Add irblocks gen/kill - for node in ircfg.nodes(): - irblock = ircfg.blocks[node] - irblockinfos = IRBlockLivenessInfos(irblock) - self.add_node(irblockinfos.loc_key) - self.blocks[irblockinfos.loc_key] = irblockinfos - for succ in ircfg.successors(node): - self.add_uniq_edge(node, succ) - for pred in ircfg.predecessors(node): - self.add_uniq_edge(pred, node) - - @property - def blocks(self): - return self._blocks - - def init_var_info(self): - """Add ircfg out regs""" - raise NotImplementedError("Abstract method") - - def node2lines(self, node): - """ - Output liveness information in dot format - """ - if self.loc_db is None: - node_name = str(node) - else: - names = self.loc_db.get_location_names(node) - if not names: - node_name = self.loc_db.pretty_str(node) - else: - node_name = "".join("%s:\n" % name for name in names) - yield self.DotCellDescription( - text="%s" % node_name, - attr={ - 'align': 'center', - 'colspan': 2, - 'bgcolor': 'grey', - } - ) - if node not in self._blocks: - yield [self.DotCellDescription(text="NOT PRESENT", attr={})] - return - - for i, info in enumerate(self._blocks[node].infos): - var_in = "VarIn:" + ", ".join(str(x) for x in info.var_in) - var_out = "VarOut:" + ", ".join(str(x) for x in info.var_out) - - assignmnts = ["%s = %s" % (dst, src) for (dst, src) in viewitems(info.assignblk)] - - if i == 0: - yield self.DotCellDescription( - text=var_in, - attr={ - 'bgcolor': 'green', - } - ) - - for assign in assignmnts: - yield self.DotCellDescription(text=assign, attr={}) - yield self.DotCellDescription( - text=var_out, - attr={ - 'bgcolor': 'green', - } - ) - yield self.DotCellDescription(text="", attr={}) - - def back_propagate_compute(self, block): - """ - Compute the liveness information in the @block. - @block: AssignBlockLivenessInfos instance - """ - infos = block.infos - modified = False - for i in reversed(range(len(infos))): - new_vars = set(infos[i].gen.union(infos[i].var_out.difference(infos[i].kill))) - if infos[i].var_in != new_vars: - modified = True - infos[i].var_in = new_vars - if i > 0 and infos[i - 1].var_out != set(infos[i].var_in): - modified = True - infos[i - 1].var_out = set(infos[i].var_in) - return modified - - def back_propagate_to_parent(self, todo, node, parent): - """ - Back propagate the liveness information from @node to @parent. - @node: loc_key of the source node - @parent: loc_key of the node to update - """ - parent_block = self.blocks[parent] - cur_block = self.blocks[node] - if cur_block.infos[0].var_in == parent_block.infos[-1].var_out: - return - var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out) - parent_block.infos[-1].var_out = var_info - todo.add(parent) - - def compute_liveness(self): - """ - Compute the liveness information for the digraph. - """ - todo = set(self.leaves()) - while todo: - node = todo.pop() - cur_block = self.blocks[node] - modified = self.back_propagate_compute(cur_block) - if not modified: - continue - # We modified parent in, propagate to parents - for pred in self.predecessors(node): - self.back_propagate_to_parent(todo, node, pred) - return True - - -class DiGraphLivenessIRA(DiGraphLiveness): - """ - DiGraph representing variable liveness for IRA - """ - - def init_var_info(self, ir_arch_a): - """Add ircfg out regs""" - - for node in self.leaves(): - irblock = self.ircfg.blocks[node] - var_out = ir_arch_a.get_out_regs(irblock) - irblock_liveness = self.blocks[node] - irblock_liveness.infos[-1].var_out = var_out - - -def discard_phi_sources(ircfg, deleted_vars): - """ - Remove phi sources in @ircfg belonging to @deleted_vars set - @ircfg: IRCFG instance in ssa form - @deleted_vars: unused phi sources - """ - for block in list(viewvalues(ircfg.blocks)): - if not block.assignblks: - continue - assignblk = block[0] - todo = {} - modified = False - for dst, src in viewitems(assignblk): - if not src.is_op('Phi'): - todo[dst] = src - continue - srcs = set(expr for expr in src.args if expr not in deleted_vars) - assert(srcs) - if len(srcs) > 1: - todo[dst] = srcs - continue - todo[dst] = srcs.pop() - modified = True - if not modified: - continue - assignblks = list(block) - assignblk = dict(assignblk) - assignblk.update(todo) - assignblk = AssignBlock(assignblk, assignblks[0].instr) - assignblks[0] = assignblk - new_irblock = IRBlock(block.loc_key, assignblks) - ircfg.blocks[block.loc_key] = new_irblock - return True - - -def get_unreachable_nodes(ircfg, edges_to_del, heads): - """ - Return the unreachable nodes starting from heads and the associated edges to - be deleted. - - @ircfg: IRCFG instance - @edges_to_del: edges already marked as deleted - heads: locations of graph heads - """ - todo = set(heads) - visited_nodes = set() - new_edges_to_del = set() - while todo: - node = todo.pop() - if node in visited_nodes: - continue - visited_nodes.add(node) - for successor in ircfg.successors(node): - if (node, successor) not in edges_to_del: - todo.add(successor) - all_nodes = set(ircfg.nodes()) - nodes_to_del = all_nodes.difference(visited_nodes) - for node in nodes_to_del: - for successor in ircfg.successors(node): - if successor not in nodes_to_del: - # Frontier: link from a deleted node to a living node - new_edges_to_del.add((node, successor)) - return nodes_to_del, new_edges_to_del - - -def update_phi_with_deleted_edges(ircfg, edges_to_del): - """ - Update phi which have a source present in @edges_to_del - @ssa: IRCFG instance in ssa form - @edges_to_del: edges to delete - """ - - modified = False - blocks = dict(ircfg.blocks) - for loc_src, loc_dst in edges_to_del: - block = ircfg.blocks[loc_dst] - assert block.assignblks - assignblks = list(block) - assignblk = assignblks[0] - out = {} - for dst, phi_sources in viewitems(assignblk): - if not phi_sources.is_op('Phi'): - out = assignblk - break - var_to_parents = get_phi_sources_parent_block( - ircfg, - loc_dst, - phi_sources.args - ) - to_keep = set(phi_sources.args) - for src in phi_sources.args: - parents = var_to_parents[src] - if loc_src in parents: - to_keep.discard(src) - modified = True - assert to_keep - if len(to_keep) == 1: - out[dst] = to_keep.pop() - else: - out[dst] = ExprOp('Phi', *to_keep) - assignblk = AssignBlock(out, assignblks[0].instr) - assignblks[0] = assignblk - new_irblock = IRBlock(loc_dst, assignblks) - blocks[block.loc_key] = new_irblock - - for loc_key, block in viewitems(blocks): - ircfg.blocks[loc_key] = block - return modified - - -def del_unused_edges(ircfg, heads): - """ - Delete non accessible edges in the @ircfg graph. - @ircfg: IRCFG instance in ssa form - @heads: location of the heads of the graph - """ - - deleted_vars = set() - modified = False - edges_to_del_1 = set() - for node in ircfg.nodes(): - successors = set(ircfg.successors(node)) - block = ircfg.blocks[node] - dst = block.dst - possible_dsts = set(solution.value for solution in possible_values(dst)) - if not all(dst.is_loc() for dst in possible_dsts): - continue - possible_dsts = set(dst.loc_key for dst in possible_dsts) - if len(possible_dsts) == len(successors): - continue - dsts_to_del = successors.difference(possible_dsts) - for dst in dsts_to_del: - edges_to_del_1.add((node, dst)) - - # Remove edges and update phi accordingly - # Two cases here: - # - edge is directly linked to a phi node - # - edge is indirect linked to a phi node - nodes_to_del, edges_to_del_2 = get_unreachable_nodes(ircfg, edges_to_del_1, heads) - modified |= update_phi_with_deleted_edges(ircfg, edges_to_del_1.union(edges_to_del_2)) - - for src, dst in edges_to_del_1.union(edges_to_del_2): - ircfg.del_edge(src, dst) - for node in nodes_to_del: - block = ircfg.blocks[node] - ircfg.del_node(node) - for assignblock in block: - for dst in assignblock: - deleted_vars.add(dst) - - if deleted_vars: - modified |= discard_phi_sources(ircfg, deleted_vars) - - return modified - - -class DiGraphLivenessSSA(DiGraphLivenessIRA): - """ - DiGraph representing variable liveness is a SSA graph - """ - def __init__(self, ircfg): - super(DiGraphLivenessSSA, self).__init__(ircfg) - - self.loc_key_to_phi_parents = {} - for irblock in viewvalues(self.blocks): - if not irblock_has_phi(irblock): - continue - out = {} - for sources in viewvalues(irblock[0]): - var_to_parents = get_phi_sources_parent_block(self, irblock.loc_key, sources.args) - for var, var_parents in viewitems(var_to_parents): - out.setdefault(var, set()).update(var_parents) - self.loc_key_to_phi_parents[irblock.loc_key] = out - - def back_propagate_to_parent(self, todo, node, parent): - parent_block = self.blocks[parent] - cur_block = self.blocks[node] - irblock = self.ircfg.blocks[node] - if cur_block.infos[0].var_in == parent_block.infos[-1].var_out: - return - var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out) - - if irblock_has_phi(irblock): - # Remove phi special case - out = set() - phi_sources = self.loc_key_to_phi_parents[irblock.loc_key] - for var in var_info: - if var not in phi_sources: - out.add(var) - continue - if parent in phi_sources[var]: - out.add(var) - var_info = out - - parent_block.infos[-1].var_out = var_info - todo.add(parent) diff --git a/miasm2/analysis/debugging.py b/miasm2/analysis/debugging.py deleted file mode 100644 index 824b62ce..00000000 --- a/miasm2/analysis/debugging.py +++ /dev/null @@ -1,499 +0,0 @@ -from __future__ import print_function -from builtins import map -from builtins import range -import cmd -from future.utils import viewitems - -from miasm2.core.utils import hexdump -from miasm2.core.interval import interval -import miasm2.jitter.csts as csts -from miasm2.jitter.jitload import ExceptionHandle - - -class DebugBreakpoint(object): - - "Debug Breakpoint parent class" - pass - - -class DebugBreakpointSoft(DebugBreakpoint): - - "Stand for software breakpoint" - - def __init__(self, addr): - self.addr = addr - - def __str__(self): - return "Soft BP @0x%08x" % self.addr - - -class DebugBreakpointTerminate(DebugBreakpoint): - "Stand for an execution termination" - - def __init__(self, status): - self.status = status - - def __str__(self): - return "Terminate with %s" % self.status - - -class DebugBreakpointMemory(DebugBreakpoint): - - "Stand for memory breakpoint" - - type2str = {csts.BREAKPOINT_READ: "R", - csts.BREAKPOINT_WRITE: "W"} - - def __init__(self, addr, size, access_type): - self.addr = addr - self.access_type = access_type - self.size = size - - def __str__(self): - bp_type = "" - for k, v in viewitems(self.type2str): - if k & self.access_type != 0: - bp_type += v - return "Memory BP @0x%08x, Size 0x%08x, Type %s" % ( - self.addr, - self.size, - bp_type - ) - - @classmethod - def get_access_type(cls, read=False, write=False): - value = 0 - for k, v in viewitems(cls.type2str): - if v == "R" and read is True: - value += k - if v == "W" and write is True: - value += k - return value - - -class Debugguer(object): - - "Debugguer linked with a Jitter instance" - - def __init__(self, myjit): - "myjit : jitter instance" - self.myjit = myjit - self.bp_list = [] # DebugBreakpointSoft list - self.hw_bp_list = [] # DebugBreakpointHard list - self.mem_watched = [] # Memory areas watched - - def init_run(self, addr): - self.myjit.init_run(addr) - - def add_breakpoint(self, addr): - "Add bp @addr" - bp = DebugBreakpointSoft(addr) - func = lambda x: bp - bp.func = func - self.bp_list.append(bp) - self.myjit.add_breakpoint(addr, func) - - def init_memory_breakpoint(self): - "Set exception handler on EXCEPT_BREAKPOINT_MEMORY" - raise NotImplementedError("Not implemented") - - def add_memory_breakpoint(self, addr, size, read=False, write=False): - "add mem bp @[addr, addr + size], on read/write/both" - access_type = DebugBreakpointMemory.get_access_type(read=read, - write=write) - dbm = DebugBreakpointMemory(addr, size, access_type) - self.hw_bp_list.append(dbm) - self.myjit.vm.add_memory_breakpoint(addr, size, access_type) - - def remove_breakpoint(self, dbs): - "remove the DebugBreakpointSoft instance" - self.bp_list.remove(dbs) - self.myjit.remove_breakpoints_by_callback(dbs.func) - - def remove_breakpoint_by_addr(self, addr): - "remove breakpoints @ addr" - for bp in self.get_breakpoint_by_addr(addr): - self.remove_breakpoint(bp) - - def remove_memory_breakpoint(self, dbm): - "remove the DebugBreakpointMemory instance" - self.hw_bp_list.remove(dbm) - self.myjit.vm.remove_memory_breakpoint(dbm.addr, dbm.access_type) - - def remove_memory_breakpoint_by_addr_access(self, addr, read=False, - write=False): - "remove breakpoints @ addr" - access_type = DebugBreakpointMemory.get_access_type(read=read, - write=write) - for bp in self.hw_bp_list: - if bp.addr == addr and bp.access_type == access_type: - self.remove_memory_breakpoint(bp) - - def get_breakpoint_by_addr(self, addr): - ret = [] - for dbgsoft in self.bp_list: - if dbgsoft.addr == addr: - ret.append(dbgsoft) - return ret - - def get_breakpoints(self): - return self.bp_list - - def active_trace(self, mn=None, regs=None, newbloc=None): - if mn is not None: - self.myjit.jit.log_mn = mn - if regs is not None: - self.myjit.jit.log_regs = regs - if newbloc is not None: - self.myjit.jit.log_newbloc = newbloc - - def handle_exception(self, res): - if not res: - # A breakpoint has stopped the execution - return DebugBreakpointTerminate(res) - - if isinstance(res, DebugBreakpointSoft): - print("Breakpoint reached @0x%08x" % res.addr) - elif isinstance(res, ExceptionHandle): - if res == ExceptionHandle.memoryBreakpoint(): - print("Memory breakpoint reached!") - - # Remove flag - except_flag = self.myjit.vm.get_exception() - self.myjit.vm.set_exception(except_flag ^ res.except_flag) - - else: - raise NotImplementedError("Unknown Except") - else: - raise NotImplementedError("type res") - - # Repropagate res - return res - - def step(self): - "Step in jit" - - self.myjit.jit.set_options(jit_maxline=1) - # Reset all jitted blocks - self.myjit.jit.clear_jitted_blocks() - - res = self.myjit.continue_run(step=True) - self.handle_exception(res) - - self.myjit.jit.set_options(jit_maxline=50) - self.on_step() - - return res - - def run(self): - status = self.myjit.continue_run() - return self.handle_exception(status) - - def get_mem(self, addr, size=0xF): - "hexdump @addr, size" - - hexdump(self.myjit.vm.get_mem(addr, size)) - - def get_mem_raw(self, addr, size=0xF): - "hexdump @addr, size" - return self.myjit.vm.get_mem(addr, size) - - def watch_mem(self, addr, size=0xF): - self.mem_watched.append((addr, size)) - - def on_step(self): - for addr, size in self.mem_watched: - print("@0x%08x:" % addr) - self.get_mem(addr, size) - - def get_reg_value(self, reg_name): - return getattr(self.myjit.cpu, reg_name) - - def set_reg_value(self, reg_name, value): - - # Handle PC case - if reg_name == self.myjit.ir_arch.pc.name: - self.init_run(value) - - setattr(self.myjit.cpu, reg_name, value) - - def get_gpreg_all(self): - "Return general purposes registers" - return self.myjit.cpu.get_gpreg() - - -class DebugCmd(cmd.Cmd, object): - - "CommandLineInterpreter for Debugguer instance" - - color_g = '\033[92m' - color_e = '\033[0m' - color_b = '\033[94m' - color_r = '\033[91m' - - intro = color_g + "=== Miasm2 Debugging shell ===\nIf you need help, " - intro += "type 'help' or '?'" + color_e - prompt = color_b + "$> " + color_e - - def __init__(self, dbg): - "dbg : Debugguer" - self.dbg = dbg - super(DebugCmd, self).__init__() - - # Debug methods - - def print_breakpoints(self): - bp_list = self.dbg.bp_list - if len(bp_list) == 0: - print("No breakpoints.") - else: - for i, b in enumerate(bp_list): - print("%d\t0x%08x" % (i, b.addr)) - - def print_watchmems(self): - watch_list = self.dbg.mem_watched - if len(watch_list) == 0: - print("No memory watchpoints.") - else: - print("Num\tAddress \tSize") - for i, w in enumerate(watch_list): - addr, size = w - print("%d\t0x%08x\t0x%08x" % (i, addr, size)) - - def print_registers(self): - regs = self.dbg.get_gpreg_all() - - # Display settings - title1 = "Registers" - title2 = "Values" - max_name_len = max(map(len, list(regs) + [title1])) - - # Print value table - s = "%s%s | %s" % ( - title1, " " * (max_name_len - len(title1)), title2) - print(s) - print("-" * len(s)) - for name, value in sorted(viewitems(regs), key=lambda x: x[0]): - print( - "%s%s | %s" % ( - name, - " " * (max_name_len - len(name)), - hex(value).replace("L", "") - ) - ) - - def add_breakpoints(self, bp_addr): - for addr in bp_addr: - addr = int(addr, 0) - - good = True - for i, dbg_obj in enumerate(self.dbg.bp_list): - if dbg_obj.addr == addr: - good = False - break - if good is False: - print("Breakpoint 0x%08x already set (%d)" % (addr, i)) - else: - l = len(self.dbg.bp_list) - self.dbg.add_breakpoint(addr) - print("Breakpoint 0x%08x successfully added ! (%d)" % (addr, l)) - - display_mode = { - "mn": None, - "regs": None, - "newbloc": None - } - - def update_display_mode(self): - self.display_mode = { - "mn": self.dbg.myjit.jit.log_mn, - "regs": self.dbg.myjit.jit.log_regs, - "newbloc": self.dbg.myjit.jit.log_newbloc - } - - # Command line methods - def print_warning(self, s): - print(self.color_r + s + self.color_e) - - def onecmd(self, line): - cmd_translate = { - "h": "help", - "q": "exit", - "e": "exit", - "!": "exec", - "r": "run", - "i": "info", - "b": "breakpoint", - "s": "step", - "d": "dump" - } - - if len(line) >= 2 and \ - line[1] == " " and \ - line[:1] in cmd_translate: - line = cmd_translate[line[:1]] + line[1:] - - if len(line) == 1 and line in cmd_translate: - line = cmd_translate[line] - - r = super(DebugCmd, self).onecmd(line) - return r - - def can_exit(self): - return True - - def do_display(self, arg): - if arg == "": - self.help_display() - return - - args = arg.split(" ") - if args[-1].lower() not in ["on", "off"]: - self.print_warning("/!\ %s not in 'on' / 'off'" % args[-1]) - return - mode = args[-1].lower() == "on" - d = {} - for a in args[:-1]: - d[a] = mode - self.dbg.active_trace(**d) - self.update_display_mode() - - def help_display(self): - print("Enable/Disable tracing.") - print("Usage: display <mode1> <mode2> ... on|off") - print("Available modes are:") - for k in self.display_mode: - print("\t%s" % k) - print("Use 'info display' to get current values") - - def do_watchmem(self, arg): - if arg == "": - self.help_watchmem() - return - - args = arg.split(" ") - if len(args) >= 2: - size = int(args[1], 0) - else: - size = 0xF - - addr = int(args[0], 0) - - self.dbg.watch_mem(addr, size) - - def help_watchmem(self): - print("Add a memory watcher.") - print("Usage: watchmem <addr> [size]") - print("Use 'info watchmem' to get current memory watchers") - - def do_info(self, arg): - av_info = [ - "registers", - "display", - "breakpoints", - "watchmem" - ] - - if arg == "": - print("'info' must be followed by the name of an info command.") - print("List of info subcommands:") - for k in av_info: - print("\t%s" % k) - - if arg.startswith("b"): - # Breakpoint - self.print_breakpoints() - - if arg.startswith("d"): - # Display - self.update_display_mode() - for k, v in viewitems(self.display_mode): - print("%s\t\t%s" % (k, v)) - - if arg.startswith("w"): - # Watchmem - self.print_watchmems() - - if arg.startswith("r"): - # Registers - self.print_registers() - - def help_info(self): - print("Generic command for showing things about the program being") - print("debugged. Use 'info' without arguments to get the list of") - print("available subcommands.") - - def do_breakpoint(self, arg): - if arg == "": - self.help_breakpoint() - else: - addrs = arg.split(" ") - self.add_breakpoints(addrs) - - def help_breakpoint(self): - print("Add breakpoints to argument addresses.") - print("Example:") - print("\tbreakpoint 0x11223344") - print("\tbreakpoint 1122 0xabcd") - - def do_step(self, arg): - if arg == "": - nb = 1 - else: - nb = int(arg) - for _ in range(nb): - self.dbg.step() - - def help_step(self): - print("Step program until it reaches a different source line.") - print("Argument N means do this N times (or till program stops") - print("for another reason).") - - def do_dump(self, arg): - if arg == "": - self.help_dump() - else: - args = arg.split(" ") - if len(args) >= 2: - size = int(args[1], 0) - else: - size = 0xF - addr = int(args[0], 0) - - self.dbg.get_mem(addr, size) - - def help_dump(self): - print("Dump <addr> [size]. Dump size bytes at addr.") - - def do_run(self, _): - self.dbg.run() - - def help_run(self): - print("Launch or continue the current program") - - def do_exit(self, _): - return True - - def do_exec(self, line): - try: - print(eval(line)) - except Exception as error: - print("*** Error: %s" % error) - - def help_exec(self): - print("Exec a python command.") - print("You can also use '!' shortcut.") - - def help_exit(self): - print("Exit the interpreter.") - print("You can also use the Ctrl-D shortcut.") - - def help_help(self): - print("Print help") - - def postloop(self): - print('\nGoodbye !') - super(DebugCmd, self).postloop() - - do_EOF = do_exit - help_EOF = help_exit diff --git a/miasm2/analysis/depgraph.py b/miasm2/analysis/depgraph.py deleted file mode 100644 index 4bfae67f..00000000 --- a/miasm2/analysis/depgraph.py +++ /dev/null @@ -1,651 +0,0 @@ -"""Provide dependency graph""" - -from functools import total_ordering - -from future.utils import viewitems - -from miasm2.expression.expression import ExprInt, ExprLoc, ExprAssign -from miasm2.core.graph import DiGraph -from miasm2.core.locationdb import LocationDB -from miasm2.expression.simplifications import expr_simp_explicit -from miasm2.ir.symbexec import SymbolicExecutionEngine -from miasm2.ir.ir import IRBlock, AssignBlock -from miasm2.ir.translators import Translator -from miasm2.expression.expression_helper import possible_values - -try: - import z3 -except ImportError: - pass - -@total_ordering -class DependencyNode(object): - - """Node elements of a DependencyGraph - - A dependency node stands for the dependency on the @element at line number - @line_nb in the IRblock named @loc_key, *before* the evaluation of this - line. - """ - - __slots__ = ["_loc_key", "_element", "_line_nb", "_hash"] - - def __init__(self, loc_key, element, line_nb): - """Create a dependency node with: - @loc_key: LocKey instance - @element: Expr instance - @line_nb: int - """ - self._loc_key = loc_key - self._element = element - self._line_nb = line_nb - self._hash = hash( - (self._loc_key, self._element, self._line_nb)) - - def __hash__(self): - """Returns a hash of @self to uniquely identify @self""" - return self._hash - - def __eq__(self, depnode): - """Returns True if @self and @depnode are equals.""" - if not isinstance(depnode, self.__class__): - return False - return (self.loc_key == depnode.loc_key and - self.element == depnode.element and - self.line_nb == depnode.line_nb) - - def __ne__(self, depnode): - # required Python 2.7.14 - return not self == depnode - - def __lt__(self, node): - """Compares @self with @node.""" - if not isinstance(node, self.__class__): - return NotImplemented - - return ((self.loc_key, self.element, self.line_nb) < - (node.loc_key, node.element, node.line_nb)) - - def __str__(self): - """Returns a string representation of DependencyNode""" - return "<%s %s %s %s>" % (self.__class__.__name__, - self.loc_key, self.element, - self.line_nb) - - def __repr__(self): - """Returns a string representation of DependencyNode""" - return self.__str__() - - @property - def loc_key(self): - "Name of the current IRBlock" - return self._loc_key - - @property - def element(self): - "Current tracked Expr" - return self._element - - @property - def line_nb(self): - "Line in the current IRBlock" - return self._line_nb - - -class DependencyState(object): - - """ - Store intermediate depnodes states during dependencygraph analysis - """ - - def __init__(self, loc_key, pending, line_nb=None): - self.loc_key = loc_key - self.history = [loc_key] - self.pending = {k: set(v) for k, v in viewitems(pending)} - self.line_nb = line_nb - self.links = set() - - # Init lazy elements - self._graph = None - - def __repr__(self): - return "<State: %r (%r) (%r)>" % ( - self.loc_key, - self.pending, - self.links - ) - - def extend(self, loc_key): - """Return a copy of itself, with itself in history - @loc_key: LocKey instance for the new DependencyState's loc_key - """ - new_state = self.__class__(loc_key, self.pending) - new_state.links = set(self.links) - new_state.history = self.history + [loc_key] - return new_state - - def get_done_state(self): - """Returns immutable object representing current state""" - return (self.loc_key, frozenset(self.links)) - - def as_graph(self): - """Generates a Digraph of dependencies""" - graph = DiGraph() - for node_a, node_b in self.links: - if not node_b: - graph.add_node(node_a) - else: - graph.add_edge(node_a, node_b) - for parent, sons in viewitems(self.pending): - for son in sons: - graph.add_edge(parent, son) - return graph - - @property - def graph(self): - """Returns a DiGraph instance representing the DependencyGraph""" - if self._graph is None: - self._graph = self.as_graph() - return self._graph - - def remove_pendings(self, nodes): - """Remove resolved @nodes""" - for node in nodes: - del self.pending[node] - - def add_pendings(self, future_pending): - """Add @future_pending to the state""" - for node, depnodes in viewitems(future_pending): - if node not in self.pending: - self.pending[node] = depnodes - else: - self.pending[node].update(depnodes) - - def link_element(self, element, line_nb): - """Link element to its dependencies - @element: the element to link - @line_nb: the element's line - """ - - depnode = DependencyNode(self.loc_key, element, line_nb) - if not self.pending[element]: - # Create start node - self.links.add((depnode, None)) - else: - # Link element to its known dependencies - for node_son in self.pending[element]: - self.links.add((depnode, node_son)) - - def link_dependencies(self, element, line_nb, dependencies, - future_pending): - """Link unfollowed dependencies and create remaining pending elements. - @element: the element to link - @line_nb: the element's line - @dependencies: the element's dependencies - @future_pending: the future dependencies - """ - - depnode = DependencyNode(self.loc_key, element, line_nb) - - # Update pending, add link to unfollowed nodes - for dependency in dependencies: - if not dependency.follow: - # Add non followed dependencies to the dependency graph - parent = DependencyNode( - self.loc_key, dependency.element, line_nb) - self.links.add((parent, depnode)) - continue - # Create future pending between new dependency and the current - # element - future_pending.setdefault(dependency.element, set()).add(depnode) - - -class DependencyResult(DependencyState): - - """Container and methods for DependencyGraph results""" - - def __init__(self, ircfg, initial_state, state, inputs): - - super(DependencyResult, self).__init__(state.loc_key, state.pending) - self.initial_state = initial_state - self.history = state.history - self.pending = state.pending - self.line_nb = state.line_nb - self.inputs = inputs - self.links = state.links - self._ircfg = ircfg - - # Init lazy elements - self._has_loop = None - - @property - def unresolved(self): - """Set of nodes whose dependencies weren't found""" - return set(element for element in self.pending - if element != self._ircfg.IRDst) - - @property - def relevant_nodes(self): - """Set of nodes directly and indirectly influencing inputs""" - output = set() - for node_a, node_b in self.links: - output.add(node_a) - if node_b is not None: - output.add(node_b) - return output - - @property - def relevant_loc_keys(self): - """List of loc_keys containing nodes influencing inputs. - The history order is preserved.""" - # Get used loc_keys - used_loc_keys = set(depnode.loc_key for depnode in self.relevant_nodes) - - # Keep history order - output = [] - for loc_key in self.history: - if loc_key in used_loc_keys: - output.append(loc_key) - - return output - - @property - def has_loop(self): - """True iff there is at least one data dependencies cycle (regarding - the associated depgraph)""" - if self._has_loop is None: - self._has_loop = self.graph.has_loop() - return self._has_loop - - def irblock_slice(self, irb, max_line=None): - """Slice of the dependency nodes on the irblock @irb - @irb: irbloc instance - """ - - assignblks = [] - line2elements = {} - for depnode in self.relevant_nodes: - if depnode.loc_key != irb.loc_key: - continue - line2elements.setdefault(depnode.line_nb, - set()).add(depnode.element) - - for line_nb, elements in sorted(viewitems(line2elements)): - if max_line is not None and line_nb >= max_line: - break - assignmnts = {} - for element in elements: - if element in irb[line_nb]: - # constants, loc_key, ... are not in destination - assignmnts[element] = irb[line_nb][element] - assignblks.append(AssignBlock(assignmnts)) - - return IRBlock(irb.loc_key, assignblks) - - def emul(self, ir_arch, ctx=None, step=False): - """Symbolic execution of relevant nodes according to the history - Return the values of inputs nodes' elements - @ir_arch: IntermediateRepresentation instance - @ctx: (optional) Initial context as dictionary - @step: (optional) Verbose execution - Warning: The emulation is not sound if the inputs nodes depend on loop - variant. - """ - # Init - ctx_init = {} - if ctx is not None: - ctx_init.update(ctx) - assignblks = [] - - # Build a single assignment block according to history - last_index = len(self.relevant_loc_keys) - for index, loc_key in enumerate(reversed(self.relevant_loc_keys), 1): - if index == last_index and loc_key == self.initial_state.loc_key: - line_nb = self.initial_state.line_nb - else: - line_nb = None - assignblks += self.irblock_slice(self._ircfg.blocks[loc_key], - line_nb).assignblks - - # Eval the block - loc_db = LocationDB() - temp_loc = loc_db.get_or_create_name_location("Temp") - symb_exec = SymbolicExecutionEngine(ir_arch, ctx_init) - symb_exec.eval_updt_irblock(IRBlock(temp_loc, assignblks), step=step) - - # Return only inputs values (others could be wrongs) - return {element: symb_exec.symbols[element] - for element in self.inputs} - - -class DependencyResultImplicit(DependencyResult): - - """Stand for a result of a DependencyGraph with implicit option - - Provide path constraints using the z3 solver""" - # Z3 Solver instance - _solver = None - - unsat_expr = ExprAssign(ExprInt(0, 1), ExprInt(1, 1)) - - def _gen_path_constraints(self, translator, expr, expected): - """Generate path constraint from @expr. Handle special case with - generated loc_keys - """ - out = [] - expected = self._ircfg.loc_db.canonize_to_exprloc(expected) - expected_is_loc_key = expected.is_loc() - for consval in possible_values(expr): - value = self._ircfg.loc_db.canonize_to_exprloc(consval.value) - if expected_is_loc_key and value != expected: - continue - if not expected_is_loc_key and value.is_loc_key(): - continue - - conds = z3.And(*[translator.from_expr(cond.to_constraint()) - for cond in consval.constraints]) - if expected != value: - conds = z3.And( - conds, - translator.from_expr( - ExprAssign(value, - expected)) - ) - out.append(conds) - - if out: - conds = z3.Or(*out) - else: - # Ex: expr: lblgen1, expected: 0x1234 - # -> Avoid unconsistent solution lblgen1 = 0x1234 - conds = translator.from_expr(self.unsat_expr) - return conds - - def emul(self, ir_arch, ctx=None, step=False): - # Init - ctx_init = {} - if ctx is not None: - ctx_init.update(ctx) - solver = z3.Solver() - symb_exec = SymbolicExecutionEngine(ir_arch, ctx_init) - history = self.history[::-1] - history_size = len(history) - translator = Translator.to_language("z3") - size = self._ircfg.IRDst.size - - for hist_nb, loc_key in enumerate(history, 1): - if hist_nb == history_size and loc_key == self.initial_state.loc_key: - line_nb = self.initial_state.line_nb - else: - line_nb = None - irb = self.irblock_slice(self._ircfg.blocks[loc_key], line_nb) - - # Emul the block and get back destination - dst = symb_exec.eval_updt_irblock(irb, step=step) - - # Add constraint - if hist_nb < history_size: - next_loc_key = history[hist_nb] - expected = symb_exec.eval_expr(ExprLoc(next_loc_key, size)) - solver.add(self._gen_path_constraints(translator, dst, expected)) - # Save the solver - self._solver = solver - - # Return only inputs values (others could be wrongs) - return { - element: symb_exec.eval_expr(element) - for element in self.inputs - } - - @property - def is_satisfiable(self): - """Return True iff the solution path admits at least one solution - PRE: 'emul' - """ - return self._solver.check() == z3.sat - - @property - def constraints(self): - """If satisfiable, return a valid solution as a Z3 Model instance""" - if not self.is_satisfiable: - raise ValueError("Unsatisfiable") - return self._solver.model() - - -class FollowExpr(object): - - "Stand for an element (expression, depnode, ...) to follow or not" - __slots__ = ["follow", "element"] - - def __init__(self, follow, element): - self.follow = follow - self.element = element - - def __repr__(self): - return '%s(%r, %r)' % (self.__class__.__name__, self.follow, self.element) - - @staticmethod - def to_depnodes(follow_exprs, loc_key, line): - """Build a set of FollowExpr(DependencyNode) from the @follow_exprs set - of FollowExpr - @follow_exprs: set of FollowExpr - @loc_key: LocKey instance - @line: integer - """ - dependencies = set() - for follow_expr in follow_exprs: - dependencies.add(FollowExpr(follow_expr.follow, - DependencyNode(loc_key, - follow_expr.element, - line))) - return dependencies - - @staticmethod - def extract_depnodes(follow_exprs, only_follow=False): - """Extract depnodes from a set of FollowExpr(Depnodes) - @only_follow: (optional) extract only elements to follow""" - return set(follow_expr.element - for follow_expr in follow_exprs - if not(only_follow) or follow_expr.follow) - - -class DependencyGraph(object): - - """Implementation of a dependency graph - - A dependency graph contains DependencyNode as nodes. The oriented edges - stand for a dependency. - The dependency graph is made of the lines of a group of IRblock - *explicitly* or *implicitly* involved in the equation of given element. - """ - - def __init__(self, ircfg, - implicit=False, apply_simp=True, follow_mem=True, - follow_call=True): - """Create a DependencyGraph linked to @ircfg - - @ircfg: IRCFG instance - @implicit: (optional) Track IRDst for each block in the resulting path - - Following arguments define filters used to generate dependencies - @apply_simp: (optional) Apply expr_simp_explicit - @follow_mem: (optional) Track memory syntactically - @follow_call: (optional) Track through "call" - """ - # Init - self._ircfg = ircfg - self._implicit = implicit - - # Create callback filters. The order is relevant. - self._cb_follow = [] - if apply_simp: - self._cb_follow.append(self._follow_simp_expr) - self._cb_follow.append(lambda exprs: self._follow_exprs(exprs, - follow_mem, - follow_call)) - self._cb_follow.append(self._follow_no_loc_key) - - @staticmethod - def _follow_simp_expr(exprs): - """Simplify expression so avoid tracking useless elements, - as: XOR EAX, EAX - """ - follow = set() - for expr in exprs: - follow.add(expr_simp_explicit(expr)) - return follow, set() - - @staticmethod - def get_expr(expr, follow, nofollow): - """Update @follow/@nofollow according to insteresting nodes - Returns same expression (non modifier visitor). - - @expr: expression to handle - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - """ - if expr.is_id(): - follow.add(expr) - elif expr.is_int(): - nofollow.add(expr) - elif expr.is_mem(): - follow.add(expr) - return expr - - @staticmethod - def follow_expr(expr, _, nofollow, follow_mem=False, follow_call=False): - """Returns True if we must visit sub expressions. - @expr: expression to browse - @follow: set of nodes to follow - @nofollow: set of nodes not to follow - @follow_mem: force the visit of memory sub expressions - @follow_call: force the visit of call sub expressions - """ - if not follow_mem and expr.is_mem(): - nofollow.add(expr) - return False - if not follow_call and expr.is_function_call(): - nofollow.add(expr) - return False - return True - - @classmethod - def _follow_exprs(cls, exprs, follow_mem=False, follow_call=False): - """Extracts subnodes from exprs and returns followed/non followed - expressions according to @follow_mem/@follow_call - - """ - follow, nofollow = set(), set() - for expr in exprs: - expr.visit(lambda x: cls.get_expr(x, follow, nofollow), - lambda x: cls.follow_expr(x, follow, nofollow, - follow_mem, follow_call)) - return follow, nofollow - - @staticmethod - def _follow_no_loc_key(exprs): - """Do not follow loc_keys""" - follow = set() - for expr in exprs: - if expr.is_int() or expr.is_loc(): - continue - follow.add(expr) - - return follow, set() - - def _follow_apply_cb(self, expr): - """Apply callback functions to @expr - @expr : FollowExpr instance""" - follow = set([expr]) - nofollow = set() - - for callback in self._cb_follow: - follow, nofollow_tmp = callback(follow) - nofollow.update(nofollow_tmp) - - out = set(FollowExpr(True, expr) for expr in follow) - out.update(set(FollowExpr(False, expr) for expr in nofollow)) - return out - - def _track_exprs(self, state, assignblk, line_nb): - """Track pending expression in an assignblock""" - future_pending = {} - node_resolved = set() - for dst, src in viewitems(assignblk): - # Only track pending - if dst not in state.pending: - continue - # Track IRDst in implicit mode only - if dst == self._ircfg.IRDst and not self._implicit: - continue - assert dst not in node_resolved - node_resolved.add(dst) - dependencies = self._follow_apply_cb(src) - - state.link_element(dst, line_nb) - state.link_dependencies(dst, line_nb, - dependencies, future_pending) - - # Update pending nodes - state.remove_pendings(node_resolved) - state.add_pendings(future_pending) - - def _compute_intrablock(self, state): - """Follow dependencies tracked in @state in the current irbloc - @state: instance of DependencyState""" - - irb = self._ircfg.blocks[state.loc_key] - line_nb = len(irb) if state.line_nb is None else state.line_nb - - for cur_line_nb, assignblk in reversed(list(enumerate(irb[:line_nb]))): - self._track_exprs(state, assignblk, cur_line_nb) - - def get(self, loc_key, elements, line_nb, heads): - """Compute the dependencies of @elements at line number @line_nb in - the block named @loc_key in the current IRCFG, before the execution of - this line. Dependency check stop if one of @heads is reached - @loc_key: LocKey instance - @element: set of Expr instances - @line_nb: int - @heads: set of LocKey instances - Return an iterator on DiGraph(DependencyNode) - """ - # Init the algorithm - inputs = {element: set() for element in elements} - initial_state = DependencyState(loc_key, inputs, line_nb) - todo = set([initial_state]) - done = set() - dpResultcls = DependencyResultImplicit if self._implicit else DependencyResult - - while todo: - state = todo.pop() - self._compute_intrablock(state) - done_state = state.get_done_state() - if done_state in done: - continue - done.add(done_state) - if (not state.pending or - state.loc_key in heads or - not self._ircfg.predecessors(state.loc_key)): - yield dpResultcls(self._ircfg, initial_state, state, elements) - if not state.pending: - continue - - if self._implicit: - # Force IRDst to be tracked, except in the input block - state.pending[self._ircfg.IRDst] = set() - - # Propagate state to parents - for pred in self._ircfg.predecessors_iter(state.loc_key): - todo.add(state.extend(pred)) - - def get_from_depnodes(self, depnodes, heads): - """Alias for the get() method. Use the attributes of @depnodes as - argument. - PRE: Loc_Keys and lines of depnodes have to be equals - @depnodes: set of DependencyNode instances - @heads: set of LocKey instances - """ - lead = list(depnodes)[0] - elements = set(depnode.element for depnode in depnodes) - return self.get(lead.loc_key, elements, lead.line_nb, heads) diff --git a/miasm2/analysis/disasm_cb.py b/miasm2/analysis/disasm_cb.py deleted file mode 100644 index 36e120b6..00000000 --- a/miasm2/analysis/disasm_cb.py +++ /dev/null @@ -1,128 +0,0 @@ -#-*- coding:utf-8 -*- - -from __future__ import print_function - -from future.utils import viewvalues - -from miasm2.expression.expression import ExprInt, ExprId, ExprMem, match_expr -from miasm2.expression.simplifications import expr_simp -from miasm2.core.asmblock import AsmConstraintNext, AsmConstraintTo -from miasm2.core.locationdb import LocationDB -from miasm2.core.utils import upck32 - - -def get_ira(mnemo, attrib): - arch = mnemo.name, attrib - if arch == ("arm", "arm"): - from miasm2.arch.arm.ira import ir_a_arm_base as ira - elif arch == ("x86", 32): - from miasm2.arch.x86.ira import ir_a_x86_32 as ira - elif arch == ("x86", 64): - from miasm2.arch.x86.ira import ir_a_x86_64 as ira - else: - raise ValueError('unknown architecture: %s' % mnemo.name) - return ira - - -def arm_guess_subcall( - mnemo, attrib, pool_bin, cur_bloc, offsets_to_dis, loc_db): - ira = get_ira(mnemo, attrib) - - sp = LocationDB() - ir_arch = ira(sp) - ircfg = ira.new_ircfg() - print('###') - print(cur_bloc) - ir_arch.add_asmblock_to_ircfg(cur_bloc, ircfg) - - to_add = set() - for irblock in viewvalues(ircfg.blocks): - pc_val = None - lr_val = None - for exprs in irblock: - for e in exprs: - if e.dst == ir_arch.pc: - pc_val = e.src - if e.dst == mnemo.regs.LR: - lr_val = e.src - if pc_val is None or lr_val is None: - continue - if not isinstance(lr_val, ExprInt): - continue - - l = cur_bloc.lines[-1] - if lr_val.arg != l.offset + l.l: - continue - l = loc_db.get_or_create_offset_location(int(lr_val)) - c = AsmConstraintNext(l) - - to_add.add(c) - offsets_to_dis.add(int(lr_val)) - - for c in to_add: - cur_bloc.addto(c) - - -def arm_guess_jump_table( - mnemo, attrib, pool_bin, cur_bloc, offsets_to_dis, loc_db): - ira = get_ira(mnemo, attrib) - - jra = ExprId('jra') - jrb = ExprId('jrb') - - sp = LocationDB() - ir_arch = ira(sp) - ircfg = ira.new_ircfg() - ir_arch.add_asmblock_to_ircfg(cur_bloc, ircfg) - - for irblock in viewvalues(ircfg.blocks): - pc_val = None - for exprs in irblock: - for e in exprs: - if e.dst == ir_arch.pc: - pc_val = e.src - if pc_val is None: - continue - if not isinstance(pc_val, ExprMem): - continue - assert(pc_val.size == 32) - print(pc_val) - ad = pc_val.arg - ad = expr_simp(ad) - print(ad) - res = match_expr(ad, jra + jrb, set([jra, jrb])) - if res is False: - raise NotImplementedError('not fully functional') - print(res) - if not isinstance(res[jrb], ExprInt): - raise NotImplementedError('not fully functional') - base_ad = int(res[jrb]) - print(base_ad) - addrs = set() - i = -1 - max_table_entry = 10000 - max_diff_addr = 0x100000 # heuristic - while i < max_table_entry: - i += 1 - try: - ad = upck32(pool_bin.getbytes(base_ad + 4 * i, 4)) - except: - break - if abs(ad - base_ad) > max_diff_addr: - break - addrs.add(ad) - print([hex(x) for x in addrs]) - - for ad in addrs: - offsets_to_dis.add(ad) - l = loc_db.get_or_create_offset_location(ad) - c = AsmConstraintTo(l) - cur_bloc.addto(c) - -guess_funcs = [] - - -def guess_multi_cb( - mnemo, attrib, pool_bin, cur_bloc, offsets_to_dis, loc_db): - for f in guess_funcs: - f(mnemo, attrib, pool_bin, cur_bloc, offsets_to_dis, loc_db) diff --git a/miasm2/analysis/dse.py b/miasm2/analysis/dse.py deleted file mode 100644 index fee85984..00000000 --- a/miasm2/analysis/dse.py +++ /dev/null @@ -1,708 +0,0 @@ -"""Dynamic symbolic execution module. - -Offers a way to have a symbolic execution along a concrete one. -Basically, this is done through DSEEngine class, with scheme: - -dse = DSEEngine(Machine("x86_32")) -dse.attach(jitter) - -The DSE state can be updated through: - - - .update_state_from_concrete: update the values from the CPU, so the symbolic - execution will be completely concrete from this point (until changes) - - .update_state: inject information, for instance RAX = symbolic_RAX - - .symbolize_memory: symbolize (using .memory_to_expr) memory areas (ie, - reading from an address in one of these areas yield a symbol) - -The DSE run can be instrumented through: - - .add_handler: register an handler, modifying the state instead of the current - execution. Can be used for stubbing external API - - .add_lib_handler: register handlers for libraries - - .add_instrumentation: register an handler, modifying the state but continuing - the current execution. Can be used for logging facilities - - -On branch, if the decision is symbolic, one can also collect "path constraints" -and inverse them to produce new inputs potentially reaching new paths. - -Basically, this is done through DSEPathConstraint. In order to produce a new -solution, one can extend this class, and override 'handle_solution' to produce a -solution which fit its needs. It could avoid computing new solution by -overriding 'produce_solution'. - -If one is only interested in constraints associated to its path, the option -"produce_solution" should be set to False, to speed up emulation. -The constraints are accumulated in the .z3_cur z3.Solver object. - -Here are a few remainings TODO: - - handle endianness in check_state / atomic read: currently, but this is also - true for others Miasm2 symbolic engines, the endianness is not take in - account, and assumed to be Little Endian - - - too many memory dependencies in constraint tracking: in order to let z3 find - new solution, it does need information on memory values (for instance, a - lookup in a table with a symbolic index). The estimated possible involved - memory location could be too large to pass to the solver (threshold named - MAX_MEMORY_INJECT). One possible solution, not yet implemented, is to call - the solver for reducing the possible values thanks to its accumulated - constraints. -""" -from builtins import range -from collections import namedtuple - -try: - import z3 -except ImportError: - z3 = None - -from future.utils import viewitems - -from miasm2.core.utils import encode_hex, force_bytes -from miasm2.expression.expression import ExprMem, ExprInt, ExprCompose, \ - ExprAssign, ExprId, ExprLoc, LocKey -from miasm2.core.bin_stream import bin_stream_vm -from miasm2.jitter.emulatedsymbexec import EmulatedSymbExec -from miasm2.expression.expression_helper import possible_values -from miasm2.ir.translators import Translator -from miasm2.analysis.expression_range import expr_range -from miasm2.analysis.modularintervals import ModularIntervals -from miasm2.core.locationdb import LocationDB - -DriftInfo = namedtuple("DriftInfo", ["symbol", "computed", "expected"]) - -class DriftException(Exception): - """Raised when the emulation drift from the reference engine""" - - def __init__(self, info): - super(DriftException, self).__init__() - self.info = info - - def __str__(self): - if len(self.info) == 1: - return "Drift of %s: %s instead of %s" % ( - self.info[0].symbol, - self.info[0].computed, - self.info[0].expected, - ) - else: - return "Drift of:\n\t" + "\n\t".join("%s: %s instead of %s" % ( - dinfo.symbol, - dinfo.computed, - dinfo.expected) - for dinfo in self.info) - - -class ESETrackModif(EmulatedSymbExec): - """Extension of EmulatedSymbExec to be used by DSE engines - - Add the tracking of modified expressions, and the ability to symbolize - memory areas - """ - - def __init__(self, *args, **kwargs): - super(ESETrackModif, self).__init__(*args, **kwargs) - self.modified_expr = set() # Expr modified since the last reset - self.dse_memory_range = [] # List/Intervals of memory addresses to - # symbolize - self.dse_memory_to_expr = None # function(addr) -> Expr used to - # symbolize - - def mem_read(self, expr_mem): - if not expr_mem.ptr.is_int(): - return expr_mem - dst_addr = int(expr_mem.ptr) - - # Split access in atomic accesses - out = [] - for addr in range(dst_addr, dst_addr + expr_mem.size // 8): - if addr in self.dse_memory_range: - # Symbolize memory access - out.append(self.dse_memory_to_expr(addr)) - continue - atomic_access = ExprMem(ExprInt(addr, expr_mem.ptr.size), 8) - if atomic_access in self.symbols: - out.append( super(EmulatedSymbExec, self).mem_read(atomic_access)) - else: - # Get concrete value - atomic_access = ExprMem(ExprInt(addr, expr_mem.ptr.size), 8) - out.append(super(ESETrackModif, self).mem_read(atomic_access)) - - if len(out) == 1: - # Trivial case (optimization) - return out[0] - - # Simplify for constant merging (ex: {ExprInt(1, 8), ExprInt(2, 8)}) - return self.expr_simp(ExprCompose(*out)) - - def mem_write(self, expr, data): - # Call Symbolic mem_write (avoid side effects on vm) - return super(EmulatedSymbExec, self).mem_write(expr, data) - - def reset_modified(self): - """Reset modified expression tracker""" - self.modified_expr.clear() - - def apply_change(self, dst, src): - super(ESETrackModif, self).apply_change(dst, src) - self.modified_expr.add(dst) - - -class ESENoVMSideEffects(EmulatedSymbExec): - """ - Do EmulatedSymbExec without modifying memory - """ - def mem_write(self, expr, data): - return super(EmulatedSymbExec, self).mem_write(expr, data) - - -class DSEEngine(object): - """Dynamic Symbolic Execution Engine - - This class aims to be overridden for each specific purpose - """ - SYMB_ENGINE = ESETrackModif - - def __init__(self, machine): - self.machine = machine - self.loc_db = LocationDB() - self.handler = {} # addr -> callback(DSEEngine instance) - self.instrumentation = {} # addr -> callback(DSEEngine instance) - self.addr_to_cacheblocks = {} # addr -> {label -> IRBlock} - self.ir_arch = self.machine.ir(loc_db=self.loc_db) # corresponding IR - self.ircfg = self.ir_arch.new_ircfg() # corresponding IR - - # Defined after attachment - self.jitter = None # Jitload (concrete execution) - self.symb = None # SymbolicExecutionEngine - self.symb_concrete = None # Concrete SymbExec for path desambiguisation - self.mdis = None # DisasmEngine - - def prepare(self): - """Prepare the environment for attachment with a jitter""" - # Disassembler - self.mdis = self.machine.dis_engine(bin_stream_vm(self.jitter.vm), - lines_wd=1, - loc_db=self.loc_db) - - # Symbexec engine - ## Prepare symbexec engines - self.symb = self.SYMB_ENGINE(self.jitter.cpu, self.jitter.vm, - self.ir_arch, {}) - self.symb.enable_emulated_simplifications() - self.symb_concrete = ESENoVMSideEffects( - self.jitter.cpu, self.jitter.vm, - self.ir_arch, {} - ) - - ## Update registers value - self.symb.symbols[self.ir_arch.IRDst] = ExprInt( - getattr(self.jitter.cpu, self.ir_arch.pc.name), - self.ir_arch.IRDst.size - ) - - # Activate callback on each instr - self.jitter.jit.set_options(max_exec_per_call=1, jit_maxline=1) - self.jitter.exec_cb = self.callback - - # Clean jit cache to avoid multi-line basic blocks already jitted - self.jitter.jit.clear_jitted_blocks() - - def attach(self, emulator): - """Attach the DSE to @emulator - @emulator: jitload (or API equivalent) instance - - To attach *DURING A BREAKPOINT*, one may consider using the following snippet: - - def breakpoint(self, jitter): - ... - dse.attach(jitter) - dse.update... - ... - # Additional call to the exec callback is necessary, as breakpoints are - # honored AFTER exec callback - jitter.exec_cb(jitter) - - return True - - Without it, one may encounteer a DriftException error due to a - "desynchronization" between jitter and dse states. Indeed, on 'handle' - call, the jitter must be one instruction AFTER the dse. - """ - self.jitter = emulator - self.prepare() - - def handle(self, cur_addr): - r"""Handle destination - @cur_addr: Expr of the next address in concrete execution - /!\ cur_addr may be a loc_key - - In this method, self.symb is in the "just before branching" state - """ - pass - - def add_handler(self, addr, callback): - """Add a @callback for address @addr before any state update. - The state IS NOT updated after returning from the callback - @addr: int - @callback: func(dse instance)""" - self.handler[addr] = callback - - def add_lib_handler(self, libimp, namespace): - """Add search for handler based on a @libimp libimp instance - - Known functions will be looked by {name}_symb in the @namespace - """ - namespace = dict( - (force_bytes(name), func) for name, func in viewitems(namespace) - ) - - # lambda cannot contain statement - def default_func(dse): - fname = b"%s_symb" % libimp.fad2cname[dse.jitter.pc] - raise RuntimeError("Symbolic stub '%s' not found" % fname) - - for addr, fname in viewitems(libimp.fad2cname): - fname = force_bytes(fname) - fname = b"%s_symb" % fname - func = namespace.get(fname, None) - if func is not None: - self.add_handler(addr, func) - else: - self.add_handler(addr, default_func) - - def add_instrumentation(self, addr, callback): - """Add a @callback for address @addr before any state update. - The state IS updated after returning from the callback - @addr: int - @callback: func(dse instance)""" - self.instrumentation[addr] = callback - - def _check_state(self): - """Check the current state against the concrete one""" - errors = [] # List of DriftInfo - - for symbol in self.symb.modified_expr: - # Do not consider PC - if symbol in [self.ir_arch.pc, self.ir_arch.IRDst]: - continue - - # Consider only concrete values - symb_value = self.eval_expr(symbol) - if not symb_value.is_int(): - continue - symb_value = int(symb_value) - - # Check computed values against real ones - if symbol.is_id(): - if hasattr(self.jitter.cpu, symbol.name): - value = getattr(self.jitter.cpu, symbol.name) - if value != symb_value: - errors.append(DriftInfo(symbol, symb_value, value)) - elif symbol.is_mem() and symbol.ptr.is_int(): - value_chr = self.jitter.vm.get_mem( - int(symbol.ptr), - symbol.size // 8 - ) - exp_value = int(encode_hex(value_chr[::-1]), 16) - if exp_value != symb_value: - errors.append(DriftInfo(symbol, symb_value, exp_value)) - - # Check for drift, and act accordingly - if errors: - raise DriftException(errors) - - def callback(self, _): - """Called before each instruction""" - # Assert synchronization with concrete execution - self._check_state() - - # Call callbacks associated to the current address - cur_addr = self.jitter.pc - if isinstance(cur_addr, LocKey): - lbl = self.ir_arch.loc_db.loc_key_to_label(cur_addr) - cur_addr = lbl.offset - - if cur_addr in self.handler: - self.handler[cur_addr](self) - return True - - if cur_addr in self.instrumentation: - self.instrumentation[cur_addr](self) - - # Handle current address - self.handle(ExprInt(cur_addr, self.ir_arch.IRDst.size)) - - # Avoid memory issue in ExpressionSimplifier - if len(self.symb.expr_simp.simplified_exprs) > 100000: - self.symb.expr_simp.simplified_exprs.clear() - - # Get IR blocks - if cur_addr in self.addr_to_cacheblocks: - self.ircfg.blocks.clear() - self.ircfg.blocks.update(self.addr_to_cacheblocks[cur_addr]) - else: - - ## Reset cache structures - self.ircfg.blocks.clear()# = {} - - ## Update current state - asm_block = self.mdis.dis_block(cur_addr) - self.ir_arch.add_asmblock_to_ircfg(asm_block, self.ircfg) - self.addr_to_cacheblocks[cur_addr] = dict(self.ircfg.blocks) - - # Emulate the current instruction - self.symb.reset_modified() - - # Is the symbolic execution going (potentially) to jump on a lbl_gen? - if len(self.ircfg.blocks) == 1: - self.symb.run_at(self.ircfg, cur_addr) - else: - # Emulation could stuck in generated IR blocks - # But concrete execution callback is not enough precise to obtain - # the full IR blocks path - # -> Use a fully concrete execution to get back path - - # Update the concrete execution - self._update_state_from_concrete_symb( - self.symb_concrete, cpu=True, mem=True - ) - while True: - - next_addr_concrete = self.symb_concrete.run_block_at( - self.ircfg, cur_addr - ) - self.symb.run_block_at(self.ircfg, cur_addr) - - if not (isinstance(next_addr_concrete, ExprLoc) and - self.ir_arch.loc_db.get_location_offset( - next_addr_concrete.loc_key - ) is None): - # Not a lbl_gen, exit - break - - # Call handle with lbl_gen state - self.handle(next_addr_concrete) - cur_addr = next_addr_concrete - - - # At this stage, symbolic engine is one instruction after the concrete - # engine - - return True - - def _get_gpregs(self): - """Return a dict of regs: value from the jitter - This version use the regs associated to the attrib (!= cpu.get_gpreg()) - """ - out = {} - regs = self.ir_arch.arch.regs.attrib_to_regs[self.ir_arch.attrib] - for reg in regs: - if hasattr(self.jitter.cpu, reg.name): - out[reg.name] = getattr(self.jitter.cpu, reg.name) - return out - - def take_snapshot(self): - """Return a snapshot of the current state (including jitter state)""" - snapshot = { - "mem": self.jitter.vm.get_all_memory(), - "regs": self._get_gpregs(), - "symb": self.symb.symbols.copy(), - } - return snapshot - - def restore_snapshot(self, snapshot, memory=True): - """Restore a @snapshot taken with .take_snapshot - @snapshot: .take_snapshot output - @memory: (optional) if set, also restore the memory - """ - # Restore memory - if memory: - self.jitter.vm.reset_memory_page_pool() - self.jitter.vm.reset_code_bloc_pool() - for addr, metadata in viewitems(snapshot["mem"]): - self.jitter.vm.add_memory_page( - addr, - metadata["access"], - metadata["data"] - ) - - # Restore registers - self.jitter.pc = snapshot["regs"][self.ir_arch.pc.name] - for reg, value in viewitems(snapshot["regs"]): - setattr(self.jitter.cpu, reg, value) - - # Reset intern elements - self.jitter.vm.set_exception(0) - self.jitter.cpu.set_exception(0) - self.jitter.bs._atomic_mode = False - - # Reset symb exec - for key, _ in list(viewitems(self.symb.symbols)): - del self.symb.symbols[key] - for expr, value in viewitems(snapshot["symb"]): - self.symb.symbols[expr] = value - - def update_state(self, assignblk): - """From this point, assume @assignblk in the symbolic execution - @assignblk: AssignBlock/{dst -> src} - """ - for dst, src in viewitems(assignblk): - self.symb.apply_change(dst, src) - - def _update_state_from_concrete_symb(self, symbexec, cpu=True, mem=False): - if mem: - # Values will be retrieved from the concrete execution if they are - # not present - symbexec.symbols.symbols_mem.base_to_memarray.clear() - if cpu: - regs = self.ir_arch.arch.regs.attrib_to_regs[self.ir_arch.attrib] - for reg in regs: - if hasattr(self.jitter.cpu, reg.name): - value = ExprInt(getattr(self.jitter.cpu, reg.name), - size=reg.size) - symbexec.symbols[reg] = value - - def update_state_from_concrete(self, cpu=True, mem=False): - r"""Update the symbolic state with concrete values from the concrete - engine - - @cpu: (optional) if set, update registers' value - @mem: (optional) if set, update memory value - - /!\ all current states will be loss. - This function is usually called when states are no more synchronized - (at the beginning, returning from an unstubbed syscall, ...) - """ - self._update_state_from_concrete_symb(self.symb, cpu, mem) - - def eval_expr(self, expr): - """Return the evaluation of @expr: - @expr: Expr instance""" - return self.symb.eval_expr(expr) - - @staticmethod - def memory_to_expr(addr): - """Translate an address to its corresponding symbolic ID (8bits) - @addr: int""" - return ExprId("MEM_0x%x" % int(addr), 8) - - def symbolize_memory(self, memory_range): - """Register a range of memory addresses to symbolize - @memory_range: object with support of __in__ operation (intervals, list, - ...) - """ - self.symb.dse_memory_range = memory_range - self.symb.dse_memory_to_expr = self.memory_to_expr - - -class DSEPathConstraint(DSEEngine): - """Dynamic Symbolic Execution Engine keeping the path constraint - - Possible new "solutions" are produced along the path, by inversing concrete - path constraint. Thus, a "solution" is a potential initial context leading - to a new path. - - In order to produce a new solution, one can extend this class, and override - 'handle_solution' to produce a solution which fit its needs. It could avoid - computing new solution by overriding 'produce_solution'. - - If one is only interested in constraints associated to its path, the option - "produce_solution" should be set to False, to speed up emulation. - The constraints are accumulated in the .z3_cur z3.Solver object. - - """ - - # Maximum memory size to inject in constraints solving - MAX_MEMORY_INJECT = 0x10000 - - # Produce solution strategies - PRODUCE_NO_SOLUTION = 0 - PRODUCE_SOLUTION_CODE_COV = 1 - PRODUCE_SOLUTION_BRANCH_COV = 2 - PRODUCE_SOLUTION_PATH_COV = 3 - - def __init__(self, machine, produce_solution=PRODUCE_SOLUTION_CODE_COV, - known_solutions=None, - **kwargs): - """Init a DSEPathConstraint - @machine: Machine of the targeted architecture instance - @produce_solution: (optional) if set, new solutions will be computed""" - super(DSEPathConstraint, self).__init__(machine, **kwargs) - - # Dependency check - assert z3 is not None - - # Init PathConstraint specifics structures - self.cur_solver = z3.Solver() - self.new_solutions = {} # solution identifier -> solution's model - self._known_solutions = set() # set of solution identifiers - self.z3_trans = Translator.to_language("z3") - self._produce_solution_strategy = produce_solution - self._previous_addr = None - self._history = None - if produce_solution == self.PRODUCE_SOLUTION_PATH_COV: - self._history = [] # List of addresses in the current path - - def take_snapshot(self, *args, **kwargs): - snap = super(DSEPathConstraint, self).take_snapshot(*args, **kwargs) - snap["new_solutions"] = { - dst: src.copy - for dst, src in viewitems(self.new_solutions) - } - snap["cur_constraints"] = self.cur_solver.assertions() - if self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: - snap["_history"] = list(self._history) - elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: - snap["_previous_addr"] = self._previous_addr - return snap - - def restore_snapshot(self, snapshot, keep_known_solutions=True, **kwargs): - """Restore a DSEPathConstraint snapshot - @keep_known_solutions: if set, do not forget solutions already found. - -> They will not appear in 'new_solutions' - """ - super(DSEPathConstraint, self).restore_snapshot(snapshot, **kwargs) - self.new_solutions.clear() - self.new_solutions.update(snapshot["new_solutions"]) - self.cur_solver = z3.Solver() - self.cur_solver.add(snapshot["cur_constraints"]) - if not keep_known_solutions: - self._known_solutions.clear() - if self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: - self._history = list(snapshot["_history"]) - elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: - self._previous_addr = snapshot["_previous_addr"] - - def _key_for_solution_strategy(self, destination): - """Return the associated identifier for the current solution strategy""" - if self._produce_solution_strategy == self.PRODUCE_NO_SOLUTION: - # Never produce a solution - return None - elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_CODE_COV: - # Decision based on code coverage - # -> produce a solution if the destination has never been seen - key = destination - - elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: - # Decision based on branch coverage - # -> produce a solution if the current branch has never been take - key = (self._previous_addr, destination) - - elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: - # Decision based on path coverage - # -> produce a solution if the current path has never been take - key = tuple(self._history + [destination]) - else: - raise ValueError("Unknown produce solution strategy") - - return key - - def produce_solution(self, destination): - """Called to determine if a solution for @destination should be test for - satisfiability and computed - @destination: Expr instance of the target @destination - """ - key = self._key_for_solution_strategy(destination) - if key is None: - return False - return key not in self._known_solutions - - def handle_solution(self, model, destination): - """Called when a new solution for destination @destination is founded - @model: z3 model instance - @destination: Expr instance for an addr which is not on the DSE path - """ - key = self._key_for_solution_strategy(destination) - assert key is not None - self.new_solutions[key] = model - self._known_solutions.add(key) - - def handle_correct_destination(self, destination, path_constraints): - """[DEV] Called by handle() to update internal structures giving the - correct destination (the concrete execution one). - """ - - # Update structure used by produce_solution() - if self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: - self._history.append(destination) - elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: - self._previous_addr = destination - - # Update current solver - for cons in path_constraints: - self.cur_solver.add(self.z3_trans.from_expr(cons)) - - def handle(self, cur_addr): - cur_addr = self.ir_arch.loc_db.canonize_to_exprloc(cur_addr) - symb_pc = self.eval_expr(self.ir_arch.IRDst) - possibilities = possible_values(symb_pc) - cur_path_constraint = set() # path_constraint for the concrete path - if len(possibilities) == 1: - dst = next(iter(possibilities)).value - dst = self.ir_arch.loc_db.canonize_to_exprloc(dst) - assert dst == cur_addr - else: - for possibility in possibilities: - target_addr = self.ir_arch.loc_db.canonize_to_exprloc( - possibility.value - ) - path_constraint = set() # Set of ExprAssign for the possible path - - # Get constraint associated to the possible path - memory_to_add = ModularIntervals(symb_pc.size) - for cons in possibility.constraints: - eaff = cons.to_constraint() - # eaff.get_r(mem_read=True) is not enough - # ExprAssign consider a Memory access in dst as a write - mem = eaff.dst.get_r(mem_read=True) - mem.update(eaff.src.get_r(mem_read=True)) - for expr in mem: - if expr.is_mem(): - addr_range = expr_range(expr.ptr) - # At upper bounds, add the size of the memory access - # if addr (- [a, b], then @size[addr] reachables - # values are in @8[a, b + size[ - for start, stop in addr_range: - stop += expr.size // 8 - 1 - full_range = ModularIntervals( - symb_pc.size, - [(start, stop)] - ) - memory_to_add.update(full_range) - path_constraint.add(eaff) - - if memory_to_add.length > self.MAX_MEMORY_INJECT: - # TODO re-croncretize the constraint or z3-try - raise RuntimeError("Not implemented: too long memory area") - - # Inject memory - for start, stop in memory_to_add: - for address in range(start, stop + 1): - expr_mem = ExprMem(ExprInt(address, - self.ir_arch.pc.size), - 8) - value = self.eval_expr(expr_mem) - if not value.is_int(): - raise TypeError("Rely on a symbolic memory case, " \ - "address 0x%x" % address) - path_constraint.add(ExprAssign(expr_mem, value)) - - if target_addr == cur_addr: - # Add path constraint - cur_path_constraint = path_constraint - - elif self.produce_solution(target_addr): - # Looking for a new solution - self.cur_solver.push() - for cons in path_constraint: - trans = self.z3_trans.from_expr(cons) - trans = z3.simplify(trans) - self.cur_solver.add(trans) - - result = self.cur_solver.check() - if result == z3.sat: - model = self.cur_solver.model() - self.handle_solution(model, target_addr) - self.cur_solver.pop() - - self.handle_correct_destination(cur_addr, cur_path_constraint) diff --git a/miasm2/analysis/expression_range.py b/miasm2/analysis/expression_range.py deleted file mode 100644 index 8f498549..00000000 --- a/miasm2/analysis/expression_range.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Naive range analysis for expression""" - -from future.builtins import zip -from functools import reduce - -from miasm2.analysis.modularintervals import ModularIntervals - -_op_range_handler = { - "+": lambda x, y: x + y, - "&": lambda x, y: x & y, - "|": lambda x, y: x | y, - "^": lambda x, y: x ^ y, - "*": lambda x, y: x * y, - "a>>": lambda x, y: x.arithmetic_shift_right(y), - "<<": lambda x, y: x << y, - ">>": lambda x, y: x >> y, - ">>>": lambda x, y: x.rotation_right(y), - "<<<": lambda x, y: x.rotation_left(y), -} - -def expr_range(expr): - """Return a ModularIntervals containing the range of possible values of - @expr""" - max_bound = (1 << expr.size) - 1 - if expr.is_int(): - return ModularIntervals(expr.size, [(int(expr), int(expr))]) - elif expr.is_id() or expr.is_mem(): - return ModularIntervals(expr.size, [(0, max_bound)]) - elif expr.is_slice(): - interval_mask = ((1 << expr.start) - 1) ^ ((1 << expr.stop) - 1) - arg = expr_range(expr.arg) - # Mask for possible range, and shift range - return ((arg & interval_mask) >> expr.start).size_update(expr.size) - elif expr.is_compose(): - sub_ranges = [expr_range(arg) for arg in expr.args] - args_idx = [info[0] for info in expr.iter_args()] - - # No shift for the first one - ret = sub_ranges[0].size_update(expr.size) - - # Doing it progressively (2 by 2) - for shift, sub_range in zip(args_idx[1:], sub_ranges[1:]): - ret |= sub_range.size_update(expr.size) << shift - return ret - elif expr.is_op(): - # A few operation are handled with care - # Otherwise, overapproximate (ie. full range interval) - if expr.op in _op_range_handler: - sub_ranges = [expr_range(arg) for arg in expr.args] - return reduce( - _op_range_handler[expr.op], - (sub_range for sub_range in sub_ranges[1:]), - sub_ranges[0] - ) - elif expr.op == "-": - assert len(expr.args) == 1 - return - expr_range(expr.args[0]) - elif expr.op == "%": - assert len(expr.args) == 2 - op, mod = [expr_range(arg) for arg in expr.args] - if mod.intervals.length == 1: - # Modulo intervals is not supported - return op % mod.intervals.hull()[0] - - # Operand not handled, return the full domain - return ModularIntervals(expr.size, [(0, max_bound)]) - elif expr.is_cond(): - return expr_range(expr.src1).union(expr_range(expr.src2)) - else: - raise TypeError("Unsupported type: %s" % expr.__class__) diff --git a/miasm2/analysis/gdbserver.py b/miasm2/analysis/gdbserver.py deleted file mode 100644 index 61ee8955..00000000 --- a/miasm2/analysis/gdbserver.py +++ /dev/null @@ -1,453 +0,0 @@ -#-*- coding:utf-8 -*- - -from __future__ import print_function -from future.builtins import map, range - -from miasm2.core.utils import decode_hex, encode_hex, int_to_byte - -import socket -import struct -import time -import logging -from io import BytesIO -import miasm2.analysis.debugging as debugging -from miasm2.jitter.jitload import ExceptionHandle - - -class GdbServer(object): - - "Debugguer binding for GDBServer protocol" - - general_registers_order = [] - general_registers_size = {} # RegName : Size in octet - status = b"S05" - - def __init__(self, dbg, port=4455): - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server.bind(('localhost', port)) - server.listen(1) - self.server = server - self.dbg = dbg - - # Communication methods - - def compute_checksum(self, data): - return encode_hex(int_to_byte(sum(map(ord, data)) % 256)) - - def get_messages(self): - all_data = b"" - while True: - data = self.sock.recv(4096) - if not data: - break - all_data += data - - logging.debug("<- %r", all_data) - self.recv_queue += self.parse_messages(all_data) - - def parse_messages(self, data): - buf = BytesIO(data) - msgs = [] - - while (buf.tell() < buf.len): - token = buf.read(1) - if token == b"+": - continue - if token == b"-": - raise NotImplementedError("Resend packet") - if token == b"$": - packet_data = b"" - c = buf.read(1) - while c != b"#": - packet_data += c - c = buf.read(1) - checksum = buf.read(2) - if checksum != self.compute_checksum(packet_data): - raise ValueError("Incorrect checksum") - msgs.append(packet_data) - - return msgs - - def send_string(self, s): - self.send_queue.append(b"O" + encode_hex(s)) - - def process_messages(self): - - while self.recv_queue: - msg = self.recv_queue.pop(0) - buf = BytesIO(msg) - msg_type = buf.read(1) - - self.send_queue.append(b"+") - - if msg_type == b"q": - if msg.startswith(b"qSupported"): - self.send_queue.append(b"PacketSize=3fff") - elif msg.startswith(b"qC"): - # Current thread - self.send_queue.append(b"") - elif msg.startswith(b"qAttached"): - # Not supported - self.send_queue.append(b"") - elif msg.startswith(b"qTStatus"): - # Not supported - self.send_queue.append(b"") - elif msg.startswith(b"qfThreadInfo"): - # Not supported - self.send_queue.append(b"") - else: - raise NotImplementedError() - - elif msg_type == b"H": - # Set current thread - self.send_queue.append(b"OK") - - elif msg_type == b"?": - # Report why the target halted - self.send_queue.append(self.status) # TRAP signal - - elif msg_type == b"g": - # Report all general register values - self.send_queue.append(self.report_general_register_values()) - - elif msg_type == b"p": - # Read a specific register - reg_num = int(buf.read(), 16) - self.send_queue.append(self.read_register(reg_num)) - - elif msg_type == b"P": - # Set a specific register - reg_num, value = buf.read().split(b"=") - reg_num = int(reg_num, 16) - value = int(encode_hex(decode_hex(value)[::-1]), 16) - self.set_register(reg_num, value) - self.send_queue.append(b"OK") - - elif msg_type == b"m": - # Read memory - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - self.send_queue.append(self.read_memory(addr, size)) - - elif msg_type == b"k": - # Kill - self.sock.close() - self.send_queue = [] - self.sock = None - - elif msg_type == b"!": - # Extending debugging will be used - self.send_queue.append(b"OK") - - elif msg_type == b"v": - if msg == b"vCont?": - # Is vCont supported ? - self.send_queue.append(b"") - - elif msg_type == b"s": - # Step - self.dbg.step() - self.send_queue.append(b"S05") # TRAP signal - - elif msg_type == b"Z": - # Add breakpoint or watchpoint - bp_type = buf.read(1) - if bp_type == b"0": - # Exec breakpoint - assert(buf.read(1) == b",") - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - - if size != 1: - raise NotImplementedError("Bigger size") - self.dbg.add_breakpoint(addr) - self.send_queue.append(b"OK") - - elif bp_type == b"1": - # Hardware BP - assert(buf.read(1) == b",") - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - - self.dbg.add_memory_breakpoint( - addr, - size, - read=True, - write=True - ) - self.send_queue.append(b"OK") - - elif bp_type in [b"2", b"3", b"4"]: - # Memory breakpoint - assert(buf.read(1) == b",") - read = bp_type in [b"3", b"4"] - write = bp_type in [b"2", b"4"] - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - - self.dbg.add_memory_breakpoint( - addr, - size, - read=read, - write=write - ) - self.send_queue.append(b"OK") - - else: - raise ValueError("Impossible value") - - elif msg_type == b"z": - # Remove breakpoint or watchpoint - bp_type = buf.read(1) - if bp_type == b"0": - # Exec breakpoint - assert(buf.read(1) == b",") - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - - if size != 1: - raise NotImplementedError("Bigger size") - dbgsoft = self.dbg.get_breakpoint_by_addr(addr) - assert(len(dbgsoft) == 1) - self.dbg.remove_breakpoint(dbgsoft[0]) - self.send_queue.append(b"OK") - - elif bp_type == b"1": - # Hardware BP - assert(buf.read(1) == b",") - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - self.dbg.remove_memory_breakpoint_by_addr_access( - addr, - read=True, - write=True - ) - self.send_queue.append(b"OK") - - elif bp_type in [b"2", b"3", b"4"]: - # Memory breakpoint - assert(buf.read(1) == b",") - read = bp_type in [b"3", b"4"] - write = bp_type in [b"2", b"4"] - addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) - - self.dbg.remove_memory_breakpoint_by_addr_access( - addr, - read=read, - write=write - ) - self.send_queue.append(b"OK") - - else: - raise ValueError("Impossible value") - - elif msg_type == b"c": - # Continue - self.status = b"" - self.send_messages() - ret = self.dbg.run() - if isinstance(ret, debugging.DebugBreakpointSoft): - self.status = b"S05" - self.send_queue.append(b"S05") # TRAP signal - elif isinstance(ret, ExceptionHandle): - if ret == ExceptionHandle.memoryBreakpoint(): - self.status = b"S05" - self.send_queue.append(b"S05") - else: - raise NotImplementedError("Unknown Except") - elif isinstance(ret, debugging.DebugBreakpointTerminate): - # Connexion should close, but keep it running as a TRAP - # The connexion will be close on instance destruction - print(ret) - self.status = b"S05" - self.send_queue.append(b"S05") - else: - raise NotImplementedError() - - else: - raise NotImplementedError( - "Not implemented: message type %r" % msg_type - ) - - def send_messages(self): - for msg in self.send_queue: - if msg == b"+": - data = b"+" - else: - data = b"$%s#%s" % (msg, self.compute_checksum(msg)) - logging.debug("-> %r", data) - self.sock.send(data) - self.send_queue = [] - - def main_loop(self): - self.recv_queue = [] - self.send_queue = [] - - self.send_string(b"Test\n") - - while (self.sock): - self.get_messages() - self.process_messages() - self.send_messages() - - def run(self): - self.sock, self.address = self.server.accept() - self.main_loop() - - # Debugguer processing methods - def report_general_register_values(self): - s = b"" - for i in range(len(self.general_registers_order)): - s += self.read_register(i) - return s - - def read_register(self, reg_num): - reg_name = self.general_registers_order[reg_num] - reg_value = self.read_register_by_name(reg_name) - size = self.general_registers_size[reg_name] - - pack_token = "" - if size == 1: - pack_token = "<B" - elif size == 2: - pack_token = "<H" - elif size == 4: - pack_token = "<I" - elif size == 8: - pack_token = "<Q" - else: - raise NotImplementedError("Unknown size") - - return encode_hex(struct.pack(pack_token, reg_value)) - - def set_register(self, reg_num, value): - reg_name = self.general_registers_order[reg_num] - self.dbg.set_reg_value(reg_name, value) - - def read_register_by_name(self, reg_name): - return self.dbg.get_reg_value(reg_name) - - def read_memory(self, addr, size): - except_flag_vm = self.dbg.myjit.vm.get_exception() - try: - return encode_hex(self.dbg.get_mem_raw(addr, size)) - except RuntimeError: - self.dbg.myjit.vm.set_exception(except_flag_vm) - return b"00" * size - - -class GdbServer_x86_32(GdbServer): - - "Extend GdbServer for x86 32bits purposes" - - general_registers_order = [ - "EAX", "ECX", "EDX", "EBX", "ESP", "EBP", "ESI", - "EDI", "EIP", "EFLAGS", "CS", "SS", "DS", "ES", - "FS", "GS" - ] - - general_registers_size = { - "EAX": 4, - "ECX": 4, - "EDX": 4, - "EBX": 4, - "ESP": 4, - "EBP": 4, - "ESI": 4, - "EDI": 4, - "EIP": 4, - "EFLAGS": 2, - "CS": 2, - "SS": 2, - "DS": 2, - "ES": 2, - "FS": 2, - "GS": 2 - } - - register_ignore = [ - "tf", "i_f", "nt", "rf", "vm", "ac", "vif", "vip", "i_d" - ] - - def read_register_by_name(self, reg_name): - sup_func = super(GdbServer_x86_32, self).read_register_by_name - - # Assert EIP on pc jitter - if reg_name == "EIP": - return self.dbg.myjit.pc - - # EFLAGS case - if reg_name == "EFLAGS": - val = 0 - eflags_args = [ - "cf", 1, "pf", 0, "af", 0, "zf", "nf", "tf", "i_f", "df", "of" - ] - eflags_args += ["nt", 0, "rf", "vm", "ac", "vif", "vip", "i_d"] - eflags_args += [0] * 10 - - for i, arg in enumerate(eflags_args): - if isinstance(arg, str): - if arg not in self.register_ignore: - to_add = sup_func(arg) - else: - to_add = 0 - else: - to_add = arg - - val |= (to_add << i) - return val - else: - return sup_func(reg_name) - - -class GdbServer_msp430(GdbServer): - - "Extend GdbServer for msp430 purposes" - - general_registers_order = [ - "PC", "SP", "SR", "R3", "R4", "R5", "R6", "R7", - "R8", "R9", "R10", "R11", "R12", "R13", "R14", - "R15" - ] - - general_registers_size = { - "PC": 2, - "SP": 2, - "SR": 2, - "R3": 2, - "R4": 2, - "R5": 2, - "R6": 2, - "R7": 2, - "R8": 2, - "R9": 2, - "R10": 2, - "R11": 2, - "R12": 2, - "R13": 2, - "R14": 2, - "R15": 2 - } - - def read_register_by_name(self, reg_name): - sup_func = super(GdbServer_msp430, self).read_register_by_name - if reg_name == "SR": - o = sup_func('res') - o <<= 1 - o |= sup_func('of') - o <<= 1 - o |= sup_func('scg1') - o <<= 1 - o |= sup_func('scg0') - o <<= 1 - o |= sup_func('osc') - o <<= 1 - o |= sup_func('cpuoff') - o <<= 1 - o |= sup_func('gie') - o <<= 1 - o |= sup_func('nf') - o <<= 1 - o |= sup_func('zf') - o <<= 1 - o |= sup_func('cf') - - return o - else: - return sup_func(reg_name) - diff --git a/miasm2/analysis/machine.py b/miasm2/analysis/machine.py deleted file mode 100644 index f12b7e57..00000000 --- a/miasm2/analysis/machine.py +++ /dev/null @@ -1,265 +0,0 @@ -#-*- coding:utf-8 -*- - - -class Machine(object): - """Abstract machine architecture to restrict architecture dependent code""" - - __dis_engine = None # Disassembly engine - __mn = None # Machine instance - __ira = None # IR analyser - __jitter = None # Jit engine - __gdbserver = None # GdbServer handler - - __available = ["arml", "armb", "armtl", "armtb", "sh4", "x86_16", "x86_32", - "x86_64", "msp430", "mips32b", "mips32l", - "aarch64l", "aarch64b", "ppc32b", "mepl", "mepb"] - - - def __init__(self, machine_name): - - dis_engine = None - mn = None - ira = None - ir = None - jitter = None - gdbserver = None - jit = None - jitter = None - log_jit = None - log_arch = None - - # Import on runtime for performance issue - if machine_name == "arml": - from miasm2.arch.arm.disasm import dis_arml as dis_engine - from miasm2.arch.arm import arch - try: - from miasm2.arch.arm import jit - jitter = jit.jitter_arml - except ImportError: - pass - mn = arch.mn_arm - from miasm2.arch.arm.ira import ir_a_arml as ira - from miasm2.arch.arm.sem import ir_arml as ir - elif machine_name == "armb": - from miasm2.arch.arm.disasm import dis_armb as dis_engine - from miasm2.arch.arm import arch - try: - from miasm2.arch.arm import jit - jitter = jit.jitter_armb - except ImportError: - pass - mn = arch.mn_arm - from miasm2.arch.arm.ira import ir_a_armb as ira - from miasm2.arch.arm.sem import ir_armb as ir - elif machine_name == "aarch64l": - from miasm2.arch.aarch64.disasm import dis_aarch64l as dis_engine - from miasm2.arch.aarch64 import arch - try: - from miasm2.arch.aarch64 import jit - jitter = jit.jitter_aarch64l - except ImportError: - pass - mn = arch.mn_aarch64 - from miasm2.arch.aarch64.ira import ir_a_aarch64l as ira - from miasm2.arch.aarch64.sem import ir_aarch64l as ir - elif machine_name == "aarch64b": - from miasm2.arch.aarch64.disasm import dis_aarch64b as dis_engine - from miasm2.arch.aarch64 import arch - try: - from miasm2.arch.aarch64 import jit - jitter = jit.jitter_aarch64b - except ImportError: - pass - mn = arch.mn_aarch64 - from miasm2.arch.aarch64.ira import ir_a_aarch64b as ira - from miasm2.arch.aarch64.sem import ir_aarch64b as ir - elif machine_name == "armtl": - from miasm2.arch.arm.disasm import dis_armtl as dis_engine - from miasm2.arch.arm import arch - mn = arch.mn_armt - from miasm2.arch.arm.ira import ir_a_armtl as ira - from miasm2.arch.arm.sem import ir_armtl as ir - try: - from miasm2.arch.arm import jit - jitter = jit.jitter_armtl - except ImportError: - pass - elif machine_name == "armtb": - from miasm2.arch.arm.disasm import dis_armtb as dis_engine - from miasm2.arch.arm import arch - mn = arch.mn_armt - from miasm2.arch.arm.ira import ir_a_armtb as ira - from miasm2.arch.arm.sem import ir_armtb as ir - elif machine_name == "sh4": - from miasm2.arch.sh4 import arch - mn = arch.mn_sh4 - elif machine_name == "x86_16": - from miasm2.arch.x86.disasm import dis_x86_16 as dis_engine - from miasm2.arch.x86 import arch - try: - from miasm2.arch.x86 import jit - jitter = jit.jitter_x86_16 - except ImportError: - pass - mn = arch.mn_x86 - from miasm2.arch.x86.ira import ir_a_x86_16 as ira - from miasm2.arch.x86.sem import ir_x86_16 as ir - elif machine_name == "x86_32": - from miasm2.arch.x86.disasm import dis_x86_32 as dis_engine - from miasm2.arch.x86 import arch - try: - from miasm2.arch.x86 import jit - jitter = jit.jitter_x86_32 - except ImportError: - pass - mn = arch.mn_x86 - from miasm2.arch.x86.ira import ir_a_x86_32 as ira - from miasm2.arch.x86.sem import ir_x86_32 as ir - try: - from miasm2.analysis.gdbserver import GdbServer_x86_32 as gdbserver - except ImportError: - pass - elif machine_name == "x86_64": - from miasm2.arch.x86.disasm import dis_x86_64 as dis_engine - from miasm2.arch.x86 import arch - try: - from miasm2.arch.x86 import jit - jitter = jit.jitter_x86_64 - except ImportError: - pass - mn = arch.mn_x86 - from miasm2.arch.x86.ira import ir_a_x86_64 as ira - from miasm2.arch.x86.sem import ir_x86_64 as ir - elif machine_name == "msp430": - from miasm2.arch.msp430.disasm import dis_msp430 as dis_engine - from miasm2.arch.msp430 import arch - try: - from miasm2.arch.msp430 import jit - jitter = jit.jitter_msp430 - except ImportError: - pass - mn = arch.mn_msp430 - from miasm2.arch.msp430.ira import ir_a_msp430 as ira - from miasm2.arch.msp430.sem import ir_msp430 as ir - try: - from miasm2.analysis.gdbserver import GdbServer_msp430 as gdbserver - except ImportError: - pass - elif machine_name == "mips32b": - from miasm2.arch.mips32.disasm import dis_mips32b as dis_engine - from miasm2.arch.mips32 import arch - try: - from miasm2.arch.mips32 import jit - jitter = jit.jitter_mips32b - except ImportError: - pass - mn = arch.mn_mips32 - from miasm2.arch.mips32.ira import ir_a_mips32b as ira - from miasm2.arch.mips32.sem import ir_mips32b as ir - elif machine_name == "mips32l": - from miasm2.arch.mips32.disasm import dis_mips32l as dis_engine - from miasm2.arch.mips32 import arch - try: - from miasm2.arch.mips32 import jit - jitter = jit.jitter_mips32l - except ImportError: - pass - mn = arch.mn_mips32 - from miasm2.arch.mips32.ira import ir_a_mips32l as ira - from miasm2.arch.mips32.sem import ir_mips32l as ir - elif machine_name == "ppc32b": - from miasm2.arch.ppc.disasm import dis_ppc32b as dis_engine - from miasm2.arch.ppc import arch - try: - from miasm2.arch.ppc import jit - jitter = jit.jitter_ppc32b - except ImportError: - pass - mn = arch.mn_ppc - from miasm2.arch.ppc.ira import ir_a_ppc32b as ira - from miasm2.arch.ppc.sem import ir_ppc32b as ir - elif machine_name == "mepb": - from miasm2.arch.mep.disasm import dis_mepb as dis_engine - from miasm2.arch.mep import arch - try: - from miasm2.arch.mep import jit - jitter = jit.jitter_mepb - except ImportError: - pass - mn = arch.mn_mep - from miasm2.arch.mep.ira import ir_a_mepb as ira - from miasm2.arch.mep.sem import ir_mepb as ir - elif machine_name == "mepl": - from miasm2.arch.mep.disasm import dis_mepl as dis_engine - from miasm2.arch.mep import arch - try: - from miasm2.arch.mep import jit - jitter = jit.jitter_mepl - except ImportError: - pass - mn = arch.mn_mep - from miasm2.arch.mep.ira import ir_a_mepl as ira - from miasm2.arch.mep.sem import ir_mepl as ir - else: - raise ValueError('Unknown machine: %s' % machine_name) - - # Loggers - if jit is not None: - log_jit = jit.log - log_arch = arch.log - - self.__dis_engine = dis_engine - self.__mn = mn - self.__ira = ira - self.__jitter = jitter - self.__gdbserver = gdbserver - self.__log_jit = log_jit - self.__log_arch = log_arch - self.__base_expr = arch.base_expr - self.__ir = ir - self.__name = machine_name - - @property - def dis_engine(self): - return self.__dis_engine - - @property - def mn(self): - return self.__mn - - @property - def ira(self): - return self.__ira - - @property - def ir(self): - return self.__ir - - @property - def jitter(self): - return self.__jitter - - @property - def gdbserver(self): - return self.__gdbserver - - @property - def log_jit(self): - return self.__log_jit - - @property - def log_arch(self): - return self.__log_arch - - @property - def base_expr(self): - return self.__base_expr - - @property - def name(self): - return self.__name - - @classmethod - def available_machine(cls): - "Return a list of supported machines" - return cls.__available diff --git a/miasm2/analysis/modularintervals.py b/miasm2/analysis/modularintervals.py deleted file mode 100644 index 2195598b..00000000 --- a/miasm2/analysis/modularintervals.py +++ /dev/null @@ -1,530 +0,0 @@ -"""Intervals with a maximum size, supporting modular arithmetic""" - -from future.builtins import range -from builtins import int as int_types -from itertools import product - -from miasm2.core.interval import interval - -class ModularIntervals(object): - """Intervals with a maximum size, supporting modular arithmetic""" - - def __init__(self, size, intervals=None): - """Instantiate a ModularIntervals of size @size - @size: maximum size of elements - @intervals: (optional) interval instance, or any type supported by - interval initialisation; element of the current instance - """ - # Create or cast @intervals argument - if intervals is None: - intervals = interval() - if not isinstance(intervals, interval): - intervals = interval(intervals) - self.intervals = intervals - self.size = size - - # Sanity check - start, end = intervals.hull() - if start is not None: - assert start >= 0 - if end is not None: - assert end <= self.mask - - # Helpers - - @staticmethod - def size2mask(size): - """Return the bit mask of size @size""" - return (1 << size) - 1 - - def _range2interval(func): - """Convert a function taking 2 ranges to a function taking a ModularIntervals - and applying to the current instance""" - def ret_func(self, target): - ret = interval() - for left_i, right_i in product(self.intervals, target.intervals): - ret += func(self, left_i[0], left_i[1], right_i[0], - right_i[1]) - return self.__class__(self.size, ret) - return ret_func - - def _range2integer(func): - """Convert a function taking 1 range and optional arguments to a function - applying to the current instance""" - def ret_func(self, *args): - ret = interval() - for x_min, x_max in self.intervals: - ret += func(self, x_min, x_max, *args) - return self.__class__(self.size, ret) - return ret_func - - def _promote(func): - """Check and promote the second argument from integer to - ModularIntervals with one value""" - def ret_func(self, target): - if isinstance(target, int_types): - target = ModularIntervals(self.size, interval([(target, target)])) - if not isinstance(target, ModularIntervals): - raise TypeError("Unsupported operation with %s" % target.__class__) - if target.size != self.size: - raise TypeError("Size are not the same: %s vs %s" % (self.size, - target.size)) - return func(self, target) - return ret_func - - def _unsigned2signed(self, value): - """Return the signed value of @value, based on self.size""" - if (value & (1 << (self.size - 1))): - return -(self.mask ^ value) - 1 - else: - return value - - def _signed2unsigned(self, value): - """Return the unsigned value of @value, based on self.size""" - return value & self.mask - - # Operation internals - # - # Naming convention: - # _range_{op}: takes 2 interval bounds and apply op - # _range_{op}_uniq: takes 1 interval bounds and apply op - # _interval_{op}: apply op on an ModularIntervals - # _integer_{op}: apply op on itself with possible arguments - - def _range_add(self, x_min, x_max, y_min, y_max): - """Bounds interval for x + y, with - - x, y of size 'self.size' - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - max_bound = self.mask - if (x_min + y_min <= max_bound and - x_max + y_max >= max_bound + 1): - # HD returns 0, max_bound; but this is because it cannot handle multiple - # interval. - # x_max + y_max can only overflow once, so returns - # [result_min, overflow] U [0, overflow_rest] - return interval([(x_min + y_min, max_bound), - (0, (x_max + y_max) & max_bound)]) - else: - return interval([((x_min + y_min) & max_bound, - (x_max + y_max) & max_bound)]) - - _interval_add = _range2interval(_range_add) - - def _range_minus_uniq(self, x_min, x_max): - """Bounds interval for -x, with - - x of size self.size - - @x_min <= x <= @x_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - max_bound = self.mask - if (x_min == 0 and x_max != 0): - # HD returns 0, max_bound; see _range_add - return interval([(0, 0), ((- x_max) & max_bound, max_bound)]) - else: - return interval([((- x_max) & max_bound, (- x_min) & max_bound)]) - - _interval_minus = _range2integer(_range_minus_uniq) - - def _range_or_min(self, x_min, x_max, y_min, y_max): - """Interval min for x | y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - max_bit = 1 << (self.size - 1) - while max_bit: - if ~x_min & y_min & max_bit: - temp = (x_min | max_bit) & - max_bit - if temp <= x_max: - x_min = temp - break - elif x_min & ~y_min & max_bit: - temp = (y_min | max_bit) & - max_bit - if temp <= y_max: - y_min = temp - break - max_bit >>= 1 - return x_min | y_min - - def _range_or_max(self, x_min, x_max, y_min, y_max): - """Interval max for x | y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - max_bit = 1 << (self.size - 1) - while max_bit: - if x_max & y_max & max_bit: - temp = (x_max - max_bit) | (max_bit - 1) - if temp >= x_min: - x_max = temp - break - temp = (y_max - max_bit) | (max_bit - 1) - if temp >= y_min: - y_max = temp - break - max_bit >>= 1 - return x_max | y_max - - def _range_or(self, x_min, x_max, y_min, y_max): - """Interval bounds for x | y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - return interval([(self._range_or_min(x_min, x_max, y_min, y_max), - self._range_or_max(x_min, x_max, y_min, y_max))]) - - _interval_or = _range2interval(_range_or) - - def _range_and_min(self, x_min, x_max, y_min, y_max): - """Interval min for x & y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - max_bit = (1 << (self.size - 1)) - while max_bit: - if ~x_min & ~y_min & max_bit: - temp = (x_min | max_bit) & - max_bit - if temp <= x_max: - x_min = temp - break - temp = (y_min | max_bit) & - max_bit - if temp <= y_max: - y_min = temp - break - max_bit >>= 1 - return x_min & y_min - - def _range_and_max(self, x_min, x_max, y_min, y_max): - """Interval max for x & y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - max_bit = (1 << (self.size - 1)) - while max_bit: - if x_max & ~y_max & max_bit: - temp = (x_max & ~max_bit) | (max_bit - 1) - if temp >= x_min: - x_max = temp - break - elif ~x_max & y_max & max_bit: - temp = (y_max & ~max_bit) | (max_bit - 1) - if temp >= y_min: - y_max = temp - break - max_bit >>= 1 - return x_max & y_max - - def _range_and(self, x_min, x_max, y_min, y_max): - """Interval bounds for x & y, with - - x, y of size @size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - return interval([(self._range_and_min(x_min, x_max, y_min, y_max), - self._range_and_max(x_min, x_max, y_min, y_max))]) - - _interval_and = _range2interval(_range_and) - - def _range_xor(self, x_min, x_max, y_min, y_max): - """Interval bounds for x ^ y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - From Hacker's Delight: Chapter 4 - """ - not_size = lambda x: x ^ self.mask - min_xor = self._range_and_min(x_min, x_max, not_size(y_max), not_size(y_min)) | self._range_and_min(not_size(x_max), not_size(x_min), y_min, y_max) - max_xor = self._range_or_max(0, - self._range_and_max(x_min, x_max, not_size(y_max), not_size(y_min)), - 0, - self._range_and_max(not_size(x_max), not_size(x_min), y_min, y_max)) - return interval([(min_xor, max_xor)]) - - _interval_xor = _range2interval(_range_xor) - - def _range_mul(self, x_min, x_max, y_min, y_max): - """Interval bounds for x * y, with - - x, y of size self.size - - @x_min <= x <= @x_max - - @y_min <= y <= @y_max - - operations are considered unsigned - This is a naive version, going to TOP on overflow""" - max_bound = self.mask - if y_max * x_max > max_bound: - return interval([(0, max_bound)]) - else: - return interval([(x_min * y_min, x_max * y_max)]) - - _interval_mul = _range2interval(_range_mul) - - def _range_mod_uniq(self, x_min, x_max, mod): - """Interval bounds for x % @mod, with - - x, @mod of size self.size - - @x_min <= x <= @x_max - - operations are considered unsigned - """ - if (x_max - x_min) >= mod: - return interval([(0, mod - 1)]) - x_max = x_max % mod - x_min = x_min % mod - if x_max < x_min: - return interval([(0, x_max), (x_min, mod - 1)]) - else: - return interval([(x_min, x_max)]) - - _integer_modulo = _range2integer(_range_mod_uniq) - - def _range_shift_uniq(self, x_min, x_max, shift, op): - """Bounds interval for x @op @shift with - - x of size self.size - - @x_min <= x <= @x_max - - operations are considered unsigned - - shift <= self.size - """ - assert shift <= self.size - # Shift operations are monotonic, and overflow results in 0 - max_bound = self.mask - - if op == "<<": - obtain_max = x_max << shift - if obtain_max > max_bound: - # Overflow at least on max, best-effort - # result '0' often happen, include it - return interval([(0, 0), ((1 << shift) - 1, max_bound)]) - else: - return interval([(x_min << shift, obtain_max)]) - elif op == ">>": - return interval([((x_min >> shift) & max_bound, - (x_max >> shift) & max_bound)]) - elif op == "a>>": - # The Miasm2 version (Expr or ModInt) could have been used, but - # introduce unnecessary dependencies for this module - # Python >> is the arithmetic one - ashr = lambda x, y: self._signed2unsigned(self._unsigned2signed(x) >> y) - end_min, end_max = ashr(x_min, shift), ashr(x_max, shift) - end_min, end_max = min(end_min, end_max), max(end_min, end_max) - return interval([(end_min, end_max)]) - else: - raise ValueError("%s is not a shifter" % op) - - def _interval_shift(self, operation, shifter): - """Apply the shifting operation @operation with a shifting - ModularIntervals @shifter on the current instance""" - # Work on a copy of shifter intervals - shifter = interval(shifter.intervals) - if (shifter.hull()[1] >= self.size): - shifter += interval([(self.size, self.size)]) - shifter &= interval([(0, self.size)]) - ret = interval() - for shift_range in shifter: - for shift in range(shift_range[0], shift_range[1] + 1): - for x_min, x_max in self.intervals: - ret += self._range_shift_uniq(x_min, x_max, shift, operation) - return self.__class__(self.size, ret) - - def _range_rotate_uniq(self, x_min, x_max, shift, op): - """Bounds interval for x @op @shift with - - x of size self.size - - @x_min <= x <= @x_max - - operations are considered unsigned - - shift <= self.size - """ - assert shift <= self.size - # Divide in sub-operations: a op b: a left b | a right (size - b) - if op == ">>>": - left, right = ">>", "<<" - elif op == "<<<": - left, right = "<<", ">>" - else: - raise ValueError("Not a rotator: %s" % op) - - left_intervals = self._range_shift_uniq(x_min, x_max, shift, left) - right_intervals = self._range_shift_uniq(x_min, x_max, - self.size - shift, right) - - result = self.__class__(self.size, left_intervals) | self.__class__(self.size, right_intervals) - return result.intervals - - def _interval_rotate(self, operation, shifter): - """Apply the rotate operation @operation with a shifting - ModularIntervals @shifter on the current instance""" - # Consider only rotation without repetition, and enumerate - # -> apply a '% size' on shifter - shifter %= self.size - ret = interval() - for shift_range in shifter: - for shift in range(shift_range[0], shift_range[1] + 1): - for x_min, x_max in self.intervals: - ret += self._range_rotate_uniq(x_min, x_max, shift, - operation) - - return self.__class__(self.size, ret) - - # Operation wrappers - - @_promote - def __add__(self, to_add): - """Add @to_add to the current intervals - @to_add: ModularInstances or integer - """ - return self._interval_add(to_add) - - @_promote - def __or__(self, to_or): - """Bitwise OR @to_or to the current intervals - @to_or: ModularInstances or integer - """ - return self._interval_or(to_or) - - @_promote - def __and__(self, to_and): - """Bitwise AND @to_and to the current intervals - @to_and: ModularInstances or integer - """ - return self._interval_and(to_and) - - @_promote - def __xor__(self, to_xor): - """Bitwise XOR @to_xor to the current intervals - @to_xor: ModularInstances or integer - """ - return self._interval_xor(to_xor) - - @_promote - def __mul__(self, to_mul): - """Multiply @to_mul to the current intervals - @to_mul: ModularInstances or integer - """ - return self._interval_mul(to_mul) - - @_promote - def __rshift__(self, to_shift): - """Logical shift right the current intervals of @to_shift - @to_shift: ModularInstances or integer - """ - return self._interval_shift('>>', to_shift) - - @_promote - def __lshift__(self, to_shift): - """Logical shift left the current intervals of @to_shift - @to_shift: ModularInstances or integer - """ - return self._interval_shift('<<', to_shift) - - @_promote - def arithmetic_shift_right(self, to_shift): - """Arithmetic shift right the current intervals of @to_shift - @to_shift: ModularInstances or integer - """ - return self._interval_shift('a>>', to_shift) - - def __neg__(self): - """Negate the current intervals""" - return self._interval_minus() - - def __mod__(self, modulo): - """Apply % @modulo on the current intervals - @modulo: integer - """ - - if not isinstance(modulo, int_types): - raise TypeError("Modulo with %s is not supported" % modulo.__class__) - return self._integer_modulo(modulo) - - @_promote - def rotation_right(self, to_rotate): - """Right rotate the current intervals of @to_rotate - @to_rotate: ModularInstances or integer - """ - return self._interval_rotate('>>>', to_rotate) - - @_promote - def rotation_left(self, to_rotate): - """Left rotate the current intervals of @to_rotate - @to_rotate: ModularInstances or integer - """ - return self._interval_rotate('<<<', to_rotate) - - # Instance operations - - @property - def mask(self): - """Return the mask corresponding to the instance size""" - return ModularIntervals.size2mask(self.size) - - def __iter__(self): - return iter(self.intervals) - - @property - def length(self): - return self.intervals.length - - def __contains__(self, other): - if isinstance(other, ModularIntervals): - other = other.intervals - return other in self.intervals - - def __str__(self): - return "%s (Size: %s)" % (self.intervals, self.size) - - def size_update(self, new_size): - """Update the instance size to @new_size - The size of elements must be <= @new_size""" - - # Increasing size is always safe - if new_size < self.size: - # Check that current values are indeed included in the new range - assert self.intervals.hull()[1] <= ModularIntervals.size2mask(new_size) - - self.size = new_size - - # For easy chainning - return self - - # Mimic Python's set operations - - @_promote - def union(self, to_union): - """Union set operation with @to_union - @to_union: ModularIntervals instance""" - return ModularIntervals(self.size, self.intervals + to_union.intervals) - - @_promote - def update(self, to_union): - """Union set operation in-place with @to_union - @to_union: ModularIntervals instance""" - self.intervals += to_union.intervals - - @_promote - def intersection(self, to_intersect): - """Intersection set operation with @to_intersect - @to_intersect: ModularIntervals instance""" - return ModularIntervals(self.size, self.intervals & to_intersect.intervals) - - @_promote - def intersection_update(self, to_intersect): - """Intersection set operation in-place with @to_intersect - @to_intersect: ModularIntervals instance""" - self.intervals &= to_intersect.intervals diff --git a/miasm2/analysis/outofssa.py b/miasm2/analysis/outofssa.py deleted file mode 100644 index 41c665af..00000000 --- a/miasm2/analysis/outofssa.py +++ /dev/null @@ -1,413 +0,0 @@ -from future.utils import viewitems, viewvalues - -from miasm2.expression.expression import ExprId -from miasm2.ir.ir import IRBlock, AssignBlock -from miasm2.analysis.ssa import get_phi_sources_parent_block, \ - irblock_has_phi - - -class Varinfo(object): - """Store liveness information for a variable""" - __slots__ = ["live_index", "loc_key", "index"] - - def __init__(self, live_index, loc_key, index): - self.live_index = live_index - self.loc_key = loc_key - self.index = index - - -class UnSSADiGraph(object): - """ - Implements unssa algorithm - Revisiting Out-of-SSA Translation for Correctness, Code Quality, and - Efficiency - """ - - def __init__(self, ssa, head, cfg_liveness): - self.cfg_liveness = cfg_liveness - self.ssa = ssa - self.head = head - - # Set of created variables - self.copy_vars = set() - # Virtual parallel copies - - # On loc_key's Phi node dst -> set((parent, src)) - self.phi_parent_sources = {} - # On loc_key's Phi node, loc_key -> set(Phi dsts) - self.phi_destinations = {} - # Phi's dst -> new var - self.phi_new_var = {} - # For a new_var representing dst: - # new_var -> set(parents of Phi's src in dst = Phi(src,...)) - self.new_var_to_srcs_parents = {} - # new_var -> set(variables to be coalesced with, named "merge_set") - self.merge_state = {} - - # Launch the algorithm in several steps - self.isolate_phi_nodes_block() - self.init_phis_merge_state() - self.order_ssa_var_dom() - self.aggressive_coalesce_block() - self.insert_parallel_copy() - self.replace_merge_sets() - self.remove_assign_eq() - - def insert_parallel_copy(self): - """ - Naive Out-of-SSA from CSSA (without coalescing for now) - - Replace Phi - - Create room for parallel copies in Phi's parents - """ - ircfg = self.ssa.graph - - for irblock in list(viewvalues(ircfg.blocks)): - if not irblock_has_phi(irblock): - continue - - # Replace Phi with Phi's dst = new_var - parallel_copies = {} - for dst in self.phi_destinations[irblock.loc_key]: - new_var = self.phi_new_var[dst] - parallel_copies[dst] = new_var - - assignblks = list(irblock) - assignblks[0] = AssignBlock(parallel_copies, irblock[0].instr) - new_irblock = IRBlock(irblock.loc_key, assignblks) - ircfg.blocks[irblock.loc_key] = new_irblock - - # Insert new_var = src in each Phi's parent, at the end of the block - parent_to_parallel_copies = {} - parallel_copies = {} - for dst in irblock[0]: - new_var = self.phi_new_var[dst] - for parent, src in self.phi_parent_sources[dst]: - parent_to_parallel_copies.setdefault(parent, {})[new_var] = src - - for parent, parallel_copies in viewitems(parent_to_parallel_copies): - parent = ircfg.blocks[parent] - assignblks = list(parent) - assignblks.append(AssignBlock(parallel_copies, parent[-1].instr)) - new_irblock = IRBlock(parent.loc_key, assignblks) - ircfg.blocks[parent.loc_key] = new_irblock - - def create_copy_var(self, var): - """ - Generate a new var standing for @var - @var: variable to replace - """ - new_var = ExprId('var%d' % len(self.copy_vars), var.size) - self.copy_vars.add(new_var) - return new_var - - def isolate_phi_nodes_block(self): - """ - Init structures and virtually insert parallel copy before/after each phi - node - """ - ircfg = self.ssa.graph - for irblock in viewvalues(ircfg.blocks): - if not irblock_has_phi(irblock): - continue - for dst, sources in viewitems(irblock[0]): - assert sources.is_op('Phi') - new_var = self.create_copy_var(dst) - self.phi_new_var[dst] = new_var - - var_to_parents = get_phi_sources_parent_block( - self.ssa.graph, - irblock.loc_key, - sources.args - ) - - for src in sources.args: - parents = var_to_parents[src] - self.new_var_to_srcs_parents.setdefault(new_var, set()).update(parents) - for parent in parents: - self.phi_parent_sources.setdefault(dst, set()).add((parent, src)) - - self.phi_destinations[irblock.loc_key] = set(irblock[0]) - - def init_phis_merge_state(self): - """ - Generate trivial coalescing of phi variable and itself - """ - for phi_new_var in viewvalues(self.phi_new_var): - self.merge_state.setdefault(phi_new_var, set([phi_new_var])) - - def order_ssa_var_dom(self): - """Compute dominance order of each ssa variable""" - ircfg = self.ssa.graph - - # compute dominator tree - dominator_tree = ircfg.compute_dominator_tree(self.head) - - # variable -> Varinfo - self.var_to_varinfo = {} - # live_index can later be used to compare dominance of AssignBlocks - live_index = 0 - - # walk in DFS over the dominator tree - for loc_key in dominator_tree.walk_depth_first_forward(self.head): - irblock = ircfg.blocks[loc_key] - - # Create live index for phi new vars - # They do not exist in the graph yet, so index is set to None - if irblock_has_phi(irblock): - for dst in irblock[0]: - if not dst.is_id(): - continue - new_var = self.phi_new_var[dst] - self.var_to_varinfo[new_var] = Varinfo(live_index, loc_key, None) - - live_index += 1 - - # Create live index for remaining assignments - for index, assignblk in enumerate(irblock): - used = False - for dst in assignblk: - if not dst.is_id(): - continue - if dst in self.ssa.immutable_ids: - # Will not be considered by the current algo, ignore it - # (for instance, IRDst) - continue - - assert dst not in self.var_to_varinfo - self.var_to_varinfo[dst] = Varinfo(live_index, loc_key, index) - used = True - if used: - live_index += 1 - - - def ssa_def_dominates(self, node_a, node_b): - """ - Return living index order of @node_a and @node_b - @node_a: Varinfo instance - @node_b: Varinfo instance - """ - ret = self.var_to_varinfo[node_a].live_index <= self.var_to_varinfo[node_b].live_index - return ret - - def merge_set_sort(self, merge_set): - """ - Return a sorted list of (live_index, var) from @merge_set in dominance - order - @merge_set: set of coalescing variables - """ - return sorted( - (self.var_to_varinfo[var].live_index, var) - for var in merge_set - ) - - def ssa_def_is_live_at(self, node_a, node_b, parent): - """ - Return True if @node_a is live during @node_b definition - If @parent is None, this is a liveness test for a post phi variable; - Else, it is a liveness test for a variable source of the phi node - - @node_a: Varinfo instance - @node_b: Varinfo instance - @parent: Optional parent location of the phi source - """ - loc_key_b, index_b = self.var_to_varinfo[node_b].loc_key, self.var_to_varinfo[node_b].index - if parent and index_b is None: - index_b = 0 - if node_a not in self.new_var_to_srcs_parents: - # node_a is not a new var (it is a "classic" var) - # -> use a basic liveness test - liveness_b = self.cfg_liveness.blocks[loc_key_b].infos[index_b] - return node_a in liveness_b.var_out - - for def_loc_key in self.new_var_to_srcs_parents[node_a]: - # Consider node_a as defined at the end of its parents blocks - # and compute liveness check accordingly - - if def_loc_key == parent: - # Same path as node_a definition, so SSA ensure b cannot be live - # on this path (otherwise, a Phi would already happen earlier) - continue - liveness_end_block = self.cfg_liveness.blocks[def_loc_key].infos[-1] - if node_b in liveness_end_block.var_out: - return True - return False - - def merge_nodes_interfere(self, node_a, node_b, parent): - """ - Return True if @node_a and @node_b interfere - @node_a: variable - @node_b: variable - @parent: Optional parent location of the phi source for liveness tests - - Interference check is: is x live at y definition (or reverse) - TODO: add Value-based interference improvement - """ - if self.var_to_varinfo[node_a].live_index == self.var_to_varinfo[node_b].live_index: - # Defined in the same AssignBlock -> interfere - return True - - if self.var_to_varinfo[node_a].live_index < self.var_to_varinfo[node_b].live_index: - return self.ssa_def_is_live_at(node_a, node_b, parent) - return self.ssa_def_is_live_at(node_b, node_a, parent) - - def merge_sets_interfere(self, merge_a, merge_b, parent): - """ - Return True if no variable in @merge_a and @merge_b interferes. - - Implementation of "Algorithm 2: Check intersection in a set of variables" - - @merge_a: a dom ordered list of equivalent variables - @merge_b: a dom ordered list of equivalent variables - @parent: Optional parent location of the phi source for liveness tests - """ - if merge_a == merge_b: - # No need to consider interference if equal - return False - - merge_a_list = self.merge_set_sort(merge_a) - merge_b_list = self.merge_set_sort(merge_b) - dom = [] - while merge_a_list or merge_b_list: - if not merge_a_list: - _, current = merge_b_list.pop(0) - elif not merge_b_list: - _, current = merge_a_list.pop(0) - else: - # compare live_indexes (standing for dominance) - if merge_a_list[-1] < merge_b_list[-1]: - _, current = merge_a_list.pop(0) - else: - _, current = merge_b_list.pop(0) - while dom and not self.ssa_def_dominates(dom[-1], current): - dom.pop() - - # Don't test node in same merge_set - if ( - # Is stack not empty? - dom and - # Trivial non-interference if dom.top() and current come - # from the same merge set - not (dom[-1] in merge_a and current in merge_a) and - not (dom[-1] in merge_b and current in merge_b) and - # Actually test for interference - self.merge_nodes_interfere(current, dom[-1], parent) - ): - return True - dom.append(current) - return False - - def aggressive_coalesce_parallel_copy(self, parallel_copies, parent): - """ - Try to coalesce variables each dst/src couple together from - @parallel_copies - - @parallel_copies: a dictionary representing dst/src parallel - assignments. - @parent: Optional parent location of the phi source for liveness tests - """ - for dst, src in viewitems(parallel_copies): - dst_merge = self.merge_state.setdefault(dst, set([dst])) - src_merge = self.merge_state.setdefault(src, set([src])) - if not self.merge_sets_interfere(dst_merge, src_merge, parent): - dst_merge.update(src_merge) - for node in dst_merge: - self.merge_state[node] = dst_merge - - def aggressive_coalesce_block(self): - """Try to coalesce phi var with their pre/post variables""" - - ircfg = self.ssa.graph - - # Run coalesce on the post phi parallel copy - for irblock in viewvalues(ircfg.blocks): - if not irblock_has_phi(irblock): - continue - parallel_copies = {} - for dst in self.phi_destinations[irblock.loc_key]: - parallel_copies[dst] = self.phi_new_var[dst] - self.aggressive_coalesce_parallel_copy(parallel_copies, None) - - # Run coalesce on the pre phi parallel copy - - # Stand for the virtual parallel copies at the end of Phi's block - # parents - parent_to_parallel_copies = {} - for dst in irblock[0]: - new_var = self.phi_new_var[dst] - for parent, src in self.phi_parent_sources[dst]: - parent_to_parallel_copies.setdefault(parent, {})[new_var] = src - - for parent, parallel_copies in viewitems(parent_to_parallel_copies): - self.aggressive_coalesce_parallel_copy(parallel_copies, parent) - - def get_best_merge_set_name(self, merge_set): - """ - For a given @merge_set, prefer an original SSA variable instead of a - created copy. In other case, take a random name. - @merge_set: set of equivalent expressions - """ - if not merge_set: - raise RuntimeError("Merge set should not be empty") - for var in merge_set: - if var not in self.copy_vars: - return var - # Get random name - return var - - - def replace_merge_sets(self): - """ - In the graph, replace all variables from merge state by their - representative variable - """ - replace = {} - merge_sets = set() - - # Elect representative for merge sets - merge_set_to_name = {} - for merge_set in viewvalues(self.merge_state): - frozen_merge_set = frozenset(merge_set) - merge_sets.add(frozen_merge_set) - var_name = self.get_best_merge_set_name(merge_set) - merge_set_to_name[frozen_merge_set] = var_name - - # Generate replacement of variable by their representative - for merge_set in merge_sets: - var_name = merge_set_to_name[merge_set] - merge_set = list(merge_set) - for var in merge_set: - replace[var] = var_name - - self.ssa.graph.simplify(lambda x: x.replace_expr(replace)) - - def remove_phi(self): - """ - Remove phi operators in @ifcfg - @ircfg: IRDiGraph instance - """ - - for irblock in list(viewvalues(self.ssa.graph.blocks)): - assignblks = list(irblock) - out = {} - for dst, src in viewitems(assignblks[0]): - if src.is_op('Phi'): - assert set([dst]) == set(src.args) - continue - out[dst] = src - assignblks[0] = AssignBlock(out, assignblks[0].instr) - self.ssa.graph.blocks[irblock.loc_key] = IRBlock(irblock.loc_key, assignblks) - - def remove_assign_eq(self): - """ - Remove trivial expressions (a=a) in the current graph - """ - for irblock in list(viewvalues(self.ssa.graph.blocks)): - assignblks = list(irblock) - for i, assignblk in enumerate(assignblks): - out = {} - for dst, src in viewitems(assignblk): - if dst == src: - continue - out[dst] = src - assignblks[i] = AssignBlock(out, assignblk.instr) - self.ssa.graph.blocks[irblock.loc_key] = IRBlock(irblock.loc_key, assignblks) diff --git a/miasm2/analysis/sandbox.py b/miasm2/analysis/sandbox.py deleted file mode 100644 index d3e8fce1..00000000 --- a/miasm2/analysis/sandbox.py +++ /dev/null @@ -1,1026 +0,0 @@ -from __future__ import print_function -from builtins import range - -import os -import logging -from argparse import ArgumentParser - -from future.utils import viewitems, viewvalues - -from miasm2.core.utils import force_bytes -from miasm2.analysis.machine import Machine -from miasm2.jitter.csts import PAGE_READ, PAGE_WRITE -from miasm2.analysis import debugging -from miasm2.jitter.jitload import log_func - - - -class Sandbox(object): - - """ - Parent class for Sandbox abstraction - """ - - CALL_FINISH_ADDR = 0x13371acc - - @staticmethod - def code_sentinelle(jitter): - jitter.run = False - return False - - @classmethod - def _classes_(cls): - """ - Iterator on parent classes except Sanbox - """ - for base_cls in cls.__bases__: - # Avoid infinite loop - if base_cls == Sandbox: - continue - - yield base_cls - - classes = property(lambda x: x.__class__._classes_()) - - def __init__(self, fname, options, custom_methods=None, **kwargs): - """ - Initialize a sandbox - @fname: str file name - @options: namespace instance of specific options - @custom_methods: { str => func } for custom API implementations - """ - - # Initialize - self.fname = fname - self.options = options - if custom_methods is None: - custom_methods = {} - for cls in self.classes: - if cls == Sandbox: - continue - if issubclass(cls, OS): - cls.__init__(self, custom_methods, **kwargs) - else: - cls.__init__(self, **kwargs) - - # Logging options - self.jitter.set_trace_log( - trace_instr=self.options.singlestep, - trace_regs=self.options.singlestep, - trace_new_blocks=self.options.dumpblocs - ) - - if not self.options.quiet_function_calls: - log_func.setLevel(logging.INFO) - - @classmethod - def parser(cls, *args, **kwargs): - """ - Return instance of instance parser with expecting options. - Extra parameters are passed to parser initialisation. - """ - - parser = ArgumentParser(*args, **kwargs) - parser.add_argument('-a', "--address", - help="Force entry point address", default=None) - parser.add_argument('-b', "--dumpblocs", action="store_true", - help="Log disasm blocks") - parser.add_argument('-z', "--singlestep", action="store_true", - help="Log single step") - parser.add_argument('-d', "--debugging", action="store_true", - help="Debug shell") - parser.add_argument('-g', "--gdbserver", type=int, - help="Listen on port @port") - parser.add_argument("-j", "--jitter", - help="Jitter engine. Possible values are: gcc (default), llvm, python", - default="gcc") - parser.add_argument( - '-q', "--quiet-function-calls", action="store_true", - help="Don't log function calls") - parser.add_argument('-i', "--dependencies", action="store_true", - help="Load PE and its dependencies") - - for base_cls in cls._classes_(): - base_cls.update_parser(parser) - return parser - - def run(self, addr=None): - """ - Launch emulation (gdbserver, debugging, basic JIT). - @addr: (int) start address - """ - if addr is None and self.options.address is not None: - addr = int(self.options.address, 0) - - if any([self.options.debugging, self.options.gdbserver]): - dbg = debugging.Debugguer(self.jitter) - self.dbg = dbg - dbg.init_run(addr) - - if self.options.gdbserver: - port = self.options.gdbserver - print("Listen on port %d" % port) - gdb = self.machine.gdbserver(dbg, port) - self.gdb = gdb - gdb.run() - else: - cmd = debugging.DebugCmd(dbg) - self.cmd = cmd - cmd.cmdloop() - - else: - self.jitter.init_run(addr) - self.jitter.continue_run() - - def call(self, prepare_cb, addr, *args): - """ - Direct call of the function at @addr, with arguments @args prepare in - calling convention implemented by @prepare_cb - @prepare_cb: func(ret_addr, *args) - @addr: address of the target function - @args: arguments - """ - self.jitter.init_run(addr) - self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.code_sentinelle) - prepare_cb(self.CALL_FINISH_ADDR, *args) - self.jitter.continue_run() - - - -class OS(object): - - """ - Parent class for OS abstraction - """ - - def __init__(self, custom_methods, **kwargs): - pass - - @classmethod - def update_parser(cls, parser): - pass - - -class Arch(object): - - """ - Parent class for Arch abstraction - """ - - # Architecture name - _ARCH_ = None - - def __init__(self, **kwargs): - self.machine = Machine(self._ARCH_) - self.jitter = self.machine.jitter(self.options.jitter) - - @classmethod - def update_parser(cls, parser): - pass - - -class OS_Win(OS): - # DLL to import - ALL_IMP_DLL = ["ntdll.dll", "kernel32.dll", "user32.dll", - "ole32.dll", "urlmon.dll", - "ws2_32.dll", 'advapi32.dll', "psapi.dll", - ] - modules_path = "win_dll" - - def __init__(self, custom_methods, *args, **kwargs): - from miasm2.jitter.loader.pe import vm_load_pe, vm_load_pe_libs,\ - preload_pe, libimp_pe, vm_load_pe_and_dependencies - from miasm2.os_dep import win_api_x86_32, win_api_x86_32_seh - methods = dict((name.encode(),func) for name, func in viewitems(win_api_x86_32.__dict__)) - methods.update(custom_methods) - - super(OS_Win, self).__init__(methods, *args, **kwargs) - - # Import manager - libs = libimp_pe() - self.libs = libs - win_api_x86_32.winobjs.runtime_dll = libs - - self.name2module = {} - fname_basename = os.path.basename(self.fname).lower() - - # Load main pe - with open(self.fname, "rb") as fstream: - self.pe = vm_load_pe( - self.jitter.vm, - fstream.read(), - load_hdr=self.options.load_hdr, - name=self.fname, - **kwargs - ) - self.name2module[fname_basename] = self.pe - - # Load library - if self.options.loadbasedll: - - # Load libs in memory - self.name2module.update( - vm_load_pe_libs( - self.jitter.vm, - self.ALL_IMP_DLL, - libs, - self.modules_path, - **kwargs - ) - ) - - # Patch libs imports - for pe in viewvalues(self.name2module): - preload_pe(self.jitter.vm, pe, libs) - - if self.options.dependencies: - vm_load_pe_and_dependencies( - self.jitter.vm, - fname_basename, - self.name2module, - libs, - self.modules_path, - **kwargs - ) - - win_api_x86_32.winobjs.current_pe = self.pe - - # Fix pe imports - preload_pe(self.jitter.vm, self.pe, libs) - - # Library calls handler - self.jitter.add_lib_handler(libs, methods) - - # Manage SEH - if self.options.use_windows_structs: - win_api_x86_32_seh.main_pe_name = fname_basename - win_api_x86_32_seh.main_pe = self.pe - win_api_x86_32.winobjs.hcurmodule = self.pe.NThdr.ImageBase - win_api_x86_32_seh.name2module = self.name2module - win_api_x86_32_seh.set_win_fs_0(self.jitter) - win_api_x86_32_seh.init_seh(self.jitter) - - self.entry_point = self.pe.rva2virt( - self.pe.Opthdr.AddressOfEntryPoint) - - @classmethod - def update_parser(cls, parser): - parser.add_argument('-o', "--load-hdr", action="store_true", - help="Load pe hdr") - parser.add_argument('-y', "--use-windows-structs", action="store_true", - help="Create and use windows structures (peb, ldr, seh, ...)") - parser.add_argument('-l', "--loadbasedll", action="store_true", - help="Load base dll (path './win_dll')") - parser.add_argument('-r', "--parse-resources", - action="store_true", help="Load resources") - - -class OS_Linux(OS): - - PROGRAM_PATH = "./program" - - def __init__(self, custom_methods, *args, **kwargs): - from miasm2.jitter.loader.elf import vm_load_elf, preload_elf, libimp_elf - from miasm2.os_dep import linux_stdlib - methods = linux_stdlib.__dict__ - methods.update(custom_methods) - - super(OS_Linux, self).__init__(methods, *args, **kwargs) - - # Import manager - self.libs = libimp_elf() - - with open(self.fname, "rb") as fstream: - self.elf = vm_load_elf( - self.jitter.vm, - fstream.read(), - name=self.fname, - **kwargs - ) - preload_elf(self.jitter.vm, self.elf, self.libs) - - self.entry_point = self.elf.Ehdr.entry - - # Library calls handler - self.jitter.add_lib_handler(self.libs, methods) - linux_stdlib.ABORT_ADDR = self.CALL_FINISH_ADDR - - # Arguments - self.argv = [self.PROGRAM_PATH] - if self.options.command_line: - self.argv += self.options.command_line - self.envp = self.options.environment_vars - - @classmethod - def update_parser(cls, parser): - parser.add_argument('-c', '--command-line', - action="append", - default=[], - help="Command line arguments") - parser.add_argument('--environment-vars', - action="append", - default=[], - help="Environment variables arguments") - parser.add_argument('--mimic-env', - action="store_true", - help="Mimic the environment of a starting executable") - -class OS_Linux_str(OS): - - PROGRAM_PATH = "./program" - - def __init__(self, custom_methods, *args, **kwargs): - from miasm2.jitter.loader.elf import libimp_elf - from miasm2.os_dep import linux_stdlib - methods = linux_stdlib.__dict__ - methods.update(custom_methods) - - super(OS_Linux_str, self).__init__(methods, *args, **kwargs) - - # Import manager - libs = libimp_elf() - self.libs = libs - - data = open(self.fname, "rb").read() - self.options.load_base_addr = int(self.options.load_base_addr, 0) - self.jitter.vm.add_memory_page( - self.options.load_base_addr, PAGE_READ | PAGE_WRITE, data, - "Initial Str" - ) - - # Library calls handler - self.jitter.add_lib_handler(libs, methods) - linux_stdlib.ABORT_ADDR = self.CALL_FINISH_ADDR - - # Arguments - self.argv = [self.PROGRAM_PATH] - if self.options.command_line: - self.argv += self.options.command_line - self.envp = self.options.environment_vars - - @classmethod - def update_parser(cls, parser): - parser.add_argument('-c', '--command-line', - action="append", - default=[], - help="Command line arguments") - parser.add_argument('--environment-vars', - action="append", - default=[], - help="Environment variables arguments") - parser.add_argument('--mimic-env', - action="store_true", - help="Mimic the environment of a starting executable") - parser.add_argument("load_base_addr", help="load base address") - - -class Arch_x86(Arch): - _ARCH_ = None # Arch name - STACK_SIZE = 0x10000 - STACK_BASE = 0x130000 - - def __init__(self, **kwargs): - super(Arch_x86, self).__init__(**kwargs) - - if self.options.usesegm: - self.jitter.ir_arch.do_stk_segm = True - self.jitter.ir_arch.do_ds_segm = True - self.jitter.ir_arch.do_str_segm = True - self.jitter.ir_arch.do_all_segm = True - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - - @classmethod - def update_parser(cls, parser): - parser.add_argument('-s', "--usesegm", action="store_true", - help="Use segments") - - -class Arch_x86_32(Arch_x86): - _ARCH_ = "x86_32" - - -class Arch_x86_64(Arch_x86): - _ARCH_ = "x86_64" - - -class Arch_arml(Arch): - _ARCH_ = "arml" - STACK_SIZE = 0x100000 - STACK_BASE = 0x100000 - - def __init__(self, **kwargs): - super(Arch_arml, self).__init__(**kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - - -class Arch_armb(Arch): - _ARCH_ = "armb" - STACK_SIZE = 0x100000 - STACK_BASE = 0x100000 - - def __init__(self, **kwargs): - super(Arch_armb, self).__init__(**kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - - -class Arch_armtl(Arch): - _ARCH_ = "armtl" - STACK_SIZE = 0x100000 - STACK_BASE = 0x100000 - - def __init__(self, **kwargs): - super(Arch_armtl, self).__init__(**kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - - -class Arch_mips32b(Arch): - _ARCH_ = "mips32b" - STACK_SIZE = 0x100000 - STACK_BASE = 0x100000 - - def __init__(self, **kwargs): - super(Arch_mips32b, self).__init__(**kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - - -class Arch_aarch64l(Arch): - _ARCH_ = "aarch64l" - STACK_SIZE = 0x100000 - STACK_BASE = 0x100000 - - def __init__(self, **kwargs): - super(Arch_aarch64l, self).__init__(**kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - - -class Arch_aarch64b(Arch): - _ARCH_ = "aarch64b" - STACK_SIZE = 0x100000 - STACK_BASE = 0x100000 - - def __init__(self, **kwargs): - super(Arch_aarch64b, self).__init__(**kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - -class Arch_ppc(Arch): - _ARCH_ = None - -class Arch_ppc32(Arch): - _ARCH_ = None - -class Arch_ppc32b(Arch_ppc32): - _ARCH_ = "ppc32b" - -class Sandbox_Win_x86_32(Sandbox, Arch_x86_32, OS_Win): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - self.jitter.push_uint32_t(2) - self.jitter.push_uint32_t(1) - self.jitter.push_uint32_t(0) - self.jitter.push_uint32_t(self.CALL_FINISH_ADDR) - - # Set the runtime guard - self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.__class__.code_sentinelle) - - def run(self, addr=None): - """ - If addr is not set, use entrypoint - """ - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Win_x86_32, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_stdcall) - super(self.__class__, self).call(prepare_cb, addr, *args) - - -class Sandbox_Win_x86_64(Sandbox, Arch_x86_64, OS_Win): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # reserve stack for local reg - for _ in range(0x4): - self.jitter.push_uint64_t(0) - - # Pre-stack return address - self.jitter.push_uint64_t(self.CALL_FINISH_ADDR) - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - """ - If addr is not set, use entrypoint - """ - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Win_x86_64, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_stdcall) - super(self.__class__, self).call(prepare_cb, addr, *args) - - -class Sandbox_Linux_x86_32(Sandbox, Arch_x86_32, OS_Linux): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.ESP -= len(env) - ptr = self.jitter.cpu.ESP - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.ESP -= len(arg) - ptr = self.jitter.cpu.ESP - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - self.jitter.push_uint32_t(self.CALL_FINISH_ADDR) - self.jitter.push_uint32_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(len(self.argv)) - else: - self.jitter.push_uint32_t(self.CALL_FINISH_ADDR) - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - """ - If addr is not set, use entrypoint - """ - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_x86_32, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) - - - -class Sandbox_Linux_x86_64(Sandbox, Arch_x86_64, OS_Linux): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.RSP -= len(env) - ptr = self.jitter.cpu.RSP - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.RSP -= len(arg) - ptr = self.jitter.cpu.RSP - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - self.jitter.push_uint64_t(self.CALL_FINISH_ADDR) - self.jitter.push_uint64_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint64_t(ptr) - self.jitter.push_uint64_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint64_t(ptr) - self.jitter.push_uint64_t(len(self.argv)) - else: - self.jitter.push_uint64_t(self.CALL_FINISH_ADDR) - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - """ - If addr is not set, use entrypoint - """ - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_x86_64, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) - - -class Sandbox_Linux_arml(Sandbox, Arch_arml, OS_Linux): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.SP -= len(env) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.SP -= len(arg) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - # Round SP to 4 - self.jitter.cpu.SP = self.jitter.cpu.SP & ~ 3 - - self.jitter.push_uint32_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(len(self.argv)) - - self.jitter.cpu.LR = self.CALL_FINISH_ADDR - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_arml, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) - - -class Sandbox_Linux_armtl(Sandbox, Arch_armtl, OS_Linux): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.SP -= len(env) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.SP -= len(arg) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - # Round SP to 4 - self.jitter.cpu.SP = self.jitter.cpu.SP & ~ 3 - - self.jitter.push_uint32_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(len(self.argv)) - - self.jitter.cpu.LR = self.CALL_FINISH_ADDR - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_armtl, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) - - - -class Sandbox_Linux_mips32b(Sandbox, Arch_mips32b, OS_Linux): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.SP -= len(env) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.SP -= len(arg) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - self.jitter.push_uint32_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.push_uint32_t(len(self.argv)) - - self.jitter.cpu.RA = 0x1337beef - - # Set the runtime guard - self.jitter.add_breakpoint( - 0x1337beef, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_mips32b, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) - - -class Sandbox_Linux_armb_str(Sandbox, Arch_armb, OS_Linux_str): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - self.jitter.cpu.LR = self.CALL_FINISH_ADDR - - # Set the runtime guard - self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.__class__.code_sentinelle) - - def run(self, addr=None): - if addr is None and self.options.address is not None: - addr = int(self.options.address, 0) - super(Sandbox_Linux_armb_str, self).run(addr) - - -class Sandbox_Linux_arml_str(Sandbox, Arch_arml, OS_Linux_str): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - self.jitter.cpu.LR = self.CALL_FINISH_ADDR - - # Set the runtime guard - self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.__class__.code_sentinelle) - - def run(self, addr=None): - if addr is None and self.options.address is not None: - addr = int(self.options.address, 0) - super(Sandbox_Linux_arml_str, self).run(addr) - - -class Sandbox_Linux_aarch64l(Sandbox, Arch_aarch64l, OS_Linux): - - def __init__(self, *args, **kwargs): - Sandbox.__init__(self, *args, **kwargs) - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.SP -= len(env) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.SP -= len(arg) - ptr = self.jitter.cpu.SP - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - self.jitter.push_uint64_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint64_t(ptr) - self.jitter.push_uint64_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint64_t(ptr) - self.jitter.push_uint64_t(len(self.argv)) - - self.jitter.cpu.LR = self.CALL_FINISH_ADDR - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_aarch64l, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) - -class Sandbox_Linux_ppc32b(Sandbox, Arch_ppc32b, OS_Linux): - - STACK_SIZE = 0x10000 - STACK_BASE = 0xbfce0000 - - # The glue between the kernel and the ELF ABI on Linux/PowerPC is - # implemented in glibc/sysdeps/powerpc/powerpc32/dl-start.S, so we - # have to play the role of ld.so here. - def __init__(self, *args, **kwargs): - super(Sandbox_Linux_ppc32b, self).__init__(*args, **kwargs) - - # Init stack - self.jitter.stack_size = self.STACK_SIZE - self.jitter.stack_base = self.STACK_BASE - self.jitter.init_stack() - self.jitter.cpu.R1 -= 8 - - # Pre-stack some arguments - if self.options.mimic_env: - env_ptrs = [] - for env in self.envp: - env = force_bytes(env) - env += b"\x00" - self.jitter.cpu.R1 -= len(env) - ptr = self.jitter.cpu.R1 - self.jitter.vm.set_mem(ptr, env) - env_ptrs.append(ptr) - argv_ptrs = [] - for arg in self.argv: - arg = force_bytes(arg) - arg += b"\x00" - self.jitter.cpu.R1 -= len(arg) - ptr = self.jitter.cpu.R1 - self.jitter.vm.set_mem(ptr, arg) - argv_ptrs.append(ptr) - - self.jitter.push_uint32_t(0) - for ptr in reversed(env_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.cpu.R5 = self.jitter.cpu.R1 # envp - self.jitter.push_uint32_t(0) - for ptr in reversed(argv_ptrs): - self.jitter.push_uint32_t(ptr) - self.jitter.cpu.R4 = self.jitter.cpu.R1 # argv - self.jitter.cpu.R3 = len(self.argv) # argc - self.jitter.push_uint32_t(self.jitter.cpu.R3) - - self.jitter.cpu.R6 = 0 # auxp - self.jitter.cpu.R7 = 0 # termination function - - # From the glibc, we should push a 0 here to distinguish a - # dynamically linked executable from a statically linked one. - # We actually do not do it and attempt to be somehow compatible - # with both types of executables. - #self.jitter.push_uint32_t(0) - - self.jitter.cpu.LR = self.CALL_FINISH_ADDR - - # Set the runtime guard - self.jitter.add_breakpoint( - self.CALL_FINISH_ADDR, - self.__class__.code_sentinelle - ) - - def run(self, addr=None): - """ - If addr is not set, use entrypoint - """ - if addr is None and self.options.address is None: - addr = self.entry_point - super(Sandbox_Linux_ppc32b, self).run(addr) - - def call(self, addr, *args, **kwargs): - """ - Direct call of the function at @addr, with arguments @args - @addr: address of the target function - @args: arguments - """ - prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) - super(self.__class__, self).call(prepare_cb, addr, *args) diff --git a/miasm2/analysis/simplifier.py b/miasm2/analysis/simplifier.py deleted file mode 100644 index 10d5e092..00000000 --- a/miasm2/analysis/simplifier.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -Apply simplification passes to an IR cfg -""" - -import logging -from functools import wraps -from miasm2.analysis.ssa import SSADiGraph -from miasm2.analysis.outofssa import UnSSADiGraph -from miasm2.analysis.data_flow import DiGraphLivenessSSA -from miasm2.expression.simplifications import expr_simp -from miasm2.analysis.data_flow import dead_simp, \ - merge_blocks, remove_empty_assignblks, \ - PropagateExprIntThroughExprId, PropagateThroughExprId, \ - PropagateThroughExprMem, del_unused_edges - - -log = logging.getLogger("simplifier") -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s")) -log.addHandler(console_handler) -log.setLevel(logging.WARNING) - - -def fix_point(func): - @wraps(func) - def ret_func(self, ircfg, head): - log.debug('[%s]: start', func.__name__) - has_been_modified = False - modified = True - while modified: - modified = func(self, ircfg, head) - has_been_modified |= modified - log.debug( - '[%s]: stop %r', - func.__name__, - has_been_modified - ) - return has_been_modified - return ret_func - - -class IRCFGSimplifier(object): - """ - Simplify an IRCFG - This class applies passes until reaching a fix point - """ - - def __init__(self, ir_arch): - self.ir_arch = ir_arch - self.init_passes() - - def init_passes(self): - """ - Init the array of simplification passes - """ - self.passes = [] - - @fix_point - def simplify(self, ircfg, head): - """ - Apply passes until reaching a fix point - Return True if the graph has been modified - - @ircfg: IRCFG instance to simplify - @head: Location instance of the ircfg head - """ - modified = False - for simplify_pass in self.passes: - modified |= simplify_pass(ircfg, head) - return modified - - def __call__(self, ircfg, head): - return self.simplify(ircfg, head) - - -class IRCFGSimplifierCommon(IRCFGSimplifier): - """ - Simplify an IRCFG - This class applies following passes until reaching a fix point: - - simplify_ircfg - - do_dead_simp_ircfg - """ - def __init__(self, ir_arch, expr_simp=expr_simp): - self.expr_simp = expr_simp - super(IRCFGSimplifierCommon, self).__init__(ir_arch) - - def init_passes(self): - self.passes = [ - self.simplify_ircfg, - self.do_dead_simp_ircfg, - ] - - @fix_point - def simplify_ircfg(self, ircfg, _head): - """ - Apply self.expr_simp on the @ircfg until reaching fix point - Return True if the graph has been modified - - @ircfg: IRCFG instance to simplify - """ - modified = ircfg.simplify(self.expr_simp) - return modified - - @fix_point - def do_dead_simp_ircfg(self, ircfg, head): - """ - Apply: - - dead_simp - - remove_empty_assignblks - - merge_blocks - on the @ircfg until reaching fix point - Return True if the graph has been modified - - @ircfg: IRCFG instance to simplify - @head: Location instance of the ircfg head - """ - modified = dead_simp(self.ir_arch, ircfg) - modified |= remove_empty_assignblks(ircfg) - modified |= merge_blocks(ircfg, set([head])) - return modified - - -class IRCFGSimplifierSSA(IRCFGSimplifierCommon): - """ - Simplify an IRCFG. - The IRCF is first transformed in SSA, then apply transformations passes - and apply out-of-ssa. Final passes of IRcfgSimplifier are applied - - This class apply following pass until reaching a fix point: - - do_propagate_int - - do_propagate_mem - - do_propagate_expr - - do_dead_simp_ssa - """ - - def __init__(self, ir_arch, expr_simp=expr_simp): - super(IRCFGSimplifierSSA, self).__init__(ir_arch, expr_simp) - - self.ir_arch.ssa_var = {} - self.all_ssa_vars = {} - - self.ssa_forbidden_regs = self.get_forbidden_regs() - - self.propag_int = PropagateExprIntThroughExprId() - self.propag_expr = PropagateThroughExprId() - self.propag_mem = PropagateThroughExprMem() - - def get_forbidden_regs(self): - """ - Return a set of immutable register during SSA transformation - """ - regs = set( - [ - self.ir_arch.pc, - self.ir_arch.IRDst, - self.ir_arch.arch.regs.exception_flags - ] - ) - return regs - - def init_passes(self): - """ - Init the array of simplification passes - """ - self.passes = [ - self.simplify_ssa, - self.do_propagate_int, - self.do_propagate_mem, - self.do_propagate_expr, - self.do_dead_simp_ssa, - ] - - def ircfg_to_ssa(self, ircfg, head): - """ - Apply the SSA transformation to @ircfg using it's @head - - @ircfg: IRCFG instance to simplify - @head: Location instance of the ircfg head - """ - ssa = SSADiGraph(ircfg) - ssa.immutable_ids.update(self.ssa_forbidden_regs) - ssa.ssa_variable_to_expr.update(self.all_ssa_vars) - ssa.transform(head) - self.all_ssa_vars.update(ssa.ssa_variable_to_expr) - self.ir_arch.ssa_var.update(ssa.ssa_variable_to_expr) - return ssa - - def ssa_to_unssa(self, ssa, head): - """ - Apply the out-of-ssa transformation to @ssa using it's @head - - @ssa: SSADiGraph instance - @head: Location instance of the graph head - """ - cfg_liveness = DiGraphLivenessSSA(ssa.graph) - cfg_liveness.init_var_info(self.ir_arch) - cfg_liveness.compute_liveness() - - UnSSADiGraph(ssa, head, cfg_liveness) - return ssa.graph - - @fix_point - def simplify_ssa(self, ssa, _head): - """ - Apply self.expr_simp on the @ssa.graph until reaching fix point - Return True if the graph has been modified - - @ssa: SSADiGraph instance - """ - modified = ssa.graph.simplify(self.expr_simp) - return modified - - @fix_point - def do_propagate_int(self, ssa, head): - """ - Constant propagation in the @ssa graph - @head: Location instance of the graph head - """ - modified = self.propag_int.propagate(ssa, head) - modified |= ssa.graph.simplify(self.expr_simp) - modified |= del_unused_edges(ssa.graph, set([head])) - return modified - - @fix_point - def do_propagate_mem(self, ssa, head): - """ - Propagation of expression based on ExprInt/ExprId in the @ssa graph - @head: Location instance of the graph head - """ - modified = self.propag_mem.propagate(ssa, head) - modified |= ssa.graph.simplify(self.expr_simp) - modified |= del_unused_edges(ssa.graph, set([head])) - return modified - - @fix_point - def do_propagate_expr(self, ssa, head): - """ - Expressions propagation through ExprId in the @ssa graph - @head: Location instance of the graph head - """ - modified = self.propag_expr.propagate(ssa, head) - modified |= ssa.graph.simplify(self.expr_simp) - modified |= del_unused_edges(ssa.graph, set([head])) - return modified - - @fix_point - def do_dead_simp_ssa(self, ssa, head): - """ - Apply: - - dead_simp - - remove_empty_assignblks - - del_unused_edges - - merge_blocks - on the @ircfg until reaching fix point - Return True if the graph has been modified - - @ircfg: IRCFG instance to simplify - @head: Location instance of the ircfg head - """ - modified = dead_simp(self.ir_arch, ssa.graph) - modified |= remove_empty_assignblks(ssa.graph) - modified |= del_unused_edges(ssa.graph, set([head])) - modified |= merge_blocks(ssa.graph, set([head])) - return modified - - def do_simplify(self, ssa, head): - """ - Apply passes until reaching a fix point - Return True if the graph has been modified - """ - return super(IRCFGSimplifierSSA, self).simplify(ssa, head) - - def do_simplify_loop(self, ssa, head): - """ - Apply do_simplify until reaching a fix point - SSA is updated between each do_simplify - Return True if the graph has been modified - """ - modified = True - while modified: - modified = self.do_simplify(ssa, head) - # Update ssa structs - ssa = self.ircfg_to_ssa(ssa.graph, head) - return ssa - - def simplify(self, ircfg, head): - """ - Apply SSA transformation to @ircfg - Apply passes until reaching a fix point - Apply out-of-ssa transformation - Apply post simplification passes - - Updated simplified IRCFG instance and return it - - @ircfg: IRCFG instance to simplify - @head: Location instance of the ircfg head - """ - ssa = self.ircfg_to_ssa(ircfg, head) - ssa = self.do_simplify_loop(ssa, head) - ircfg = self.ssa_to_unssa(ssa, head) - ircfg_simplifier = IRCFGSimplifierCommon(self.ir_arch) - ircfg_simplifier.simplify(ircfg, head) - return ircfg diff --git a/miasm2/analysis/ssa.py b/miasm2/analysis/ssa.py deleted file mode 100644 index 54d17dc1..00000000 --- a/miasm2/analysis/ssa.py +++ /dev/null @@ -1,1118 +0,0 @@ -from collections import deque -from future.utils import viewitems, viewvalues - -from miasm2.expression.expression import ExprId, ExprAssign, ExprOp, \ - ExprLoc, get_expr_ids -from miasm2.ir.ir import AssignBlock, IRBlock - - -def sanitize_graph_head(ircfg, head): - """ - In multiple algorithm, the @head of the ircfg may not have predecessors. - The function transform the @ircfg in order to ensure this property - @ircfg: IRCFG instance - @head: the location of the graph's head - """ - - if not ircfg.predecessors(head): - return - original_edges = ircfg.predecessors(head) - sub_head = ircfg.loc_db.add_location() - - # Duplicate graph, replacing references to head by sub_head - replaced_expr = { - ExprLoc(head, ircfg.IRDst.size): - ExprLoc(sub_head, ircfg.IRDst.size) - } - ircfg.simplify( - lambda expr:expr.replace_expr(replaced_expr) - ) - # Duplicate head block - ircfg.add_irblock(IRBlock(sub_head, list(ircfg.blocks[head]))) - - # Remove original head block - ircfg.del_node(head) - - for src in original_edges: - ircfg.add_edge(src, sub_head) - - # Create new head, jumping to sub_head - assignblk = AssignBlock({ircfg.IRDst:ExprLoc(sub_head, ircfg.IRDst.size)}) - new_irblock = IRBlock(head, [assignblk]) - ircfg.add_irblock(new_irblock) - - -class SSA(object): - """ - Generic class for static single assignment (SSA) transformation - - Handling of - - variable generation - - variable renaming - - conversion of an IRCFG block into SSA - - Variables will be renamed to <variable>.<index>, whereby the - index will be increased in every definition of <variable>. - - Memory expressions are stateless. The addresses are in SSA form, - but memory aliasing will occur. For instance, if it holds - that RAX == RBX.0 + (-0x8) and - - @64[RBX.0 + (-0x8)] = RDX - RCX.0 = @64[RAX], - - then it cannot be tracked that RCX.0 == RDX. - """ - - - def __init__(self, ircfg): - """ - Initialises generic class for SSA - :param ircfg: instance of IRCFG - """ - # IRCFG instance - self.ircfg = ircfg - - # SSA blocks - self.blocks = {} - - # stack for RHS - self._stack_rhs = {} - # stack for LHS - self._stack_lhs = {} - - self.ssa_variable_to_expr = {} - - # dict of SSA expressions - self.expressions = {} - - # dict of SSA to original location - self.ssa_to_location = {} - - # Don't SSA IRDst - self.immutable_ids = set([self.ircfg.IRDst]) - - def get_regs(self, expr): - return get_expr_ids(expr) - - def transform(self, *args, **kwargs): - """Transforms into SSA""" - raise NotImplementedError("Abstract method") - - def get_block(self, loc_key): - """ - Returns an IRBlock - :param loc_key: LocKey instance - :return: IRBlock - """ - irblock = self.ircfg.blocks.get(loc_key, None) - - return irblock - - def reverse_variable(self, ssa_var): - """ - Transforms a variable in SSA form into non-SSA form - :param ssa_var: ExprId, variable in SSA form - :return: ExprId, variable in non-SSA form - """ - expr = self.ssa_variable_to_expr.get(ssa_var, ssa_var) - return expr - - def reset(self): - """Resets SSA transformation""" - self.blocks = {} - self.expressions = {} - self._stack_rhs = {} - self._stack_lhs = {} - self.ssa_to_location = {} - - def _gen_var_expr(self, expr, stack): - """ - Generates a variable expression in SSA form - :param expr: variable expression which will be translated - :param stack: self._stack_rhs or self._stack_lhs - :return: variable expression in SSA form - """ - index = stack[expr] - name = "%s.%d" % (expr.name, index) - ssa_var = ExprId(name, expr.size) - self.ssa_variable_to_expr[ssa_var] = expr - - return ssa_var - - def _transform_var_rhs(self, ssa_var): - """ - Transforms a variable on the right hand side into SSA - :param ssa_var: variable - :return: transformed variable - """ - # variable has never been on the LHS - if ssa_var not in self._stack_rhs: - return ssa_var - # variable has been on the LHS - stack = self._stack_rhs - return self._gen_var_expr(ssa_var, stack) - - def _transform_var_lhs(self, expr): - """ - Transforms a variable on the left hand side into SSA - :param expr: variable - :return: transformed variable - """ - # check if variable has already been on the LHS - if expr not in self._stack_lhs: - self._stack_lhs[expr] = 0 - # save last value for RHS transformation - self._stack_rhs[expr] = self._stack_lhs[expr] - - # generate SSA expression - stack = self._stack_lhs - ssa_var = self._gen_var_expr(expr, stack) - - return ssa_var - - def _transform_expression_lhs(self, dst): - """ - Transforms an expression on the left hand side into SSA - :param dst: expression - :return: expression in SSA form - """ - if dst.is_mem(): - # transform with last RHS instance - ssa_var = self._transform_expression_rhs(dst) - else: - # transform LHS - ssa_var = self._transform_var_lhs(dst) - - # increase SSA variable counter - self._stack_lhs[dst] += 1 - - return ssa_var - - def _transform_expression_rhs(self, src): - """ - Transforms an expression on the right hand side into SSA - :param src: expression - :return: expression in SSA form - """ - # dissect expression in variables - variables = self.get_regs(src) - src_ssa = src - # transform variables - for expr in variables: - ssa_var = self._transform_var_rhs(expr) - src_ssa = src_ssa.replace_expr({expr: ssa_var}) - - return src_ssa - - @staticmethod - def _parallel_instructions(assignblk): - """ - Extracts the instruction from a AssignBlock. - - Since instructions in a AssignBlock are evaluated - in parallel, memory instructions on the left hand - side will be inserted into the start of the list. - Then, memory instruction on the LHS will be - transformed firstly. - - :param assignblk: assignblock - :return: sorted list of expressions - """ - instructions = [] - for dst in assignblk: - # dst = src - aff = assignblk.dst2ExprAssign(dst) - # insert memory expression into start of list - if dst.is_mem(): - instructions.insert(0, aff) - else: - instructions.append(aff) - - return instructions - - @staticmethod - def _convert_block(irblock, ssa_list): - """ - Transforms an IRBlock inplace into SSA - :param irblock: IRBlock to be transformed - :param ssa_list: list of SSA expressions - """ - # iterator over SSA expressions - ssa_iter = iter(ssa_list) - new_irs = [] - # walk over IR blocks' assignblocks - for assignblk in irblock.assignblks: - # list of instructions - instructions = [] - # insert SSA instructions - for _ in assignblk: - instructions.append(next(ssa_iter)) - # replace instructions of assignblock in IRBlock - new_irs.append(AssignBlock(instructions, assignblk.instr)) - return IRBlock(irblock.loc_key, new_irs) - - def _rename_expressions(self, loc_key): - """ - Transforms variables and expressions - of an IRBlock into SSA. - - IR representations of an assembly instruction are evaluated - in parallel. Thus, RHS and LHS instructions will be performed - separately. - :param loc_key: IRBlock loc_key - """ - # list of IRBlock's SSA expressions - ssa_expressions_block = [] - - # retrieve IRBlock - irblock = self.get_block(loc_key) - if irblock is None: - # Incomplete graph - return - - # iterate block's IR expressions - for index, assignblk in enumerate(irblock.assignblks): - # list of parallel instructions - instructions = self._parallel_instructions(assignblk) - # list for transformed RHS expressions - rhs = deque() - - # transform RHS - for expr in instructions: - src = expr.src - src_ssa = self._transform_expression_rhs(src) - # save transformed RHS - rhs.append(src_ssa) - - # transform LHS - for expr in instructions: - if expr.dst in self.immutable_ids or expr.dst in self.ssa_variable_to_expr: - dst_ssa = expr.dst - else: - dst_ssa = self._transform_expression_lhs(expr.dst) - - # retrieve corresponding RHS expression - src_ssa = rhs.popleft() - - # rebuild SSA expression - expr = ExprAssign(dst_ssa, src_ssa) - self.expressions[dst_ssa] = src_ssa - self.ssa_to_location[dst_ssa] = (loc_key, index) - - - # append ssa expression to list - ssa_expressions_block.append(expr) - - # replace blocks IR expressions with corresponding SSA transformations - new_irblock = self._convert_block(irblock, ssa_expressions_block) - self.ircfg.blocks[loc_key] = new_irblock - - -class SSABlock(SSA): - """ - SSA transformation on block level - - It handles - - transformation of a single IRBlock into SSA - - reassembling an SSA expression into a non-SSA - expression through iterative resolving of the RHS - """ - - def transform(self, loc_key): - """ - Transforms a block into SSA form - :param loc_key: IRBlock loc_key - """ - self._rename_expressions(loc_key) - - def reassemble_expr(self, expr): - """ - Reassembles an expression in SSA form into a solely non-SSA expression - :param expr: expression - :return: non-SSA expression - """ - # worklist - todo = {expr.copy()} - - while todo: - # current expression - cur = todo.pop() - # RHS of current expression - cur_rhs = self.expressions[cur] - - # replace cur with RHS in expr - expr = expr.replace_expr({cur: cur_rhs}) - - # parse ExprIDs on RHS - ids_rhs = self.get_regs(cur_rhs) - - # add RHS ids to worklist - for id_rhs in ids_rhs: - if id_rhs in self.expressions: - todo.add(id_rhs) - return expr - - -class SSAPath(SSABlock): - """ - SSA transformation on path level - - It handles - - transformation of a path of IRBlocks into SSA - """ - - def transform(self, path): - """ - Transforms a path into SSA - :param path: list of IRBlock loc_key - """ - for block in path: - self._rename_expressions(block) - - -class SSADiGraph(SSA): - """ - SSA transformation on DiGraph level - - It handles - - transformation of a DiGraph into SSA - - generation, insertion and filling of phi nodes - - The implemented SSA form is known as minimal SSA. - """ - - PHI_STR = 'Phi' - - - def __init__(self, ircfg): - """ - Initialises SSA class for directed graphs - :param ircfg: instance of IRCFG - """ - super(SSADiGraph, self).__init__(ircfg) - - # variable definitions - self.defs = {} - - # dict of blocks' phi nodes - self._phinodes = {} - - # IRCFG control flow graph - self.graph = ircfg - - - def transform(self, head): - """Transforms into SSA""" - sanitize_graph_head(self.graph, head) - self._init_variable_defs(head) - self._place_phi(head) - self._rename(head) - self._insert_phi() - self._convert_phi() - self._fix_no_def_var(head) - - def reset(self): - """Resets SSA transformation""" - super(SSADiGraph, self).reset() - self.defs = {} - self._phinodes = {} - - def _init_variable_defs(self, head): - """ - Initialises all variable definitions and - assigns the corresponding IRBlocks. - - All variable definitions in self.defs contain - a set of IRBlocks in which the variable gets assigned - """ - - for loc_key in self.graph.walk_depth_first_forward(head): - irblock = self.get_block(loc_key) - if irblock is None: - # Incomplete graph - continue - - # search for block's IR definitions/destinations - for assignblk in irblock.assignblks: - for dst in assignblk: - # enforce ExprId - if dst.is_id(): - # exclude immutable ids - if dst in self.immutable_ids or dst in self.ssa_variable_to_expr: - continue - # map variable definition to blocks - self.defs.setdefault(dst, set()).add(irblock.loc_key) - - def _place_phi(self, head): - """ - For all blocks, empty phi functions will be placed for every - variable in the block's dominance frontier. - - self.phinodes contains a dict for every block in the - dominance frontier. In this dict, each variable - definition maps to its corresponding phi function. - - Source: Cytron, Ron, et al. - "An efficient method of computing static single assignment form" - Proceedings of the 16th ACM SIGPLAN-SIGACT symposium on - Principles of programming languages (1989), p. 30 - """ - # dominance frontier - frontier = self.graph.compute_dominance_frontier(head) - - for variable in self.defs: - done = set() - todo = set() - intodo = set() - - for loc_key in self.defs[variable]: - todo.add(loc_key) - intodo.add(loc_key) - - while todo: - loc_key = todo.pop() - - # walk through block's dominance frontier - for node in frontier.get(loc_key, []): - if node in done: - continue - # place empty phi functions for a variable - empty_phi = self._gen_empty_phi(variable) - - # add empty phi node for variable in node - self._phinodes.setdefault(node, {})[variable] = empty_phi.src - done.add(node) - - if node not in intodo: - intodo.add(node) - todo.add(node) - - def _gen_empty_phi(self, expr): - """ - Generates an empty phi function for a variable - :param expr: variable - :return: ExprAssign, empty phi function for expr - """ - phi = ExprId(self.PHI_STR, expr.size) - return ExprAssign(expr, phi) - - def _fill_phi(self, *args): - """ - Fills a phi function with variables. - - phi(x.1, x.5, x.6) - - :param args: list of ExprId - :return: ExprOp - """ - return ExprOp(self.PHI_STR, *set(args)) - - def _rename(self, head): - """ - Transforms each variable expression in the CFG into SSA - by traversing the dominator tree in depth-first search. - - 1. Transform variables of phi functions on LHS into SSA - 2. Transform all non-phi expressions into SSA - 3. Update the successor's phi functions' RHS with current SSA variables - 4. Save current SSA variable stack for successors in the dominator tree - - Source: Cytron, Ron, et al. - "An efficient method of computing static single assignment form" - Proceedings of the 16th ACM SIGPLAN-SIGACT symposium on - Principles of programming languages (1989), p. 31 - """ - # compute dominator tree - dominator_tree = self.graph.compute_dominator_tree(head) - - # init SSA variable stack - stack = [self._stack_rhs] - - # walk in DFS over the dominator tree - for loc_key in dominator_tree.walk_depth_first_forward(head): - # restore SSA variable stack of the predecessor in the dominator tree - self._stack_rhs = stack.pop().copy() - - # Transform variables of phi functions on LHS into SSA - self._rename_phi_lhs(loc_key) - - # Transform all non-phi expressions into SSA - self._rename_expressions(loc_key) - - # Update the successor's phi functions' RHS with current SSA variables - # walk over block's successors in the CFG - for successor in self.graph.successors_iter(loc_key): - self._rename_phi_rhs(successor) - - # Save current SSA variable stack for successors in the dominator tree - for _ in dominator_tree.successors_iter(loc_key): - stack.append(self._stack_rhs) - - def _rename_phi_lhs(self, loc_key): - """ - Transforms phi function's expressions of an IRBlock - on the left hand side into SSA - :param loc_key: IRBlock loc_key - """ - if loc_key in self._phinodes: - # create temporary list of phi function assignments for inplace renaming - tmp = list(self._phinodes[loc_key]) - - # iterate over all block's phi nodes - for dst in tmp: - # transform variables on LHS inplace - self._phinodes[loc_key][self._transform_expression_lhs(dst)] = self._phinodes[loc_key].pop(dst) - - def _rename_phi_rhs(self, successor): - """ - Transforms the right hand side of each successor's phi function - into SSA. Each transformed expression of a phi function's - right hand side is of the form - - phi(<var>.<index 1>, <var>.<index 2>, ..., <var>.<index n>) - - :param successor: loc_key of block's direct successor in the CFG - """ - # if successor is in block's dominance frontier - if successor in self._phinodes: - # walk over all variables on LHS - for dst, src in list(viewitems(self._phinodes[successor])): - # transform variable on RHS in non-SSA form - expr = self.reverse_variable(dst) - # transform expr into it's SSA form using current stack - src_ssa = self._transform_expression_rhs(expr) - - # Add src_ssa to phi args - if src.is_id(self.PHI_STR): - # phi function is empty - expr = self._fill_phi(src_ssa) - else: - # phi function contains at least one value - expr = self._fill_phi(src_ssa, *src.args) - - # update phi function - self._phinodes[successor][dst] = expr - - def _insert_phi(self): - """Inserts phi functions into the list of SSA expressions""" - for loc_key in self._phinodes: - for dst in self._phinodes[loc_key]: - self.expressions[dst] = self._phinodes[loc_key][dst] - - def _convert_phi(self): - """Inserts corresponding phi functions inplace - into IRBlock at the beginning""" - for loc_key in self._phinodes: - irblock = self.get_block(loc_key) - if irblock is None: - continue - assignblk = AssignBlock(self._phinodes[loc_key]) - # insert at the beginning - new_irs = IRBlock(loc_key, [assignblk] + list(irblock.assignblks)) - self.ircfg.blocks[loc_key] = new_irs - - def _fix_no_def_var(self, head): - """ - Replace phi source variables which are not ssa vars by ssa vars. - @head: loc_key of the graph head - """ - var_to_insert = set() - for loc_key in self._phinodes: - for dst, sources in viewitems(self._phinodes[loc_key]): - for src in sources.args: - if src in self.ssa_variable_to_expr: - continue - var_to_insert.add(src) - var_to_newname = {} - newname_to_var = {} - for var in var_to_insert: - new_var = self._transform_var_lhs(var) - var_to_newname[var] = new_var - newname_to_var[new_var] = var - - # Replace non modified node used in phi with new variable - self.ircfg.simplify(lambda expr:expr.replace_expr(var_to_newname)) - - if newname_to_var: - irblock = self.ircfg.blocks[head] - assignblks = list(irblock) - assignblks[0:0] = [AssignBlock(newname_to_var, assignblks[0].instr)] - self.ircfg.blocks[head] = IRBlock(head, assignblks) - - # Updt structure - for loc_key in self._phinodes: - for dst, sources in viewitems(self._phinodes[loc_key]): - self._phinodes[loc_key][dst] = sources.replace_expr(var_to_newname) - - for var, (loc_key, index) in list(viewitems(self.ssa_to_location)): - if loc_key == head: - self.ssa_to_location[var] = loc_key, index + 1 - - for newname, var in viewitems(newname_to_var): - self.ssa_to_location[newname] = head, 0 - self.ssa_variable_to_expr[newname] = var - self.expressions[newname] = var - - -def irblock_has_phi(irblock): - """ - Return True if @irblock has Phi assignments - @irblock: IRBlock instance - """ - if not irblock.assignblks: - return False - for src in viewvalues(irblock[0]): - return src.is_op('Phi') - return False - - -class Varinfo(object): - """Store liveness information for a variable""" - __slots__ = ["live_index", "loc_key", "index"] - - def __init__(self, live_index, loc_key, index): - self.live_index = live_index - self.loc_key = loc_key - self.index = index - - -def get_var_assignment_src(ircfg, node, variables): - """ - Return the variable of @variables which is written by the irblock at @node - @node: Location - @variables: a set of variable to test - """ - irblock = ircfg.blocks[node] - for assignblk in irblock: - result = set(assignblk).intersection(variables) - if not result: - continue - assert len(result) == 1 - return list(result)[0] - return None - - -def get_phi_sources_parent_block(ircfg, loc_key, sources): - """ - Return a dictionary linking a variable to it's direct parent label - which belong to a path which affects the node. - @loc_key: the starting node - @sources: set of variables to resolve - """ - source_to_parent = {} - for parent in ircfg.predecessors(loc_key): - done = set() - todo = set([parent]) - found = False - while todo: - node = todo.pop() - if node in done: - continue - done.add(node) - ret = get_var_assignment_src(ircfg, node, sources) - if ret: - source_to_parent.setdefault(ret, set()).add(parent) - found = True - break - for pred in ircfg.predecessors(node): - todo.add(pred) - assert found - return source_to_parent - - -class UnSSADiGraph(object): - """ - Implements unssa algorithm - Revisiting Out-of-SSA Translation for Correctness, Code Quality, and - Efficiency - """ - - def __init__(self, ssa, head, cfg_liveness): - self.cfg_liveness = cfg_liveness - self.ssa = ssa - self.head = head - - # Set of created variables - self.copy_vars = set() - # Virtual parallel copies - - # On loc_key's Phi node dst -> set((parent, src)) - self.phi_parent_sources = {} - # On loc_key's Phi node, loc_key -> set(Phi dsts) - self.phi_destinations = {} - # Phi's dst -> new var - self.phi_new_var = {} - # For a new_var representing dst: - # new_var -> set(parents of Phi's src in dst = Phi(src,...)) - self.new_var_to_srcs_parents = {} - # new_var -> set(variables to be coalesced with, named "merge_set") - self.merge_state = {} - - # Launch the algorithm in several steps - self.isolate_phi_nodes_block() - self.init_phis_merge_state() - self.order_ssa_var_dom() - self.aggressive_coalesce_block() - self.insert_parallel_copy() - self.replace_merge_sets() - self.remove_assign_eq() - - def insert_parallel_copy(self): - """ - Naive Out-of-SSA from CSSA (without coalescing for now) - - Replace Phi - - Create room for parallel copies in Phi's parents - """ - ircfg = self.ssa.graph - - for irblock in list(viewvalues(ircfg.blocks)): - if not irblock_has_phi(irblock): - continue - - # Replace Phi with Phi's dst = new_var - parallel_copies = {} - for dst in self.phi_destinations[irblock.loc_key]: - new_var = self.phi_new_var[dst] - parallel_copies[dst] = new_var - - assignblks = list(irblock) - assignblks[0] = AssignBlock(parallel_copies, irblock[0].instr) - new_irblock = IRBlock(irblock.loc_key, assignblks) - ircfg.blocks[irblock.loc_key] = new_irblock - - # Insert new_var = src in each Phi's parent, at the end of the block - parent_to_parallel_copies = {} - parallel_copies = {} - for dst in irblock[0]: - new_var = self.phi_new_var[dst] - for parent, src in self.phi_parent_sources[dst]: - parent_to_parallel_copies.setdefault(parent, {})[new_var] = src - - for parent, parallel_copies in viewitems(parent_to_parallel_copies): - parent = ircfg.blocks[parent] - assignblks = list(parent) - assignblks.append(AssignBlock(parallel_copies, parent[-1].instr)) - new_irblock = IRBlock(parent.loc_key, assignblks) - ircfg.blocks[parent.loc_key] = new_irblock - - def create_copy_var(self, var): - """ - Generate a new var standing for @var - @var: variable to replace - """ - new_var = ExprId('var%d' % len(self.copy_vars), var.size) - self.copy_vars.add(new_var) - return new_var - - def isolate_phi_nodes_block(self): - """ - Init structures and virtually insert parallel copy before/after each phi - node - """ - ircfg = self.ssa.graph - for irblock in viewvalues(ircfg.blocks): - if not irblock_has_phi(irblock): - continue - for dst, sources in viewitems(irblock[0]): - assert sources.is_op('Phi') - new_var = self.create_copy_var(dst) - self.phi_new_var[dst] = new_var - - var_to_parents = get_phi_sources_parent_block( - self.ssa.graph, - irblock.loc_key, - sources.args - ) - - for src in sources.args: - parents = var_to_parents[src] - self.new_var_to_srcs_parents.setdefault(new_var, set()).update(parents) - for parent in parents: - self.phi_parent_sources.setdefault(dst, set()).add((parent, src)) - - self.phi_destinations[irblock.loc_key] = set(irblock[0]) - - def init_phis_merge_state(self): - """ - Generate trivial coalescing of phi variable and itself - """ - for phi_new_var in viewvalues(self.phi_new_var): - self.merge_state.setdefault(phi_new_var, set([phi_new_var])) - - def order_ssa_var_dom(self): - """Compute dominance order of each ssa variable""" - ircfg = self.ssa.graph - - # compute dominator tree - dominator_tree = ircfg.compute_dominator_tree(self.head) - - # variable -> Varinfo - self.var_to_varinfo = {} - # live_index can later be used to compare dominance of AssignBlocks - live_index = 0 - - # walk in DFS over the dominator tree - for loc_key in dominator_tree.walk_depth_first_forward(self.head): - irblock = ircfg.blocks[loc_key] - - # Create live index for phi new vars - # They do not exist in the graph yet, so index is set to None - if irblock_has_phi(irblock): - for dst in irblock[0]: - if not dst.is_id(): - continue - new_var = self.phi_new_var[dst] - self.var_to_varinfo[new_var] = Varinfo(live_index, loc_key, None) - - live_index += 1 - - # Create live index for remaining assignments - for index, assignblk in enumerate(irblock): - used = False - for dst in assignblk: - if not dst.is_id(): - continue - if dst in self.ssa.immutable_ids: - # Will not be considered by the current algo, ignore it - # (for instance, IRDst) - continue - - assert dst not in self.var_to_varinfo - self.var_to_varinfo[dst] = Varinfo(live_index, loc_key, index) - used = True - if used: - live_index += 1 - - - def ssa_def_dominates(self, node_a, node_b): - """ - Return living index order of @node_a and @node_b - @node_a: Varinfo instance - @node_b: Varinfo instance - """ - ret = self.var_to_varinfo[node_a].live_index <= self.var_to_varinfo[node_b].live_index - return ret - - def merge_set_sort(self, merge_set): - """ - Return a sorted list of (live_index, var) from @merge_set in dominance - order - @merge_set: set of coalescing variables - """ - return sorted( - (self.var_to_varinfo[var].live_index, var) - for var in merge_set - ) - - def ssa_def_is_live_at(self, node_a, node_b, parent): - """ - Return True if @node_a is live during @node_b definition - If @parent is None, this is a liveness test for a post phi variable; - Else, it is a liveness test for a variable source of the phi node - - @node_a: Varinfo instance - @node_b: Varinfo instance - @parent: Optional parent location of the phi source - """ - loc_key_b, index_b = self.var_to_varinfo[node_b].loc_key, self.var_to_varinfo[node_b].index - if parent and index_b is None: - index_b = 0 - if node_a not in self.new_var_to_srcs_parents: - # node_a is not a new var (it is a "classic" var) - # -> use a basic liveness test - liveness_b = self.cfg_liveness.blocks[loc_key_b].infos[index_b] - return node_a in liveness_b.var_out - - for def_loc_key in self.new_var_to_srcs_parents[node_a]: - # Consider node_a as defined at the end of its parents blocks - # and compute liveness check accordingly - - if def_loc_key == parent: - # Same path as node_a definition, so SSA ensure b cannot be live - # on this path (otherwise, a Phi would already happen earlier) - continue - liveness_end_block = self.cfg_liveness.blocks[def_loc_key].infos[-1] - if node_b in liveness_end_block.var_out: - return True - return False - - def merge_nodes_interfere(self, node_a, node_b, parent): - """ - Return True if @node_a and @node_b interfere - @node_a: variable - @node_b: variable - @parent: Optional parent location of the phi source for liveness tests - - Interference check is: is x live at y definition (or reverse) - TODO: add Value-based interference improvement - """ - if self.var_to_varinfo[node_a].live_index == self.var_to_varinfo[node_b].live_index: - # Defined in the same AssignBlock -> interfere - return True - - if self.var_to_varinfo[node_a].live_index < self.var_to_varinfo[node_b].live_index: - return self.ssa_def_is_live_at(node_a, node_b, parent) - return self.ssa_def_is_live_at(node_b, node_a, parent) - - def merge_sets_interfere(self, merge_a, merge_b, parent): - """ - Return True if no variable in @merge_a and @merge_b interferes. - - Implementation of "Algorithm 2: Check intersection in a set of variables" - - @merge_a: a dom ordered list of equivalent variables - @merge_b: a dom ordered list of equivalent variables - @parent: Optional parent location of the phi source for liveness tests - """ - if merge_a == merge_b: - # No need to consider interference if equal - return False - - merge_a_list = self.merge_set_sort(merge_a) - merge_b_list = self.merge_set_sort(merge_b) - dom = [] - while merge_a_list or merge_b_list: - if not merge_a_list: - _, current = merge_b_list.pop(0) - elif not merge_b_list: - _, current = merge_a_list.pop(0) - else: - # compare live_indexes (standing for dominance) - if merge_a_list[-1] < merge_b_list[-1]: - _, current = merge_a_list.pop(0) - else: - _, current = merge_b_list.pop(0) - while dom and not self.ssa_def_dominates(dom[-1], current): - dom.pop() - - # Don't test node in same merge_set - if ( - # Is stack not empty? - dom and - # Trivial non-interference if dom.top() and current come - # from the same merge set - not (dom[-1] in merge_a and current in merge_a) and - not (dom[-1] in merge_b and current in merge_b) and - # Actually test for interference - self.merge_nodes_interfere(current, dom[-1], parent) - ): - return True - dom.append(current) - return False - - def aggressive_coalesce_parallel_copy(self, parallel_copies, parent): - """ - Try to coalesce variables each dst/src couple together from - @parallel_copies - - @parallel_copies: a dictionary representing dst/src parallel - assignments. - @parent: Optional parent location of the phi source for liveness tests - """ - for dst, src in viewitems(parallel_copies): - dst_merge = self.merge_state.setdefault(dst, set([dst])) - src_merge = self.merge_state.setdefault(src, set([src])) - if not self.merge_sets_interfere(dst_merge, src_merge, parent): - dst_merge.update(src_merge) - for node in dst_merge: - self.merge_state[node] = dst_merge - - def aggressive_coalesce_block(self): - """Try to coalesce phi var with their pre/post variables""" - - ircfg = self.ssa.graph - - # Run coalesce on the post phi parallel copy - for irblock in viewvalues(ircfg.blocks): - if not irblock_has_phi(irblock): - continue - parallel_copies = {} - for dst in self.phi_destinations[irblock.loc_key]: - parallel_copies[dst] = self.phi_new_var[dst] - self.aggressive_coalesce_parallel_copy(parallel_copies, None) - - # Run coalesce on the pre phi parallel copy - - # Stand for the virtual parallel copies at the end of Phi's block - # parents - parent_to_parallel_copies = {} - for dst in irblock[0]: - new_var = self.phi_new_var[dst] - for parent, src in self.phi_parent_sources[dst]: - parent_to_parallel_copies.setdefault(parent, {})[new_var] = src - - for parent, parallel_copies in viewitems(parent_to_parallel_copies): - self.aggressive_coalesce_parallel_copy(parallel_copies, parent) - - def get_best_merge_set_name(self, merge_set): - """ - For a given @merge_set, prefer an original SSA variable instead of a - created copy. In other case, take a random name. - @merge_set: set of equivalent expressions - """ - if not merge_set: - raise RuntimeError("Merge set should not be empty") - for var in merge_set: - if var not in self.copy_vars: - return var - # Get random name - return var - - - def replace_merge_sets(self): - """ - In the graph, replace all variables from merge state by their - representative variable - """ - replace = {} - merge_sets = set() - - # Elect representative for merge sets - merge_set_to_name = {} - for merge_set in viewvalues(self.merge_state): - frozen_merge_set = frozenset(merge_set) - merge_sets.add(frozen_merge_set) - var_name = self.get_best_merge_set_name(merge_set) - merge_set_to_name[frozen_merge_set] = var_name - - # Generate replacement of variable by their representative - for merge_set in merge_sets: - var_name = merge_set_to_name[merge_set] - merge_set = list(merge_set) - for var in merge_set: - replace[var] = var_name - - self.ssa.graph.simplify(lambda x: x.replace_expr(replace)) - - def remove_phi(self): - """ - Remove phi operators in @ifcfg - @ircfg: IRDiGraph instance - """ - - for irblock in list(viewvalues(self.ssa.graph.blocks)): - assignblks = list(irblock) - out = {} - for dst, src in viewitems(assignblks[0]): - if src.is_op('Phi'): - assert set([dst]) == set(src.args) - continue - out[dst] = src - assignblks[0] = AssignBlock(out, assignblks[0].instr) - self.ssa.graph.blocks[irblock.loc_key] = IRBlock(irblock.loc_key, assignblks) - - def remove_assign_eq(self): - """ - Remove trivial expressions (a=a) in the current graph - """ - for irblock in list(viewvalues(self.ssa.graph.blocks)): - assignblks = list(irblock) - for i, assignblk in enumerate(assignblks): - out = {} - for dst, src in viewitems(assignblk): - if dst == src: - continue - out[dst] = src - assignblks[i] = AssignBlock(out, assignblk.instr) - self.ssa.graph.blocks[irblock.loc_key] = IRBlock(irblock.loc_key, assignblks) |