diff options
Diffstat (limited to 'src/miasm/analysis')
| -rw-r--r-- | src/miasm/analysis/__init__.py | 1 | ||||
| -rw-r--r-- | src/miasm/analysis/binary.py | 233 | ||||
| -rw-r--r-- | src/miasm/analysis/cst_propag.py | 185 | ||||
| -rw-r--r-- | src/miasm/analysis/data_analysis.py | 204 | ||||
| -rw-r--r-- | src/miasm/analysis/data_flow.py | 2356 | ||||
| -rw-r--r-- | src/miasm/analysis/debugging.py | 557 | ||||
| -rw-r--r-- | src/miasm/analysis/depgraph.py | 659 | ||||
| -rw-r--r-- | src/miasm/analysis/disasm_cb.py | 127 | ||||
| -rw-r--r-- | src/miasm/analysis/dse.py | 717 | ||||
| -rw-r--r-- | src/miasm/analysis/expression_range.py | 70 | ||||
| -rw-r--r-- | src/miasm/analysis/gdbserver.py | 453 | ||||
| -rw-r--r-- | src/miasm/analysis/machine.py | 279 | ||||
| -rw-r--r-- | src/miasm/analysis/modularintervals.py | 525 | ||||
| -rw-r--r-- | src/miasm/analysis/outofssa.py | 415 | ||||
| -rw-r--r-- | src/miasm/analysis/sandbox.py | 1033 | ||||
| -rw-r--r-- | src/miasm/analysis/simplifier.py | 325 | ||||
| -rw-r--r-- | src/miasm/analysis/ssa.py | 731 |
17 files changed, 8870 insertions, 0 deletions
diff --git a/src/miasm/analysis/__init__.py b/src/miasm/analysis/__init__.py new file mode 100644 index 00000000..5abdd3a3 --- /dev/null +++ b/src/miasm/analysis/__init__.py @@ -0,0 +1 @@ +"High-level tools for binary analysis" diff --git a/src/miasm/analysis/binary.py b/src/miasm/analysis/binary.py new file mode 100644 index 00000000..c278594b --- /dev/null +++ b/src/miasm/analysis/binary.py @@ -0,0 +1,233 @@ +import logging +import warnings + +from miasm.core.bin_stream import bin_stream_str, bin_stream_elf, bin_stream_pe +from miasm.jitter.csts import PAGE_READ + + +log = logging.getLogger("binary") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log.addHandler(console_handler) +log.setLevel(logging.ERROR) + + +# Container +## Exceptions +class ContainerSignatureException(Exception): + "The container does not match the current container signature" + + +class ContainerParsingException(Exception): + "Error during container parsing" + + +## Parent class +class Container(object): + """Container abstraction layer + + This class aims to offer a common interface for abstracting container + such as PE or ELF. + """ + + available_container = [] # Available container formats + fallback_container = None # Fallback container format + + @classmethod + def from_string(cls, data, loc_db, *args, **kwargs): + """Instantiate a container and parse the binary + @data: str containing the binary + @loc_db: LocationDB instance + """ + log.info('Load binary') + # Try each available format + for container_type in cls.available_container: + try: + return container_type(data, loc_db, *args, **kwargs) + except ContainerSignatureException: + continue + except ContainerParsingException as error: + log.error(error) + + # Fallback mode + log.warning('Fallback to string input') + return cls.fallback_container(data, loc_db, *args, **kwargs) + + @classmethod + def register_container(cls, container): + "Add a Container format" + cls.available_container.append(container) + + @classmethod + def register_fallback(cls, container): + "Set the Container fallback format" + cls.fallback_container = container + + @classmethod + def from_stream(cls, stream, loc_db, *args, **kwargs): + """Instantiate a container and parse the binary + @stream: stream to use as binary + @vm: (optional) VmMngr instance to link with the executable + @addr: (optional) Base address of the parsed binary. If set, + force the unknown format + """ + return Container.from_string(stream.read(), loc_db, *args, **kwargs) + + def parse(self, data, *args, **kwargs): + """Launch parsing of @data + @data: str containing the binary + """ + raise NotImplementedError("Abstract method") + + def __init__(self, data, loc_db, **kwargs): + "Alias for 'parse'" + # Init attributes + self._executable = None + self._bin_stream = None + self._entry_point = None + self._arch = None + self._loc_db = loc_db + + # Launch parsing + self.parse(data, **kwargs) + + @property + def bin_stream(self): + "Return the BinStream instance corresponding to container content" + return self._bin_stream + + @property + def executable(self): + "Return the abstract instance standing for parsed executable" + return self._executable + + @property + def entry_point(self): + "Return the detected entry_point" + return self._entry_point + + @property + def arch(self): + "Return the guessed architecture" + return self._arch + + @property + def loc_db(self): + "LocationDB instance preloaded with container symbols (if any)" + return self._loc_db + + @property + def symbol_pool(self): + "[DEPRECATED API]" + warnings.warn("Deprecated API: use 'loc_db'") + return self.loc_db + +## Format dependent classes +class ContainerPE(Container): + "Container abstraction for PE" + + def parse(self, data, vm=None, **kwargs): + from miasm.jitter.loader.pe import vm_load_pe, guess_arch + from miasm.loader import pe_init + + # Parse signature + if not data.startswith(b'MZ'): + raise ContainerSignatureException() + + # Build executable instance + try: + if vm is not None: + self._executable = vm_load_pe(vm, data) + else: + self._executable = pe_init.PE(data) + except Exception as error: + raise ContainerParsingException('Cannot read PE: %s' % error) + + # Check instance validity + if not self._executable.isPE() or \ + self._executable.NTsig.signature_value != 0x4550: + raise ContainerSignatureException() + + # Guess the architecture + self._arch = guess_arch(self._executable) + + # Build the bin_stream instance and set the entry point + try: + self._bin_stream = bin_stream_pe(self._executable) + ep_detected = self._executable.Opthdr.AddressOfEntryPoint + self._entry_point = self._executable.rva2virt(ep_detected) + except Exception as error: + raise ContainerParsingException('Cannot read PE: %s' % error) + + +class ContainerELF(Container): + "Container abstraction for ELF" + + def parse(self, data, vm=None, addr=0, apply_reloc=False, **kwargs): + """Load an ELF from @data + @data: bytes containing the ELF bytes + @vm (optional): VmMngr instance. If set, load the ELF in virtual memory + @addr (optional): base address the ELF in virtual memory + @apply_reloc (optional): if set, apply relocation during ELF loading + + @addr and @apply_reloc are only meaningful in the context of a + non-empty @vm + """ + from miasm.jitter.loader.elf import vm_load_elf, guess_arch, \ + fill_loc_db_with_symbols + from miasm.loader import elf_init + + # Parse signature + if not data.startswith(b'\x7fELF'): + raise ContainerSignatureException() + + # Build executable instance + try: + if vm is not None: + self._executable = vm_load_elf( + vm, + data, + loc_db=self.loc_db, + base_addr=addr, + apply_reloc=apply_reloc + ) + else: + self._executable = elf_init.ELF(data) + except Exception as error: + raise ContainerParsingException('Cannot read ELF: %s' % error) + + # Guess the architecture + self._arch = guess_arch(self._executable) + + # Build the bin_stream instance and set the entry point + try: + self._bin_stream = bin_stream_elf(self._executable) + self._entry_point = self._executable.Ehdr.entry + addr + except Exception as error: + raise ContainerParsingException('Cannot read ELF: %s' % error) + + if vm is None: + # Add known symbols (vm_load_elf already does it) + fill_loc_db_with_symbols(self._executable, self.loc_db, addr) + + + +class ContainerUnknown(Container): + "Container abstraction for unknown format" + + def parse(self, data, vm=None, addr=0, **kwargs): + self._bin_stream = bin_stream_str(data, base_address=addr) + if vm is not None: + vm.add_memory_page( + addr, + PAGE_READ, + data + ) + self._executable = None + self._entry_point = 0 + + +## Register containers +Container.register_container(ContainerPE) +Container.register_container(ContainerELF) +Container.register_fallback(ContainerUnknown) diff --git a/src/miasm/analysis/cst_propag.py b/src/miasm/analysis/cst_propag.py new file mode 100644 index 00000000..cdb62d3c --- /dev/null +++ b/src/miasm/analysis/cst_propag.py @@ -0,0 +1,185 @@ +import logging + +from future.utils import viewitems + +from miasm.ir.symbexec import SymbolicExecutionEngine +from miasm.expression.expression import ExprMem +from miasm.expression.expression_helper import possible_values +from miasm.expression.simplifications import expr_simp +from miasm.ir.ir import IRBlock, AssignBlock + +LOG_CST_PROPAG = logging.getLogger("cst_propag") +CONSOLE_HANDLER = logging.StreamHandler() +CONSOLE_HANDLER.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +LOG_CST_PROPAG.addHandler(CONSOLE_HANDLER) +LOG_CST_PROPAG.setLevel(logging.WARNING) + + +class SymbExecState(SymbolicExecutionEngine): + """ + State manager for SymbolicExecution + """ + def __init__(self, lifter, ircfg, state): + super(SymbExecState, self).__init__(lifter, {}) + self.set_state(state) + + +def add_state(ircfg, todo, states, addr, state): + """ + Add or merge the computed @state for the block at @addr. Update @todo + @todo: modified block set + @states: dictionary linking a label to its entering state. + @addr: address of the considered block + @state: computed state + """ + addr = ircfg.get_loc_key(addr) + todo.add(addr) + if addr not in states: + states[addr] = state + else: + states[addr] = states[addr].merge(state) + + +def is_expr_cst(lifter, expr): + """Return true if @expr is only composed of ExprInt and init_regs + @lifter: Lifter instance + @expr: Expression to test""" + + elements = expr.get_r(mem_read=True) + for element in elements: + if element.is_mem(): + continue + if element.is_id() and element in lifter.arch.regs.all_regs_ids_init: + continue + if element.is_int(): + continue + return False + # Expr is a constant + return True + + +class SymbExecStateFix(SymbolicExecutionEngine): + """ + Emul blocks and replace expressions with their corresponding constant if + any. + + """ + # Function used to test if an Expression is considered as a constant + is_expr_cst = lambda _, lifter, expr: is_expr_cst(lifter, expr) + + def __init__(self, lifter, ircfg, state, cst_propag_link): + self.ircfg = ircfg + super(SymbExecStateFix, self).__init__(lifter, {}) + self.set_state(state) + self.cst_propag_link = cst_propag_link + + def propag_expr_cst(self, expr): + """Propagate constant expressions in @expr + @expr: Expression to update""" + elements = expr.get_r(mem_read=True) + to_propag = {} + for element in elements: + # Only ExprId can be safely propagated + if not element.is_id(): + continue + value = self.eval_expr(element) + if self.is_expr_cst(self.lifter, value): + to_propag[element] = value + return expr_simp(expr.replace_expr(to_propag)) + + def eval_updt_irblock(self, irb, step=False): + """ + Symbolic execution of the @irb on the current state + @irb: IRBlock instance + @step: display intermediate steps + """ + assignblks = [] + for index, assignblk in enumerate(irb): + new_assignblk = {} + links = {} + for dst, src in viewitems(assignblk): + src = self.propag_expr_cst(src) + if dst.is_mem(): + ptr = dst.ptr + ptr = self.propag_expr_cst(ptr) + dst = ExprMem(ptr, dst.size) + new_assignblk[dst] = src + + if assignblk.instr is not None: + for arg in assignblk.instr.args: + new_arg = self.propag_expr_cst(arg) + links[new_arg] = arg + self.cst_propag_link[(irb.loc_key, index)] = links + + self.eval_updt_assignblk(assignblk) + assignblks.append(AssignBlock(new_assignblk, assignblk.instr)) + self.ircfg.blocks[irb.loc_key] = IRBlock(irb.loc_db, irb.loc_key, assignblks) + + +def compute_cst_propagation_states(lifter, ircfg, init_addr, init_infos): + """ + Propagate "constant expressions" in a function. + The attribute "constant expression" is true if the expression is based on + constants or "init" regs values. + + @lifter: Lifter instance + @init_addr: analysis start address + @init_infos: dictionary linking expressions to their values at @init_addr + """ + + done = set() + state = SymbExecState.StateEngine(init_infos) + lbl = ircfg.get_loc_key(init_addr) + todo = set([lbl]) + states = {lbl: state} + + while todo: + if not todo: + break + lbl = todo.pop() + state = states[lbl] + if (lbl, state) in done: + continue + done.add((lbl, state)) + if lbl not in ircfg.blocks: + continue + + symbexec_engine = SymbExecState(lifter, ircfg, state) + addr = symbexec_engine.run_block_at(ircfg, lbl) + symbexec_engine.del_mem_above_stack(lifter.sp) + + for dst in possible_values(addr): + value = dst.value + if value.is_mem(): + LOG_CST_PROPAG.warning('Bad destination: %s', value) + continue + elif value.is_int(): + value = ircfg.get_loc_key(value) + add_state( + ircfg, todo, states, value, + symbexec_engine.get_state() + ) + + return states + + +def propagate_cst_expr(lifter, ircfg, addr, init_infos): + """ + Propagate "constant expressions" in a @lifter. + The attribute "constant expression" is true if the expression is based on + constants or "init" regs values. + + @lifter: Lifter instance + @addr: analysis start address + @init_infos: dictionary linking expressions to their values at @init_addr + + Returns a mapping between replaced Expression and their new values. + """ + states = compute_cst_propagation_states(lifter, ircfg, addr, init_infos) + cst_propag_link = {} + for lbl, state in viewitems(states): + if lbl not in ircfg.blocks: + continue + symbexec = SymbExecStateFix(lifter, ircfg, state, cst_propag_link) + symbexec.eval_updt_irblock(ircfg.blocks[lbl]) + return cst_propag_link diff --git a/src/miasm/analysis/data_analysis.py b/src/miasm/analysis/data_analysis.py new file mode 100644 index 00000000..c7924cf2 --- /dev/null +++ b/src/miasm/analysis/data_analysis.py @@ -0,0 +1,204 @@ +from __future__ import print_function + +from future.utils import viewitems + +from builtins import object +from functools import cmp_to_key +from miasm.expression.expression \ + import get_expr_mem, get_list_rw, ExprId, ExprInt, \ + compare_exprs +from miasm.ir.symbexec import SymbolicExecutionEngine + + +def get_node_name(label, i, n): + n_name = (label, i, n) + return n_name + + +def intra_block_flow_raw(lifter, ircfg, flow_graph, irb, in_nodes, out_nodes): + """ + Create data flow for an irbloc using raw IR expressions + """ + current_nodes = {} + for i, assignblk in enumerate(irb): + dict_rw = assignblk.get_rw(cst_read=True) + current_nodes.update(out_nodes) + + # gen mem arg to mem node links + all_mems = set() + for node_w, nodes_r in viewitems(dict_rw): + for n in nodes_r.union([node_w]): + all_mems.update(get_expr_mem(n)) + if not all_mems: + continue + + for n in all_mems: + node_n_w = get_node_name(irb.loc_key, i, n) + if not n in nodes_r: + continue + o_r = n.ptr.get_r(mem_read=False, cst_read=True) + for n_r in o_r: + if n_r in current_nodes: + node_n_r = current_nodes[n_r] + else: + node_n_r = get_node_name(irb.loc_key, i, n_r) + current_nodes[n_r] = node_n_r + in_nodes[n_r] = node_n_r + flow_graph.add_uniq_edge(node_n_r, node_n_w) + + # gen data flow links + for node_w, nodes_r in viewitems(dict_rw): + for n_r in nodes_r: + if n_r in current_nodes: + node_n_r = current_nodes[n_r] + else: + node_n_r = get_node_name(irb.loc_key, i, n_r) + current_nodes[n_r] = node_n_r + in_nodes[n_r] = node_n_r + + flow_graph.add_node(node_n_r) + + node_n_w = get_node_name(irb.loc_key, i + 1, node_w) + out_nodes[node_w] = node_n_w + + flow_graph.add_node(node_n_w) + flow_graph.add_uniq_edge(node_n_r, node_n_w) + + + +def inter_block_flow_link(lifter, ircfg, flow_graph, irb_in_nodes, irb_out_nodes, todo, link_exec_to_data): + lbl, current_nodes, exec_nodes = todo + current_nodes = dict(current_nodes) + + # link current nodes to block in_nodes + if not lbl in ircfg.blocks: + print("cannot find block!!", lbl) + return set() + irb = ircfg.blocks[lbl] + to_del = set() + for n_r, node_n_r in viewitems(irb_in_nodes[irb.loc_key]): + if not n_r in current_nodes: + continue + flow_graph.add_uniq_edge(current_nodes[n_r], node_n_r) + to_del.add(n_r) + + # if link exec to data, all nodes depends on exec nodes + if link_exec_to_data: + for n_x_r in exec_nodes: + for n_r, node_n_r in viewitems(irb_in_nodes[irb.loc_key]): + if not n_x_r in current_nodes: + continue + if isinstance(n_r, ExprInt): + continue + flow_graph.add_uniq_edge(current_nodes[n_x_r], node_n_r) + + # update current nodes using block out_nodes + for n_w, node_n_w in viewitems(irb_out_nodes[irb.loc_key]): + current_nodes[n_w] = node_n_w + + # get nodes involved in exec flow + x_nodes = tuple(sorted(irb.dst.get_r(), key=cmp_to_key(compare_exprs))) + + todo = set() + for lbl_dst in ircfg.successors(irb.loc_key): + todo.add((lbl_dst, tuple(viewitems(current_nodes)), x_nodes)) + + return todo + + +def create_implicit_flow(lifter, flow_graph, irb_in_nodes, irb_out_nodes): + + # first fix IN/OUT + # If a son read a node which in not in OUT, add it + todo = set(lifter.blocks.keys()) + while todo: + lbl = todo.pop() + irb = lifter.blocks[lbl] + for lbl_son in lifter.graph.successors(irb.loc_key): + if not lbl_son in lifter.blocks: + print("cannot find block!!", lbl) + continue + irb_son = lifter.blocks[lbl_son] + for n_r in irb_in_nodes[irb_son.loc_key]: + if n_r in irb_out_nodes[irb.loc_key]: + continue + if not isinstance(n_r, ExprId): + continue + + node_n_w = irb.loc_key, len(irb), n_r + irb_out_nodes[irb.loc_key][n_r] = node_n_w + if not n_r in irb_in_nodes[irb.loc_key]: + irb_in_nodes[irb.loc_key][n_r] = irb.loc_key, 0, n_r + node_n_r = irb_in_nodes[irb.loc_key][n_r] + for lbl_p in lifter.graph.predecessors(irb.loc_key): + todo.add(lbl_p) + + flow_graph.add_uniq_edge(node_n_r, node_n_w) + + +def inter_block_flow(lifter, ircfg, flow_graph, irb_0, irb_in_nodes, irb_out_nodes, link_exec_to_data=True): + + todo = set() + done = set() + todo.add((irb_0, (), ())) + + while todo: + state = todo.pop() + if state in done: + continue + done.add(state) + out = inter_block_flow_link(lifter, ircfg, flow_graph, irb_in_nodes, irb_out_nodes, state, link_exec_to_data) + todo.update(out) + + +class symb_exec_func(object): + + """ + This algorithm will do symbolic execution on a function, trying to propagate + states between basic blocks in order to extract inter-blocks dataflow. The + algorithm tries to merge states from blocks with multiple parents. + + There is no real magic here, loops and complex merging will certainly fail. + """ + + def __init__(self, lifter): + self.todo = set() + self.stateby_ad = {} + self.cpt = {} + self.states_var_done = set() + self.states_done = set() + self.total_done = 0 + self.lifter = lifter + + def add_state(self, parent, ad, state): + variables = dict(state.symbols) + + # get block dead, and remove from state + b = self.lifter.get_block(ad) + if b is None: + raise ValueError("unknown block! %s" % ad) + s = parent, ad, tuple(sorted(viewitems(variables))) + self.todo.add(s) + + def get_next_state(self): + state = self.todo.pop() + return state + + def do_step(self): + if len(self.todo) == 0: + return None + if self.total_done > 600: + print("symbexec watchdog!") + return None + self.total_done += 1 + print('CPT', self.total_done) + while self.todo: + state = self.get_next_state() + parent, ad, s = state + self.states_done.add(state) + self.states_var_done.add(state) + + sb = SymbolicExecutionEngine(self.lifter, dict(s)) + + return parent, ad, sb + return None diff --git a/src/miasm/analysis/data_flow.py b/src/miasm/analysis/data_flow.py new file mode 100644 index 00000000..23d0b3dd --- /dev/null +++ b/src/miasm/analysis/data_flow.py @@ -0,0 +1,2356 @@ +"""Data flow analysis based on miasm intermediate representation""" +from builtins import range +from collections import namedtuple, Counter +from pprint import pprint as pp +from future.utils import viewitems, viewvalues +from miasm.core.utils import encode_hex +from miasm.core.graph import DiGraph +from miasm.ir.ir import AssignBlock, IRBlock +from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\ + ExprAssign, ExprOp, ExprWalk, ExprSlice, \ + is_function_call, ExprVisitorCallbackBottomToTop +from miasm.expression.simplifications import expr_simp, expr_simp_explicit +from miasm.core.interval import interval +from miasm.expression.expression_helper import possible_values +from miasm.analysis.ssa import get_phi_sources_parent_block, \ + irblock_has_phi +from miasm.ir.symbexec import get_expr_base_offset +from collections import deque + +class ReachingDefinitions(dict): + """ + Computes for each assignblock the set of reaching definitions. + Example: + IR block: + lbl0: + 0 A = 1 + B = 3 + 1 B = 2 + 2 A = A + B + 4 + + Reach definition of lbl0: + (lbl0, 0) => {} + (lbl0, 1) => {A: {(lbl0, 0)}, B: {(lbl0, 0)}} + (lbl0, 2) => {A: {(lbl0, 0)}, B: {(lbl0, 1)}} + (lbl0, 3) => {A: {(lbl0, 2)}, B: {(lbl0, 1)}} + + Source set 'REACHES' in: Kennedy, K. (1979). + A survey of data flow analysis techniques. + IBM Thomas J. Watson Research Division, Algorithm MK + + This class is usable as a dictionary whose structure is + { (block, index): { lvalue: set((block, index)) } } + """ + + ircfg = None + + def __init__(self, ircfg): + super(ReachingDefinitions, self).__init__() + self.ircfg = ircfg + self.compute() + + def get_definitions(self, block_lbl, assignblk_index): + """Returns the dict { lvalue: set((def_block_lbl, def_index)) } + associated with self.ircfg.@block.assignblks[@assignblk_index] + or {} if it is not yet computed + """ + return self.get((block_lbl, assignblk_index), {}) + + def compute(self): + """This is the main fixpoint""" + modified = True + while modified: + modified = False + for block in viewvalues(self.ircfg.blocks): + modified |= self.process_block(block) + + def process_block(self, block): + """ + Fetch reach definitions from predecessors and propagate it to + the assignblk in block @block. + """ + predecessor_state = {} + for pred_lbl in self.ircfg.predecessors(block.loc_key): + if pred_lbl not in self.ircfg.blocks: + continue + pred = self.ircfg.blocks[pred_lbl] + for lval, definitions in viewitems(self.get_definitions(pred_lbl, len(pred))): + predecessor_state.setdefault(lval, set()).update(definitions) + + modified = self.get((block.loc_key, 0)) != predecessor_state + if not modified: + return False + self[(block.loc_key, 0)] = predecessor_state + + for index in range(len(block)): + modified |= self.process_assignblock(block, index) + return modified + + def process_assignblock(self, block, assignblk_index): + """ + Updates the reach definitions with values defined at + assignblock @assignblk_index in block @block. + NB: the effect of assignblock @assignblk_index in stored at index + (@block, @assignblk_index + 1). + """ + + assignblk = block[assignblk_index] + defs = self.get_definitions(block.loc_key, assignblk_index).copy() + for lval in assignblk: + defs.update({lval: set([(block.loc_key, assignblk_index)])}) + + modified = self.get((block.loc_key, assignblk_index + 1)) != defs + if modified: + self[(block.loc_key, assignblk_index + 1)] = defs + + return modified + +ATTR_DEP = {"color" : "black", + "_type" : "data"} + +AssignblkNode = namedtuple('AssignblkNode', ['label', 'index', 'var']) + + +class DiGraphDefUse(DiGraph): + """Representation of a Use-Definition graph as defined by + Kennedy, K. (1979). A survey of data flow analysis techniques. + IBM Thomas J. Watson Research Division. + Example: + IR block: + lbl0: + 0 A = 1 + B = 3 + 1 B = 2 + 2 A = A + B + 4 + + Def use analysis: + (lbl0, 0, A) => {(lbl0, 2, A)} + (lbl0, 0, B) => {} + (lbl0, 1, B) => {(lbl0, 2, A)} + (lbl0, 2, A) => {} + + """ + + + def __init__(self, reaching_defs, + deref_mem=False, apply_simp=False, *args, **kwargs): + """Instantiate a DiGraph + @blocks: IR blocks + """ + self._edge_attr = {} + + # For dot display + self._filter_node = None + self._dot_offset = None + self._blocks = reaching_defs.ircfg.blocks + + super(DiGraphDefUse, self).__init__(*args, **kwargs) + self._compute_def_use(reaching_defs, + deref_mem=deref_mem, + apply_simp=apply_simp) + + def edge_attr(self, src, dst): + """ + Return a dictionary of attributes for the edge between @src and @dst + @src: the source node of the edge + @dst: the destination node of the edge + """ + return self._edge_attr[(src, dst)] + + def _compute_def_use(self, reaching_defs, + deref_mem=False, apply_simp=False): + for block in viewvalues(self._blocks): + self._compute_def_use_block(block, + reaching_defs, + deref_mem=deref_mem, + apply_simp=apply_simp) + + def _compute_def_use_block(self, block, reaching_defs, deref_mem=False, apply_simp=False): + for index, assignblk in enumerate(block): + assignblk_reaching_defs = reaching_defs.get_definitions(block.loc_key, index) + for lval, expr in viewitems(assignblk): + self.add_node(AssignblkNode(block.loc_key, index, lval)) + + expr = expr_simp_explicit(expr) if apply_simp else expr + read_vars = expr.get_r(mem_read=deref_mem) + if deref_mem and lval.is_mem(): + read_vars.update(lval.ptr.get_r(mem_read=deref_mem)) + for read_var in read_vars: + for reach in assignblk_reaching_defs.get(read_var, set()): + self.add_data_edge(AssignblkNode(reach[0], reach[1], read_var), + AssignblkNode(block.loc_key, index, lval)) + + def del_edge(self, src, dst): + super(DiGraphDefUse, self).del_edge(src, dst) + del self._edge_attr[(src, dst)] + + def add_uniq_labeled_edge(self, src, dst, edge_label): + """Adds the edge (@src, @dst) with label @edge_label. + if edge (@src, @dst) already exists, the previous label is overridden + """ + self.add_uniq_edge(src, dst) + self._edge_attr[(src, dst)] = edge_label + + def add_data_edge(self, src, dst): + """Adds an edge representing a data dependency + and sets the label accordingly""" + self.add_uniq_labeled_edge(src, dst, ATTR_DEP) + + def node2lines(self, node): + lbl, index, reg = node + yield self.DotCellDescription(text="%s (%s)" % (lbl, index), + attr={'align': 'center', + 'colspan': 2, + 'bgcolor': 'grey'}) + src = self._blocks[lbl][index][reg] + line = "%s = %s" % (reg, src) + yield self.DotCellDescription(text=line, attr={}) + yield self.DotCellDescription(text="", attr={}) + + +class DeadRemoval(object): + """ + Do dead removal + """ + + def __init__(self, lifter, expr_to_original_expr=None): + self.lifter = lifter + if expr_to_original_expr is None: + expr_to_original_expr = {} + self.expr_to_original_expr = expr_to_original_expr + + + def add_expr_to_original_expr(self, expr_to_original_expr): + self.expr_to_original_expr.update(expr_to_original_expr) + + def is_unkillable_destination(self, lval, rval): + if ( + lval.is_mem() or + self.lifter.IRDst == lval or + lval.is_id("exception_flags") or + is_function_call(rval) + ): + return True + return False + + def get_block_useful_destinations(self, block): + """ + Force keeping of specific cases + block: IRBlock instance + """ + useful = set() + for index, assignblk in enumerate(block): + for lval, rval in viewitems(assignblk): + if self.is_unkillable_destination(lval, rval): + useful.add(AssignblkNode(block.loc_key, index, lval)) + return useful + + def is_tracked_var(self, lval, variable): + new_lval = self.expr_to_original_expr.get(lval, lval) + return new_lval == variable + + def find_definitions_from_worklist(self, worklist, ircfg): + """ + Find variables definition in @worklist by browsing the @ircfg + """ + locs_done = set() + + defs = set() + + while worklist: + found = False + elt = worklist.pop() + if elt in locs_done: + continue + locs_done.add(elt) + variable, loc_key = elt + block = ircfg.get_block(loc_key) + + if block is None: + # Consider no sources in incomplete graph + continue + + for index, assignblk in reversed(list(enumerate(block))): + for dst, src in viewitems(assignblk): + if self.is_tracked_var(dst, variable): + defs.add(AssignblkNode(loc_key, index, dst)) + found = True + break + if found: + break + + if not found: + for predecessor in ircfg.predecessors(loc_key): + worklist.add((variable, predecessor)) + + return defs + + def find_out_regs_definitions_from_block(self, block, ircfg): + """ + Find definitions of out regs starting from @block + """ + worklist = set() + for reg in self.lifter.get_out_regs(block): + worklist.add((reg, block.loc_key)) + ret = self.find_definitions_from_worklist(worklist, ircfg) + return ret + + + def add_def_for_incomplete_leaf(self, block, ircfg, reaching_defs): + """ + Add valid definitions at end of @block plus out regs + """ + valid_definitions = reaching_defs.get_definitions( + block.loc_key, + len(block) + ) + worklist = set() + for lval, definitions in viewitems(valid_definitions): + for definition in definitions: + new_lval = self.expr_to_original_expr.get(lval, lval) + worklist.add((new_lval, block.loc_key)) + ret = self.find_definitions_from_worklist(worklist, ircfg) + useful = ret + useful.update(self.find_out_regs_definitions_from_block(block, ircfg)) + return useful + + def get_useful_assignments(self, ircfg, defuse, reaching_defs): + """ + Mark useful statements using previous reach analysis and defuse + + Return a set of triplets (block, assignblk number, lvalue) of + useful definitions + PRE: compute_reach(self) + + """ + + useful = set() + + for block_lbl, block in viewitems(ircfg.blocks): + block = ircfg.get_block(block_lbl) + if block is None: + # skip unknown blocks: won't generate dependencies + continue + + block_useful = self.get_block_useful_destinations(block) + useful.update(block_useful) + + + successors = ircfg.successors(block_lbl) + for successor in successors: + if successor not in ircfg.blocks: + keep_all_definitions = True + break + else: + keep_all_definitions = False + + if keep_all_definitions: + useful.update(self.add_def_for_incomplete_leaf(block, ircfg, reaching_defs)) + continue + + if len(successors) == 0: + useful.update(self.find_out_regs_definitions_from_block(block, ircfg)) + else: + continue + + + + # Useful nodes dependencies + for node in useful: + for parent in defuse.reachable_parents(node): + yield parent + + def do_dead_removal(self, ircfg): + """ + Remove useless assignments. + + This function is used to analyse relation of a * complete function * + This means the blocks under study represent a solid full function graph. + + Source : Kennedy, K. (1979). A survey of data flow analysis techniques. + IBM Thomas J. Watson Research Division, page 43 + + @ircfg: Lifter instance + """ + + modified = False + reaching_defs = ReachingDefinitions(ircfg) + defuse = DiGraphDefUse(reaching_defs, deref_mem=True) + useful = self.get_useful_assignments(ircfg, defuse, reaching_defs) + useful = set(useful) + for block in list(viewvalues(ircfg.blocks)): + irs = [] + for idx, assignblk in enumerate(block): + new_assignblk = dict(assignblk) + for lval in assignblk: + if AssignblkNode(block.loc_key, idx, lval) not in useful: + del new_assignblk[lval] + modified = True + irs.append(AssignBlock(new_assignblk, assignblk.instr)) + ircfg.blocks[block.loc_key] = IRBlock(block.loc_db, block.loc_key, irs) + return modified + + def __call__(self, ircfg): + ret = self.do_dead_removal(ircfg) + return ret + + +def _test_merge_next_block(ircfg, loc_key): + """ + Test if the irblock at @loc_key can be merge with its son + @ircfg: IRCFG instance + @loc_key: LocKey instance of the candidate parent irblock + """ + + if loc_key not in ircfg.blocks: + return None + sons = ircfg.successors(loc_key) + if len(sons) != 1: + return None + son = list(sons)[0] + if ircfg.predecessors(son) != [loc_key]: + return None + if son not in ircfg.blocks: + return None + + return son + + +def _do_merge_blocks(ircfg, loc_key, son_loc_key): + """ + Merge two irblocks at @loc_key and @son_loc_key + + @ircfg: DiGrpahIR + @loc_key: LocKey instance of the parent IRBlock + @loc_key: LocKey instance of the son IRBlock + """ + + assignblks = [] + for assignblk in ircfg.blocks[loc_key]: + if ircfg.IRDst not in assignblk: + assignblks.append(assignblk) + continue + affs = {} + for dst, src in viewitems(assignblk): + if dst != ircfg.IRDst: + affs[dst] = src + if affs: + assignblks.append(AssignBlock(affs, assignblk.instr)) + + assignblks += ircfg.blocks[son_loc_key].assignblks + new_block = IRBlock(ircfg.loc_db, loc_key, assignblks) + + ircfg.discard_edge(loc_key, son_loc_key) + + for son_successor in ircfg.successors(son_loc_key): + ircfg.add_uniq_edge(loc_key, son_successor) + ircfg.discard_edge(son_loc_key, son_successor) + del ircfg.blocks[son_loc_key] + ircfg.del_node(son_loc_key) + ircfg.blocks[loc_key] = new_block + + +def _test_jmp_only(ircfg, loc_key, heads): + """ + If irblock at @loc_key sets only IRDst to an ExprLoc, return the + corresponding loc_key target. + Avoid creating predecssors for heads LocKeys + None in other cases. + + @ircfg: IRCFG instance + @loc_key: LocKey instance of the candidate irblock + @heads: LocKey heads of the graph + + """ + + if loc_key not in ircfg.blocks: + return None + irblock = ircfg.blocks[loc_key] + if len(irblock.assignblks) != 1: + return None + items = list(viewitems(dict(irblock.assignblks[0]))) + if len(items) != 1: + return None + if len(ircfg.successors(loc_key)) != 1: + return None + # Don't create predecessors on heads + dst, src = items[0] + assert dst.is_id("IRDst") + if not src.is_loc(): + return None + dst = src.loc_key + if loc_key in heads: + predecessors = set(ircfg.predecessors(dst)) + predecessors.difference_update(set([loc_key])) + if predecessors: + return None + return dst + + +def _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct): + """ + Link loc_key's parents to parents directly to son_loc_key + """ + for parent in set(ircfg.predecessors(loc_key)): + parent_block = ircfg.blocks.get(parent, None) + if parent_block is None: + continue + + new_block = parent_block.modify_exprs( + lambda expr:expr.replace_expr(replace_dct), + lambda expr:expr.replace_expr(replace_dct) + ) + + # Link parent to new dst + ircfg.add_uniq_edge(parent, son_loc_key) + + # Unlink block + ircfg.blocks[new_block.loc_key] = new_block + ircfg.del_node(loc_key) + + +def _remove_to_son(ircfg, loc_key, son_loc_key): + """ + Merge irblocks; The final block has the @son_loc_key loc_key + Update references + + Condition: + - irblock at @loc_key is a pure jump block + - @loc_key is not an entry point (can be removed) + + @irblock: IRCFG instance + @loc_key: LocKey instance of the parent irblock + @son_loc_key: LocKey instance of the son irblock + """ + + # Ircfg loop => don't mess + if loc_key == son_loc_key: + return False + + # Unlink block destinations + ircfg.del_edge(loc_key, son_loc_key) + + replace_dct = { + ExprLoc(loc_key, ircfg.IRDst.size):ExprLoc(son_loc_key, ircfg.IRDst.size) + } + + _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct) + + ircfg.del_node(loc_key) + del ircfg.blocks[loc_key] + + return True + + +def _remove_to_parent(ircfg, loc_key, son_loc_key): + """ + Merge irblocks; The final block has the @loc_key loc_key + Update references + + Condition: + - irblock at @loc_key is a pure jump block + - @son_loc_key is not an entry point (can be removed) + + @irblock: IRCFG instance + @loc_key: LocKey instance of the parent irblock + @son_loc_key: LocKey instance of the son irblock + """ + + # Ircfg loop => don't mess + if loc_key == son_loc_key: + return False + + # Unlink block destinations + ircfg.del_edge(loc_key, son_loc_key) + + old_irblock = ircfg.blocks[son_loc_key] + new_irblock = IRBlock(ircfg.loc_db, loc_key, old_irblock.assignblks) + + ircfg.blocks[son_loc_key] = new_irblock + + ircfg.add_irblock(new_irblock) + + replace_dct = { + ExprLoc(son_loc_key, ircfg.IRDst.size):ExprLoc(loc_key, ircfg.IRDst.size) + } + + _relink_block_node(ircfg, son_loc_key, loc_key, replace_dct) + + + ircfg.del_node(son_loc_key) + del ircfg.blocks[son_loc_key] + + return True + + +def merge_blocks(ircfg, heads): + """ + This function modifies @ircfg to apply the following transformations: + - group an irblock with its son if the irblock has one and only one son and + this son has one and only one parent (spaghetti code). + - if an irblock is only made of an assignment to IRDst with a given label, + this irblock is dropped and its parent destination targets are + updated. The irblock must have a parent (avoid deleting the function head) + - if an irblock is a head of the graph and is only made of an assignment to + IRDst with a given label, this irblock is dropped and its son becomes the + head. References are fixed + + This function avoid creating predecessors on heads + + Return True if at least an irblock has been modified + + @ircfg: IRCFG instance + @heads: loc_key to keep + """ + + modified = False + todo = set(ircfg.nodes()) + while todo: + loc_key = todo.pop() + + # Test merge block + son = _test_merge_next_block(ircfg, loc_key) + if son is not None and son not in heads: + _do_merge_blocks(ircfg, loc_key, son) + todo.add(loc_key) + modified = True + continue + + # Test jmp only block + son = _test_jmp_only(ircfg, loc_key, heads) + if son is not None and loc_key not in heads: + ret = _remove_to_son(ircfg, loc_key, son) + modified |= ret + if ret: + todo.add(loc_key) + continue + + # Test head jmp only block + if (son is not None and + son not in heads and + son in ircfg.blocks): + # jmp only test done previously + ret = _remove_to_parent(ircfg, loc_key, son) + modified |= ret + if ret: + todo.add(loc_key) + continue + + + return modified + + +def remove_empty_assignblks(ircfg): + """ + Remove empty assignblks in irblocks of @ircfg + Return True if at least an irblock has been modified + + @ircfg: IRCFG instance + """ + modified = False + for loc_key, block in list(viewitems(ircfg.blocks)): + irs = [] + block_modified = False + for assignblk in block: + if len(assignblk): + irs.append(assignblk) + else: + block_modified = True + if block_modified: + new_irblock = IRBlock(ircfg.loc_db, loc_key, irs) + ircfg.blocks[loc_key] = new_irblock + modified = True + return modified + + +class SSADefUse(DiGraph): + """ + Generate DefUse information from SSA transformation + Links are not valid for ExprMem. + """ + + def add_var_def(self, node, src): + index2dst = self._links.setdefault(node.label, {}) + dst2src = index2dst.setdefault(node.index, {}) + dst2src[node.var] = src + + def add_def_node(self, def_nodes, node, src): + if node.var.is_id(): + def_nodes[node.var] = node + + def add_use_node(self, use_nodes, node, src): + sources = set() + if node.var.is_mem(): + sources.update(node.var.ptr.get_r(mem_read=True)) + sources.update(src.get_r(mem_read=True)) + for source in sources: + if not source.is_mem(): + use_nodes.setdefault(source, set()).add(node) + + def get_node_target(self, node): + return self._links[node.label][node.index][node.var] + + def set_node_target(self, node, src): + self._links[node.label][node.index][node.var] = src + + @classmethod + def from_ssa(cls, ssa): + """ + Return a DefUse DiGraph from a SSA graph + @ssa: SSADiGraph instance + """ + + graph = cls() + # First pass + # Link line to its use and def + def_nodes = {} + use_nodes = {} + graph._links = {} + for lbl in ssa.graph.nodes(): + block = ssa.graph.blocks.get(lbl, None) + if block is None: + continue + for index, assignblk in enumerate(block): + for dst, src in viewitems(assignblk): + node = AssignblkNode(lbl, index, dst) + graph.add_var_def(node, src) + graph.add_def_node(def_nodes, node, src) + graph.add_use_node(use_nodes, node, src) + + for dst, node in viewitems(def_nodes): + graph.add_node(node) + if dst not in use_nodes: + continue + for use in use_nodes[dst]: + graph.add_uniq_edge(node, use) + + return graph + + + +def expr_has_mem(expr): + """ + Return True if expr contains at least one memory access + @expr: Expr instance + """ + + def has_mem(self): + return self.is_mem() + visitor = ExprWalk(has_mem) + return visitor.visit(expr) + + +def stack_to_reg(expr): + if expr.is_mem(): + ptr = expr.arg + SP = lifter.sp + if ptr == SP: + return ExprId("STACK.0", expr.size) + elif (ptr.is_op('+') and + len(ptr.args) == 2 and + ptr.args[0] == SP and + ptr.args[1].is_int()): + diff = int(ptr.args[1]) + assert diff % 4 == 0 + diff = (0 - diff) & 0xFFFFFFFF + return ExprId("STACK.%d" % (diff // 4), expr.size) + return False + + +def is_stack_access(lifter, expr): + if not expr.is_mem(): + return False + ptr = expr.ptr + diff = expr_simp(ptr - lifter.sp) + if not diff.is_int(): + return False + return expr + + +def visitor_get_stack_accesses(lifter, expr, stack_vars): + if is_stack_access(lifter, expr): + stack_vars.add(expr) + return expr + + +def get_stack_accesses(lifter, expr): + result = set() + def get_stack(expr_to_test): + visitor_get_stack_accesses(lifter, expr_to_test, result) + return None + visitor = ExprWalk(get_stack) + visitor.visit(expr) + return result + + +def get_interval_length(interval_in): + length = 0 + for start, stop in interval_in.intervals: + length += stop + 1 - start + return length + + +def check_expr_below_stack(lifter, expr): + """ + Return False if expr pointer is below original stack pointer + @lifter: lifter_model_call instance + @expr: Expression instance + """ + ptr = expr.ptr + diff = expr_simp(ptr - lifter.sp) + if not diff.is_int(): + return True + if int(diff) == 0 or int(expr_simp(diff.msb())) == 0: + return False + return True + + +def retrieve_stack_accesses(lifter, ircfg): + """ + Walk the ssa graph and find stack based variables. + Return a dictionary linking stack base address to its size/name + @lifter: lifter_model_call instance + @ircfg: IRCFG instance + """ + stack_vars = set() + for block in viewvalues(ircfg.blocks): + for assignblk in block: + for dst, src in viewitems(assignblk): + stack_vars.update(get_stack_accesses(lifter, dst)) + stack_vars.update(get_stack_accesses(lifter, src)) + stack_vars = [expr for expr in stack_vars if check_expr_below_stack(lifter, expr)] + + base_to_var = {} + for var in stack_vars: + base_to_var.setdefault(var.ptr, set()).add(var) + + + base_to_interval = {} + for addr, vars in viewitems(base_to_var): + var_interval = interval() + for var in vars: + offset = expr_simp(addr - lifter.sp) + if not offset.is_int(): + # skip non linear stack offset + continue + + start = int(offset) + stop = int(expr_simp(offset + ExprInt(var.size // 8, offset.size))) + mem = interval([(start, stop-1)]) + var_interval += mem + base_to_interval[addr] = var_interval + if not base_to_interval: + return {} + # Check if not intervals overlap + _, tmp = base_to_interval.popitem() + while base_to_interval: + addr, mem = base_to_interval.popitem() + assert (tmp & mem).empty + tmp += mem + + base_to_info = {} + for addr, vars in viewitems(base_to_var): + name = "var_%d" % (len(base_to_info)) + size = max([var.size for var in vars]) + base_to_info[addr] = size, name + return base_to_info + + +def fix_stack_vars(expr, base_to_info): + """ + Replace local stack accesses in expr using information in @base_to_info + @expr: Expression instance + @base_to_info: dictionary linking stack base address to its size/name + """ + if not expr.is_mem(): + return expr + ptr = expr.ptr + if ptr not in base_to_info: + return expr + size, name = base_to_info[ptr] + var = ExprId(name, size) + if size == expr.size: + return var + assert expr.size < size + return var[:expr.size] + + +def replace_mem_stack_vars(expr, base_to_info): + return expr.visit(lambda expr:fix_stack_vars(expr, base_to_info)) + + +def replace_stack_vars(lifter, ircfg): + """ + Try to replace stack based memory accesses by variables. + + Hypothesis: the input ircfg must have all it's accesses to stack explicitly + done through the stack register, ie every aliases on those variables is + resolved. + + WARNING: may fail + + @lifter: lifter_model_call instance + @ircfg: IRCFG instance + """ + + base_to_info = retrieve_stack_accesses(lifter, ircfg) + modified = False + for block in list(viewvalues(ircfg.blocks)): + assignblks = [] + for assignblk in block: + out = {} + for dst, src in viewitems(assignblk): + new_dst = dst.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) + new_src = src.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info)) + if new_dst != dst or new_src != src: + modified |= True + + out[new_dst] = new_src + + out = AssignBlock(out, assignblk.instr) + assignblks.append(out) + new_block = IRBlock(block.loc_db, block.loc_key, assignblks) + ircfg.blocks[block.loc_key] = new_block + return modified + + +def memlookup_test(expr, bs, is_addr_ro_variable, result): + if expr.is_mem() and expr.ptr.is_int(): + ptr = int(expr.ptr) + if is_addr_ro_variable(bs, ptr, expr.size): + result.add(expr) + return False + return True + + +def memlookup_visit(expr, bs, is_addr_ro_variable): + result = set() + def retrieve_memlookup(expr_to_test): + memlookup_test(expr_to_test, bs, is_addr_ro_variable, result) + return None + visitor = ExprWalk(retrieve_memlookup) + visitor.visit(expr) + return result + +def get_memlookup(expr, bs, is_addr_ro_variable): + return memlookup_visit(expr, bs, is_addr_ro_variable) + + +def read_mem(bs, expr): + ptr = int(expr.ptr) + var_bytes = bs.getbytes(ptr, expr.size // 8)[::-1] + try: + value = int(encode_hex(var_bytes), 16) + except ValueError: + return expr + return ExprInt(value, expr.size) + + +def load_from_int(ircfg, bs, is_addr_ro_variable): + """ + Replace memory read based on constant with static value + @ircfg: IRCFG instance + @bs: binstream instance + @is_addr_ro_variable: callback(addr, size) to test memory candidate + """ + + modified = False + for block in list(viewvalues(ircfg.blocks)): + assignblks = list() + for assignblk in block: + out = {} + for dst, src in viewitems(assignblk): + # Test src + mems = get_memlookup(src, bs, is_addr_ro_variable) + src_new = src + if mems: + replace = {} + for mem in mems: + value = read_mem(bs, mem) + replace[mem] = value + src_new = src.replace_expr(replace) + if src_new != src: + modified = True + # Test dst pointer if dst is mem + if dst.is_mem(): + ptr = dst.ptr + mems = get_memlookup(ptr, bs, is_addr_ro_variable) + if mems: + replace = {} + for mem in mems: + value = read_mem(bs, mem) + replace[mem] = value + ptr_new = ptr.replace_expr(replace) + if ptr_new != ptr: + modified = True + dst = ExprMem(ptr_new, dst.size) + out[dst] = src_new + out = AssignBlock(out, assignblk.instr) + assignblks.append(out) + block = IRBlock(block.loc_db, block.loc_key, assignblks) + ircfg.blocks[block.loc_key] = block + return modified + + +class AssignBlockLivenessInfos(object): + """ + Description of live in / live out of an AssignBlock + """ + + __slots__ = ["gen", "kill", "var_in", "var_out", "live", "assignblk"] + + def __init__(self, assignblk, gen, kill): + self.gen = gen + self.kill = kill + self.var_in = set() + self.var_out = set() + self.live = set() + self.assignblk = assignblk + + def __str__(self): + out = [] + out.append("\tVarIn:" + ", ".join(str(x) for x in self.var_in)) + out.append("\tGen:" + ", ".join(str(x) for x in self.gen)) + out.append("\tKill:" + ", ".join(str(x) for x in self.kill)) + out.append( + '\n'.join( + "\t%s = %s" % (dst, src) + for (dst, src) in viewitems(self.assignblk) + ) + ) + out.append("\tVarOut:" + ", ".join(str(x) for x in self.var_out)) + return '\n'.join(out) + + +class IRBlockLivenessInfos(object): + """ + Description of live in / live out of an AssignBlock + """ + __slots__ = ["loc_key", "infos", "assignblks"] + + + def __init__(self, irblock): + self.loc_key = irblock.loc_key + self.infos = [] + self.assignblks = [] + for assignblk in irblock: + gens, kills = set(), set() + for dst, src in viewitems(assignblk): + expr = ExprAssign(dst, src) + read = expr.get_r(mem_read=True) + write = expr.get_w() + gens.update(read) + kills.update(write) + self.infos.append(AssignBlockLivenessInfos(assignblk, gens, kills)) + self.assignblks.append(assignblk) + + def __getitem__(self, index): + """Getitem on assignblks""" + return self.assignblks.__getitem__(index) + + def __str__(self): + out = [] + out.append("%s:" % self.loc_key) + for info in self.infos: + out.append(str(info)) + out.append('') + return "\n".join(out) + + +class DiGraphLiveness(DiGraph): + """ + DiGraph representing variable liveness + """ + + def __init__(self, ircfg): + super(DiGraphLiveness, self).__init__() + self.ircfg = ircfg + self.loc_db = ircfg.loc_db + self._blocks = {} + # Add irblocks gen/kill + for node in ircfg.nodes(): + irblock = ircfg.blocks.get(node, None) + if irblock is None: + continue + irblockinfos = IRBlockLivenessInfos(irblock) + self.add_node(irblockinfos.loc_key) + self.blocks[irblockinfos.loc_key] = irblockinfos + for succ in ircfg.successors(node): + self.add_uniq_edge(node, succ) + for pred in ircfg.predecessors(node): + self.add_uniq_edge(pred, node) + + @property + def blocks(self): + return self._blocks + + def init_var_info(self): + """Add ircfg out regs""" + raise NotImplementedError("Abstract method") + + def node2lines(self, node): + """ + Output liveness information in dot format + """ + names = self.loc_db.get_location_names(node) + if not names: + node_name = self.loc_db.pretty_str(node) + else: + node_name = "".join("%s:\n" % name for name in names) + yield self.DotCellDescription( + text="%s" % node_name, + attr={ + 'align': 'center', + 'colspan': 2, + 'bgcolor': 'grey', + } + ) + if node not in self._blocks: + yield [self.DotCellDescription(text="NOT PRESENT", attr={})] + return + + for i, info in enumerate(self._blocks[node].infos): + var_in = "VarIn:" + ", ".join(str(x) for x in info.var_in) + var_out = "VarOut:" + ", ".join(str(x) for x in info.var_out) + + assignmnts = ["%s = %s" % (dst, src) for (dst, src) in viewitems(info.assignblk)] + + if i == 0: + yield self.DotCellDescription( + text=var_in, + attr={ + 'bgcolor': 'green', + } + ) + + for assign in assignmnts: + yield self.DotCellDescription(text=assign, attr={}) + yield self.DotCellDescription( + text=var_out, + attr={ + 'bgcolor': 'green', + } + ) + yield self.DotCellDescription(text="", attr={}) + + def back_propagate_compute(self, block): + """ + Compute the liveness information in the @block. + @block: AssignBlockLivenessInfos instance + """ + infos = block.infos + modified = False + for i in reversed(range(len(infos))): + new_vars = set(infos[i].gen.union(infos[i].var_out.difference(infos[i].kill))) + if infos[i].var_in != new_vars: + modified = True + infos[i].var_in = new_vars + if i > 0 and infos[i - 1].var_out != set(infos[i].var_in): + modified = True + infos[i - 1].var_out = set(infos[i].var_in) + return modified + + def back_propagate_to_parent(self, todo, node, parent): + """ + Back propagate the liveness information from @node to @parent. + @node: loc_key of the source node + @parent: loc_key of the node to update + """ + parent_block = self.blocks[parent] + cur_block = self.blocks[node] + if cur_block.infos[0].var_in == parent_block.infos[-1].var_out: + return + var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out) + parent_block.infos[-1].var_out = var_info + todo.add(parent) + + def compute_liveness(self): + """ + Compute the liveness information for the digraph. + """ + todo = set(self.leaves()) + while todo: + node = todo.pop() + cur_block = self.blocks.get(node, None) + if cur_block is None: + continue + modified = self.back_propagate_compute(cur_block) + if not modified: + continue + # We modified parent in, propagate to parents + for pred in self.predecessors(node): + self.back_propagate_to_parent(todo, node, pred) + return True + + +class DiGraphLivenessIRA(DiGraphLiveness): + """ + DiGraph representing variable liveness for IRA + """ + + def init_var_info(self, lifter): + """Add ircfg out regs""" + + for node in self.leaves(): + irblock = self.ircfg.blocks.get(node, None) + if irblock is None: + continue + var_out = lifter.get_out_regs(irblock) + irblock_liveness = self.blocks[node] + irblock_liveness.infos[-1].var_out = var_out + + +def discard_phi_sources(ircfg, deleted_vars): + """ + Remove phi sources in @ircfg belonging to @deleted_vars set + @ircfg: IRCFG instance in ssa form + @deleted_vars: unused phi sources + """ + for block in list(viewvalues(ircfg.blocks)): + if not block.assignblks: + continue + assignblk = block[0] + todo = {} + modified = False + for dst, src in viewitems(assignblk): + if not src.is_op('Phi'): + todo[dst] = src + continue + srcs = set(expr for expr in src.args if expr not in deleted_vars) + assert(srcs) + if len(srcs) > 1: + todo[dst] = ExprOp('Phi', *srcs) + continue + todo[dst] = srcs.pop() + modified = True + if not modified: + continue + assignblks = list(block) + assignblk = dict(assignblk) + assignblk.update(todo) + assignblk = AssignBlock(assignblk, assignblks[0].instr) + assignblks[0] = assignblk + new_irblock = IRBlock(block.loc_db, block.loc_key, assignblks) + ircfg.blocks[block.loc_key] = new_irblock + return True + + +def get_unreachable_nodes(ircfg, edges_to_del, heads): + """ + Return the unreachable nodes starting from heads and the associated edges to + be deleted. + + @ircfg: IRCFG instance + @edges_to_del: edges already marked as deleted + heads: locations of graph heads + """ + todo = set(heads) + visited_nodes = set() + new_edges_to_del = set() + while todo: + node = todo.pop() + if node in visited_nodes: + continue + visited_nodes.add(node) + for successor in ircfg.successors(node): + if (node, successor) not in edges_to_del: + todo.add(successor) + all_nodes = set(ircfg.nodes()) + nodes_to_del = all_nodes.difference(visited_nodes) + for node in nodes_to_del: + for successor in ircfg.successors(node): + if successor not in nodes_to_del: + # Frontier: link from a deleted node to a living node + new_edges_to_del.add((node, successor)) + return nodes_to_del, new_edges_to_del + + +def update_phi_with_deleted_edges(ircfg, edges_to_del): + """ + Update phi which have a source present in @edges_to_del + @ssa: IRCFG instance in ssa form + @edges_to_del: edges to delete + """ + + + phi_locs_to_srcs = {} + for loc_src, loc_dst in edges_to_del: + phi_locs_to_srcs.setdefault(loc_dst, set()).add(loc_src) + + modified = False + blocks = dict(ircfg.blocks) + for loc_dst, loc_srcs in viewitems(phi_locs_to_srcs): + if loc_dst not in ircfg.blocks: + continue + block = ircfg.blocks[loc_dst] + if not irblock_has_phi(block): + continue + assignblks = list(block) + assignblk = assignblks[0] + out = {} + for dst, phi_sources in viewitems(assignblk): + if not phi_sources.is_op('Phi'): + out[dst] = phi_sources + continue + var_to_parents = get_phi_sources_parent_block( + ircfg, + loc_dst, + phi_sources.args + ) + to_keep = set(phi_sources.args) + for src in phi_sources.args: + parents = var_to_parents[src] + remaining = parents.difference(loc_srcs) + if not remaining: + to_keep.discard(src) + modified = True + assert to_keep + if len(to_keep) == 1: + out[dst] = to_keep.pop() + else: + out[dst] = ExprOp('Phi', *to_keep) + assignblk = AssignBlock(out, assignblks[0].instr) + assignblks[0] = assignblk + new_irblock = IRBlock(block.loc_db, loc_dst, assignblks) + blocks[block.loc_key] = new_irblock + + for loc_key, block in viewitems(blocks): + ircfg.blocks[loc_key] = block + return modified + + +def del_unused_edges(ircfg, heads): + """ + Delete non accessible edges in the @ircfg graph. + @ircfg: IRCFG instance in ssa form + @heads: location of the heads of the graph + """ + + deleted_vars = set() + modified = False + edges_to_del_1 = set() + for node in ircfg.nodes(): + successors = set(ircfg.successors(node)) + block = ircfg.blocks.get(node, None) + if block is None: + continue + dst = block.dst + possible_dsts = set(solution.value for solution in possible_values(dst)) + if not all(dst.is_loc() for dst in possible_dsts): + continue + possible_dsts = set(dst.loc_key for dst in possible_dsts) + if len(possible_dsts) == len(successors): + continue + dsts_to_del = successors.difference(possible_dsts) + for dst in dsts_to_del: + edges_to_del_1.add((node, dst)) + + # Remove edges and update phi accordingly + # Two cases here: + # - edge is directly linked to a phi node + # - edge is indirect linked to a phi node + nodes_to_del, edges_to_del_2 = get_unreachable_nodes(ircfg, edges_to_del_1, heads) + modified |= update_phi_with_deleted_edges(ircfg, edges_to_del_1.union(edges_to_del_2)) + + for src, dst in edges_to_del_1.union(edges_to_del_2): + ircfg.del_edge(src, dst) + for node in nodes_to_del: + if node not in ircfg.blocks: + continue + block = ircfg.blocks[node] + ircfg.del_node(node) + del ircfg.blocks[node] + + for assignblock in block: + for dst in assignblock: + deleted_vars.add(dst) + + if deleted_vars: + modified |= discard_phi_sources(ircfg, deleted_vars) + + return modified + + +class DiGraphLivenessSSA(DiGraphLivenessIRA): + """ + DiGraph representing variable liveness is a SSA graph + """ + def __init__(self, ircfg): + super(DiGraphLivenessSSA, self).__init__(ircfg) + + self.loc_key_to_phi_parents = {} + for irblock in viewvalues(self.blocks): + if not irblock_has_phi(irblock): + continue + out = {} + for sources in viewvalues(irblock[0]): + if not sources.is_op('Phi'): + # Some phi sources may have already been resolved to an + # expression + continue + var_to_parents = get_phi_sources_parent_block(self, irblock.loc_key, sources.args) + for var, var_parents in viewitems(var_to_parents): + out.setdefault(var, set()).update(var_parents) + self.loc_key_to_phi_parents[irblock.loc_key] = out + + def back_propagate_to_parent(self, todo, node, parent): + if parent not in self.blocks: + return + parent_block = self.blocks[parent] + cur_block = self.blocks[node] + irblock = self.ircfg.blocks[node] + if cur_block.infos[0].var_in == parent_block.infos[-1].var_out: + return + var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out) + + if irblock_has_phi(irblock): + # Remove phi special case + out = set() + phi_sources = self.loc_key_to_phi_parents[irblock.loc_key] + for var in var_info: + if var not in phi_sources: + out.add(var) + continue + if parent in phi_sources[var]: + out.add(var) + var_info = out + + parent_block.infos[-1].var_out = var_info + todo.add(parent) + + +def get_phi_sources(phi_src, phi_dsts, ids_to_src): + """ + Return False if the @phi_src has more than one non-phi source + Else, return its source + @ids_to_src: Dictionary linking phi source to its definition + """ + true_values = set() + for src in phi_src.args: + if src in phi_dsts: + # Source is phi dst => skip + continue + true_src = ids_to_src[src] + if true_src in phi_dsts: + # Source is phi dst => skip + continue + # Check if src is not also a phi + if true_src.is_op('Phi'): + phi_dsts.add(src) + true_src = get_phi_sources(true_src, phi_dsts, ids_to_src) + if true_src is False: + return False + if true_src is True: + continue + true_values.add(true_src) + if len(true_values) != 1: + return False + if not true_values: + return True + if len(true_values) != 1: + return False + true_value = true_values.pop() + return true_value + + +class DelDummyPhi(object): + """ + Del dummy phi + Find nodes which are in the same equivalence class and replace phi nodes by + the class representative. + """ + + def src_gen_phi_node_srcs(self, equivalence_graph): + for node in equivalence_graph.nodes(): + if not node.is_op("Phi"): + continue + phi_successors = equivalence_graph.successors(node) + for head in phi_successors: + # Walk from head to find if we have a phi merging node + known = set([node]) + todo = set([head]) + done = set() + while todo: + node = todo.pop() + if node in done: + continue + + known.add(node) + is_ok = True + for parent in equivalence_graph.predecessors(node): + if parent not in known: + is_ok = False + break + if not is_ok: + continue + if node.is_op("Phi"): + successors = equivalence_graph.successors(node) + phi_node = successors.pop() + return set([phi_node]), phi_node, head, equivalence_graph + done.add(node) + for successor in equivalence_graph.successors(node): + todo.add(successor) + return None + + def get_equivalence_class(self, node, ids_to_src): + todo = set([node]) + done = set() + defined = set() + equivalence = set() + src_to_dst = {} + equivalence_graph = DiGraph() + while todo: + dst = todo.pop() + if dst in done: + continue + done.add(dst) + equivalence.add(dst) + src = ids_to_src.get(dst) + if src is None: + # Node is not defined + continue + src_to_dst[src] = dst + defined.add(dst) + if src.is_id(): + equivalence_graph.add_uniq_edge(src, dst) + todo.add(src) + elif src.is_op('Phi'): + equivalence_graph.add_uniq_edge(src, dst) + for arg in src.args: + assert arg.is_id() + equivalence_graph.add_uniq_edge(arg, src) + todo.add(arg) + else: + if src.is_mem() or (src.is_op() and src.op.startswith("call")): + if src in equivalence_graph.nodes(): + return None + equivalence_graph.add_uniq_edge(src, dst) + equivalence.add(src) + + if len(equivalence_graph.heads()) == 0: + raise RuntimeError("Inconsistent graph") + elif len(equivalence_graph.heads()) == 1: + # Every nodes in the equivalence graph may be equivalent to the root + head = equivalence_graph.heads().pop() + successors = equivalence_graph.successors(head) + if len(successors) == 1: + # If successor is an id + successor = successors.pop() + if successor.is_id(): + nodes = equivalence_graph.nodes() + nodes.discard(head) + nodes.discard(successor) + nodes = [node for node in nodes if node.is_id()] + return nodes, successor, head, equivalence_graph + else: + # Walk from head to find if we have a phi merging node + known = set() + todo = set([head]) + done = set() + while todo: + node = todo.pop() + if node in done: + continue + known.add(node) + is_ok = True + for parent in equivalence_graph.predecessors(node): + if parent not in known: + is_ok = False + break + if not is_ok: + continue + if node.is_op("Phi"): + successors = equivalence_graph.successors(node) + assert len(successors) == 1 + phi_node = successors.pop() + return set([phi_node]), phi_node, head, equivalence_graph + done.add(node) + for successor in equivalence_graph.successors(node): + todo.add(successor) + + return self.src_gen_phi_node_srcs(equivalence_graph) + + def del_dummy_phi(self, ssa, head): + ids_to_src = {} + def_to_loc = {} + for block in viewvalues(ssa.graph.blocks): + for index, assignblock in enumerate(block): + for dst, src in viewitems(assignblock): + if not dst.is_id(): + continue + ids_to_src[dst] = src + def_to_loc[dst] = block.loc_key + + + modified = False + for loc_key in ssa.graph.blocks.keys(): + block = ssa.graph.blocks[loc_key] + if not irblock_has_phi(block): + continue + assignblk = block[0] + for dst, phi_src in viewitems(assignblk): + assert phi_src.is_op('Phi') + result = self.get_equivalence_class(dst, ids_to_src) + if result is None: + continue + defined, node, true_value, equivalence_graph = result + if expr_has_mem(true_value): + # Don't propagate ExprMem + continue + if true_value.is_op() and true_value.op.startswith("call"): + # Don't propagate call + continue + # We have an equivalence of nodes + to_del = set(defined) + # Remove all implicated phis + for dst in to_del: + loc_key = def_to_loc[dst] + block = ssa.graph.blocks[loc_key] + + assignblk = block[0] + fixed_phis = {} + for old_dst, old_phi_src in viewitems(assignblk): + if old_dst in defined: + continue + fixed_phis[old_dst] = old_phi_src + + assignblks = list(block) + assignblks[0] = AssignBlock(fixed_phis, assignblk.instr) + assignblks[1:1] = [AssignBlock({dst: true_value}, assignblk.instr)] + new_irblock = IRBlock(block.loc_db, block.loc_key, assignblks) + ssa.graph.blocks[loc_key] = new_irblock + modified = True + return modified + + +def replace_expr_from_bottom(expr_orig, dct): + def replace(expr): + if expr in dct: + return dct[expr] + return expr + visitor = ExprVisitorCallbackBottomToTop(lambda expr:replace(expr)) + return visitor.visit(expr_orig) + + +def is_mem_sub_part(needle, mem): + """ + If @needle is a sub part of @mem, return the offset of @needle in @mem + Else, return False + @needle: ExprMem + @mem: ExprMem + """ + ptr_base_a, ptr_offset_a = get_expr_base_offset(needle.ptr) + ptr_base_b, ptr_offset_b = get_expr_base_offset(mem.ptr) + if ptr_base_a != ptr_base_b: + return False + # Test if sub part starts after mem + if not (ptr_offset_b <= ptr_offset_a < ptr_offset_b + mem.size // 8): + return False + # Test if sub part ends before mem + if not (ptr_offset_a + needle.size // 8 <= ptr_offset_b + mem.size // 8): + return False + return ptr_offset_a - ptr_offset_b + +class UnionFind(object): + """ + Implementation of UnionFind structure + __classes: a list of Set of equivalent elements + node_to_class: Dictionary linkink an element to its equivalent class + order: Dictionary link an element to it's weight + + The order attributes is used to allow the selection of a representative + element of an equivalence class + """ + + def __init__(self): + self.index = 0 + self.__classes = [] + self.node_to_class = {} + self.order = dict() + + def copy(self): + """ + Return a copy of the object + """ + unionfind = UnionFind() + unionfind.index = self.index + unionfind.__classes = [set(known_class) for known_class in self.__classes] + node_to_class = {} + for class_eq in unionfind.__classes: + for node in class_eq: + node_to_class[node] = class_eq + unionfind.node_to_class = node_to_class + unionfind.order = dict(self.order) + return unionfind + + def replace_node(self, old_node, new_node): + """ + Replace the @old_node by the @new_node + """ + classes = self.get_classes() + + new_classes = [] + replace_dct = {old_node:new_node} + for eq_class in classes: + new_class = set() + for node in eq_class: + new_class.add(replace_expr_from_bottom(node, replace_dct)) + new_classes.append(new_class) + + node_to_class = {} + for class_eq in new_classes: + for node in class_eq: + node_to_class[node] = class_eq + self.__classes = new_classes + self.node_to_class = node_to_class + new_order = dict() + for node,index in self.order.items(): + new_node = replace_expr_from_bottom(node, replace_dct) + new_order[new_node] = index + self.order = new_order + + def get_classes(self): + """ + Return a list of the equivalent classes + """ + classes = [] + for class_tmp in self.__classes: + classes.append(set(class_tmp)) + return classes + + def nodes(self): + for known_class in self.__classes: + for node in known_class: + yield node + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + + return Counter(frozenset(known_class) for known_class in self.__classes) == Counter(frozenset(known_class) for known_class in other.__classes) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def __str__(self): + components = self.__classes + out = ['UnionFind<'] + for component in components: + out.append("\t" + (", ".join([str(node) for node in component]))) + out.append('>') + return "\n".join(out) + + def add_equivalence(self, node_a, node_b): + """ + Add the new equivalence @node_a == @node_b + @node_a is equivalent to @node_b, but @node_b is more representative + than @node_a + """ + if node_b not in self.order: + self.order[node_b] = self.index + self.index += 1 + # As node_a is destination, we always replace its index + self.order[node_a] = self.index + self.index += 1 + + if node_a not in self.node_to_class and node_b not in self.node_to_class: + new_class = set([node_a, node_b]) + self.node_to_class[node_a] = new_class + self.node_to_class[node_b] = new_class + self.__classes.append(new_class) + elif node_a in self.node_to_class and node_b not in self.node_to_class: + known_class = self.node_to_class[node_a] + known_class.add(node_b) + self.node_to_class[node_b] = known_class + elif node_a not in self.node_to_class and node_b in self.node_to_class: + known_class = self.node_to_class[node_b] + known_class.add(node_a) + self.node_to_class[node_a] = known_class + else: + raise RuntimeError("Two nodes cannot be in two classes") + + def _get_master(self, node): + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + best_node = node + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + + def get_master(self, node): + """ + Return the representative element of the equivalence class containing + @node + @node: ExprMem or ExprId + """ + if not node.is_mem(): + return self._get_master(node) + if node in self.node_to_class: + # Full expr mem is known + return self._get_master(node) + # Test if mem is sub part of known node + for expr in self.node_to_class: + if not expr.is_mem(): + continue + ret = is_mem_sub_part(node, expr) + if ret is False: + continue + master = self._get_master(expr) + master = master[ret * 8 : ret * 8 + node.size] + return master + + return self._get_master(node) + + + def del_element(self, node): + """ + Remove @node for the equivalence classes + """ + assert node in self.node_to_class + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + def del_get_new_master(self, node): + """ + Remove @node for the equivalence classes and return it's representative + equivalent element + @node: Element to delete + """ + if node not in self.node_to_class: + return None + known_class = self.node_to_class[node] + known_class.discard(node) + del(self.node_to_class[node]) + del(self.order[node]) + + if not known_class: + return None + best_node = list(known_class)[0] + for node in known_class: + if self.order[node] < self.order[best_node]: + best_node = node + return best_node + +class ExprToGraph(ExprWalk): + """ + Transform an Expression into a tree and add link nodes to an existing tree + """ + def __init__(self, graph): + super(ExprToGraph, self).__init__(self.link_nodes) + self.graph = graph + + def link_nodes(self, expr, *args, **kwargs): + """ + Transform an Expression @expr into a tree and add link nodes to the + current tree + @expr: Expression + """ + if expr in self.graph.nodes(): + return None + self.graph.add_node(expr) + if expr.is_mem(): + self.graph.add_uniq_edge(expr, expr.ptr) + elif expr.is_slice(): + self.graph.add_uniq_edge(expr, expr.arg) + elif expr.is_cond(): + self.graph.add_uniq_edge(expr, expr.cond) + self.graph.add_uniq_edge(expr, expr.src1) + self.graph.add_uniq_edge(expr, expr.src2) + elif expr.is_compose(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + elif expr.is_op(): + for arg in expr.args: + self.graph.add_uniq_edge(expr, arg) + return None + +class State(object): + """ + Object representing the state of a program at a given point + The state is represented using equivalence classes + + Each assignment can create/destroy equivalence classes. Interferences + between expression is computed using `may_interfer` function + """ + + def __init__(self): + self.equivalence_classes = UnionFind() + self.undefined = set() + + def __str__(self): + return "{0.equivalence_classes}\n{0.undefined}".format(self) + + def copy(self): + state = self.__class__() + state.equivalence_classes = self.equivalence_classes.copy() + state.undefined = self.undefined.copy() + return state + + def __eq__(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + return ( + set(self.equivalence_classes.nodes()) == set(other.equivalence_classes.nodes()) and + sorted(self.equivalence_classes.edges()) == sorted(other.equivalence_classes.edges()) and + self.undefined == other.undefined + ) + + def __ne__(self, other): + # required Python 2.7.14 + return not self == other + + def may_interfer(self, dsts, src): + """ + Return True if @src may interfere with expressions in @dsts + @dsts: Set of Expressions + @src: expression to test + """ + + srcs = src.get_r() + for src in srcs: + for dst in dsts: + if dst in src: + return True + if dst.is_mem() and src.is_mem(): + dst_base, dst_offset = get_expr_base_offset(dst.ptr) + src_base, src_offset = get_expr_base_offset(src.ptr) + if dst_base != src_base: + return True + dst_size = dst.size // 8 + src_size = src.size // 8 + # Special case: + # @32[ESP + 0xFFFFFFFE], @32[ESP] + # Both memories alias + if dst_offset + dst_size <= int(dst_base.mask) + 1: + # @32[ESP + 0xFFFFFFFC] => [0xFFFFFFFC, 0xFFFFFFFF] + interval1 = interval([(dst_offset, dst_offset + dst.size // 8 - 1)]) + else: + # @32[ESP + 0xFFFFFFFE] => [0x0, 0x1] U [0xFFFFFFFE, 0xFFFFFFFF] + interval1 = interval([(dst_offset, int(dst_base.mask))]) + interval1 += interval([(0, dst_size - (int(dst_base.mask) + 1 - dst_offset) - 1 )]) + if src_offset + src_size <= int(src_base.mask) + 1: + # @32[ESP + 0xFFFFFFFC] => [0xFFFFFFFC, 0xFFFFFFFF] + interval2 = interval([(src_offset, src_offset + src.size // 8 - 1)]) + else: + # @32[ESP + 0xFFFFFFFE] => [0x0, 0x1] U [0xFFFFFFFE, 0xFFFFFFFF] + interval2 = interval([(src_offset, int(src_base.mask))]) + interval2 += interval([(0, src_size - (int(src_base.mask) + 1 - src_offset) - 1)]) + if (interval1 & interval2).empty: + continue + return True + return False + + def _get_representative_expr(self, expr): + representative = self.equivalence_classes.get_master(expr) + if representative is None: + return expr + return representative + + def get_representative_expr(self, expr): + """ + Replace each sub expression of @expr by its representative element + @expr: Expression to analyse + """ + new_expr = expr.visit(self._get_representative_expr) + return new_expr + + def propagation_allowed(self, expr): + """ + Return True if @expr can be propagated + Don't propagate: + - Phi nodes + - call_func_ret / call_func_stack operants + """ + + if ( + expr.is_op('Phi') or + (expr.is_op() and expr.op.startswith("call_func")) + ): + return False + return True + + def eval_assignblock(self, assignblock): + """ + Evaluate the @assignblock on the current state + @assignblock: AssignBlock instance + """ + + out = dict(assignblock.items()) + new_out = dict() + # Replace sub expression by their equivalence class repesentative + for dst, src in out.items(): + if src.is_op('Phi'): + # Don't replace in phi + new_src = src + else: + new_src = self.get_representative_expr(src) + if dst.is_mem(): + new_ptr = self.get_representative_expr(dst.ptr) + new_dst = ExprMem(new_ptr, dst.size) + else: + new_dst = dst + new_dst = expr_simp(new_dst) + new_src = expr_simp(new_src) + new_out[new_dst] = new_src + + # For each destination, update (or delete) dependent's node according to + # equivalence classes + classes = self.equivalence_classes + + for dst in new_out: + + replacement = classes.del_get_new_master(dst) + if replacement is None: + to_del = set([dst]) + to_replace = {} + else: + to_del = set() + to_replace = {dst:replacement} + + graph = DiGraph() + # Build en expression graph linking all classes + has_parents = False + for node in classes.nodes(): + if dst in node: + # Only dependent nodes are interesting here + has_parents = True + expr_to_graph = ExprToGraph(graph) + expr_to_graph.visit(node) + + if not has_parents: + continue + + todo = graph.leaves() + done = set() + + while todo: + node = todo.pop(0) + if node in done: + continue + # If at least one son is not done, re do later + if [son for son in graph.successors(node) if son not in done]: + todo.append(node) + continue + done.add(node) + + # If at least one son cannot be replaced (deleted), our last + # chance is to have an equivalence + if any(son in to_del for son in graph.successors(node)): + # One son has been deleted! + # Try to find a replacement of the whole expression + replacement = classes.del_get_new_master(node) + if replacement is None: + to_del.add(node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + else: + to_replace[node] = replacement + # Continue with replacement + + # Everyson is live or has been replaced + new_node = node.replace_expr(to_replace) + + if new_node == node: + # If node is not touched (Ex: leaf node) + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + continue + + # Node has been modified, update equivalence classes + classes.replace_node(node, new_node) + to_replace[node] = new_node + + for predecessor in graph.predecessors(node): + if predecessor not in todo: + todo.append(predecessor) + + continue + + new_assignblk = AssignBlock(new_out, assignblock.instr) + dsts = new_out.keys() + + # Remove interfering known classes + to_del = set() + for node in list(classes.nodes()): + if self.may_interfer(dsts, node): + # Interfere with known equivalence class + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + + # Update equivalence classes + for dst, src in new_out.items(): + # Delete equivalence class interfering with dst + to_del = set() + classes = self.equivalence_classes + for node in classes.nodes(): + if dst in node: + to_del.add(node) + for node in to_del: + self.equivalence_classes.del_element(node) + if node.is_id() or node.is_mem(): + self.undefined.add(node) + + # Don't create equivalence if self interfer + if self.may_interfer(dsts, src): + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + if dst.is_id() or dst.is_mem(): + self.undefined.add(dst) + continue + + if not self.propagation_allowed(src): + continue + + self.undefined.discard(dst) + if dst in self.equivalence_classes.nodes(): + self.equivalence_classes.del_element(dst) + self.equivalence_classes.add_equivalence(dst, src) + + return new_assignblk + + + def merge(self, other): + """ + Merge the current state with @other + Merge rules: + - if two nodes are equal in both states => in equivalence class + - if node value is different or non present in another state => undefined + @other: State instance + """ + classes1 = self.equivalence_classes + classes2 = other.equivalence_classes + + undefined = set(node for node in self.undefined if node.is_id() or node.is_mem()) + undefined.update(set(node for node in other.undefined if node.is_id() or node.is_mem())) + # Should we compute interference between srcs and undefined ? + # Nop => should already interfere in other state + components1 = classes1.get_classes() + components2 = classes2.get_classes() + + node_to_component2 = {} + for component in components2: + for node in component: + node_to_component2[node] = component + + # Compute intersection of equivalence classes of states + out = [] + nodes_ok = set() + while components1: + component1 = components1.pop() + for node in component1: + if node in undefined: + continue + component2 = node_to_component2.get(node) + if component2 is None: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + continue + if node not in component2: + continue + # Found two classes containing node + common = component1.intersection(component2) + if len(common) == 1: + # Intersection contains only one node => undefine node + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + component2.discard(common.pop()) + continue + if common: + # Intersection contains multiple nodes + # Here, common nodes don't interfere with any undefined + nodes_ok.update(common) + out.append(common) + diff = component1.difference(common) + if diff: + components1.append(diff) + component2.difference_update(common) + break + + # Discard remaining components2 elements + for component in components2: + for node in component: + if node.is_id() or node.is_mem(): + assert(node not in nodes_ok) + undefined.add(node) + + all_nodes = set() + for common in out: + all_nodes.update(common) + + new_order = dict( + (node, index) for (node, index) in classes1.order.items() + if node in all_nodes + ) + + unionfind = UnionFind() + new_classes = [] + global_max_index = 0 + for common in out: + min_index = None + master = None + for node in common: + index = new_order[node] + global_max_index = max(index, global_max_index) + if min_index is None or min_index > index: + min_index = index + master = node + for node in common: + if node == master: + continue + unionfind.add_equivalence(node, master) + + unionfind.index = global_max_index + unionfind.order = new_order + state = self.__class__() + state.equivalence_classes = unionfind + state.undefined = undefined + + return state + + +class PropagateExpressions(object): + """ + Propagate expressions + + The algorithm propagates equivalence classes expressions from the entry + point. During the analyse, we replace source nodes by its equivalence + classes representative. Equivalence classes can be modified during analyse + due to memory aliasing. + + For example: + B = A+1 + C = A + A = 6 + D = [B] + + Will result in: + B = A+1 + C = A + A = 6 + D = [C+1] + """ + + @staticmethod + def new_state(): + return State() + + def merge_prev_states(self, ircfg, states, loc_key): + """ + Merge predecessors states of irblock at location @loc_key + @ircfg: IRCfg instance + @states: Dictionary linking locations to state + @loc_key: location of the current irblock + """ + + prev_states = [] + for predecessor in ircfg.predecessors(loc_key): + prev_states.append((predecessor, states[predecessor])) + + filtered_prev_states = [] + for (_, prev_state) in prev_states: + if prev_state is not None: + filtered_prev_states.append(prev_state) + + prev_states = filtered_prev_states + if not prev_states: + state = self.new_state() + elif len(prev_states) == 1: + state = prev_states[0].copy() + else: + while prev_states: + state = prev_states.pop() + if state is not None: + break + for prev_state in prev_states: + state = state.merge(prev_state) + + return state + + def update_state(self, irblock, state): + """ + Propagate the @state through the @irblock + @irblock: IRBlock instance + @state: State instance + """ + new_assignblocks = [] + modified = False + + for assignblock in irblock: + if not assignblock.items(): + continue + new_assignblk = state.eval_assignblock(assignblock) + new_assignblocks.append(new_assignblk) + if new_assignblk != assignblock: + modified = True + + new_irblock = IRBlock(irblock.loc_db, irblock.loc_key, new_assignblocks) + + return new_irblock, modified + + def propagate(self, ssa, head, max_expr_depth=None): + """ + Apply algorithm on the @ssa graph + """ + ircfg = ssa.ircfg + self.loc_db = ircfg.loc_db + irblocks = ssa.ircfg.blocks + states = {} + for loc_key, irblock in irblocks.items(): + states[loc_key] = None + + todo = deque([head]) + while todo: + loc_key = todo.popleft() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state_orig = states[loc_key] + state = self.merge_prev_states(ircfg, states, loc_key) + state = state.copy() + + new_irblock, modified_irblock = self.update_state(irblock, state) + if state_orig is not None: + # Merge current and previous state + state = state.merge(state_orig) + if (state.equivalence_classes == state_orig.equivalence_classes and + state.undefined == state_orig.undefined + ): + continue + + states[loc_key] = state + # Propagate to sons + for successor in ircfg.successors(loc_key): + todo.append(successor) + + # Update blocks + todo = set(loc_key for loc_key in irblocks) + modified = False + while todo: + loc_key = todo.pop() + irblock = irblocks.get(loc_key) + if irblock is None: + continue + + state = self.merge_prev_states(ircfg, states, loc_key) + new_irblock, modified_irblock = self.update_state(irblock, state) + modified |= modified_irblock + irblocks[new_irblock.loc_key] = new_irblock + + return modified diff --git a/src/miasm/analysis/debugging.py b/src/miasm/analysis/debugging.py new file mode 100644 index 00000000..d5f59d49 --- /dev/null +++ b/src/miasm/analysis/debugging.py @@ -0,0 +1,557 @@ +from __future__ import print_function +from builtins import map +from builtins import range +import cmd +from future.utils import viewitems + +from miasm.core.utils import hexdump +from miasm.core.interval import interval +import miasm.jitter.csts as csts +from miasm.jitter.jitload import ExceptionHandle + + +class DebugBreakpoint(object): + + "Debug Breakpoint parent class" + pass + + +class DebugBreakpointSoft(DebugBreakpoint): + + "Stand for software breakpoint" + + def __init__(self, addr): + self.addr = addr + + def __str__(self): + return "Soft BP @0x%08x" % self.addr + + +class DebugBreakpointTerminate(DebugBreakpoint): + "Stand for an execution termination" + + def __init__(self, status): + self.status = status + + def __str__(self): + return "Terminate with %s" % self.status + + +class DebugBreakpointMemory(DebugBreakpoint): + + "Stand for memory breakpoint" + + type2str = {csts.BREAKPOINT_READ: "R", + csts.BREAKPOINT_WRITE: "W"} + + def __init__(self, addr, size, access_type): + self.addr = addr + self.access_type = access_type + self.size = size + + def __str__(self): + bp_type = "" + for k, v in viewitems(self.type2str): + if k & self.access_type != 0: + bp_type += v + return "Memory BP @0x%08x, Size 0x%08x, Type %s" % ( + self.addr, + self.size, + bp_type + ) + + @classmethod + def get_access_type(cls, read=False, write=False): + value = 0 + for k, v in viewitems(cls.type2str): + if v == "R" and read is True: + value += k + if v == "W" and write is True: + value += k + return value + + +class Debugguer(object): + + "Debugguer linked with a Jitter instance" + + def __init__(self, myjit): + "myjit : jitter instance" + self.myjit = myjit + self.bp_list = [] # DebugBreakpointSoft list + self.mem_bp_list = [] # DebugBreakpointMemory list + self.mem_watched = [] # Memory areas watched + self.init_memory_breakpoint() + + def init_run(self, addr): + self.myjit.init_run(addr) + + def add_breakpoint(self, addr): + "Add bp @addr" + bp = DebugBreakpointSoft(addr) + func = lambda x: bp + bp.func = func + self.bp_list.append(bp) + self.myjit.add_breakpoint(addr, func) + + def init_memory_breakpoint(self): + "Set exception handler on EXCEPT_BREAKPOINT_MEMORY" + def exception_memory_breakpoint(jitter): + "Stop the execution and return an identifier" + return ExceptionHandle.memoryBreakpoint() + + self.myjit.add_exception_handler(csts.EXCEPT_BREAKPOINT_MEMORY, + exception_memory_breakpoint) + + + def add_memory_breakpoint(self, addr, size, read=False, write=False): + "add mem bp @[addr, addr + size], on read/write/both" + access_type = DebugBreakpointMemory.get_access_type(read=read, + write=write) + dbm = DebugBreakpointMemory(addr, size, access_type) + self.mem_bp_list.append(dbm) + self.myjit.vm.add_memory_breakpoint(addr, size, access_type) + + def remove_breakpoint(self, dbs): + "remove the DebugBreakpointSoft instance" + self.bp_list.remove(dbs) + self.myjit.remove_breakpoints_by_callback(dbs.func) + + def remove_breakpoint_by_addr(self, addr): + "remove breakpoints @ addr" + for bp in self.get_breakpoint_by_addr(addr): + self.remove_breakpoint(bp) + + def remove_memory_breakpoint(self, dbm): + "remove the DebugBreakpointMemory instance" + self.mem_bp_list.remove(dbm) + self.myjit.vm.remove_memory_breakpoint(dbm.addr, dbm.access_type) + + def remove_memory_breakpoint_by_addr_access(self, addr, read=False, + write=False): + "remove breakpoints @ addr" + access_type = DebugBreakpointMemory.get_access_type(read=read, + write=write) + for bp in self.mem_bp_list: + if bp.addr == addr and bp.access_type == access_type: + self.remove_memory_breakpoint(bp) + + def get_breakpoint_by_addr(self, addr): + ret = [] + for dbgsoft in self.bp_list: + if dbgsoft.addr == addr: + ret.append(dbgsoft) + return ret + + def get_breakpoints(self): + return self.bp_list + + def active_trace(self, mn=None, regs=None, newbloc=None): + if mn is not None: + self.myjit.jit.log_mn = mn + if regs is not None: + self.myjit.jit.log_regs = regs + if newbloc is not None: + self.myjit.jit.log_newbloc = newbloc + + def handle_exception(self, res): + if not res: + # A breakpoint has stopped the execution + return DebugBreakpointTerminate(res) + + if isinstance(res, DebugBreakpointSoft): + print("Breakpoint reached @0x%08x" % res.addr) + elif isinstance(res, ExceptionHandle): + if res == ExceptionHandle.memoryBreakpoint(): + print("Memory breakpoint reached @0x%08x" % self.myjit.pc) + + memory_read = self.myjit.vm.get_memory_read() + if len(memory_read) > 0: + print("Read:") + for start_address, end_address in memory_read: + print("- from 0x%08x to 0x%08x" % (start_address, end_address)) + memory_write = self.myjit.vm.get_memory_write() + if len(memory_write) > 0: + print("Write:") + for start_address, end_address in memory_write: + print("- from 0x%08x to 0x%08x" % (start_address, end_address)) + + # Remove flag + except_flag = self.myjit.vm.get_exception() + self.myjit.vm.set_exception(except_flag ^ res.except_flag) + # Clean memory access data + self.myjit.vm.reset_memory_access() + else: + raise NotImplementedError("Unknown Except") + else: + raise NotImplementedError("type res") + + # Repropagate res + return res + + def step(self): + "Step in jit" + + self.myjit.jit.set_options(jit_maxline=1) + # Reset all jitted blocks + self.myjit.jit.clear_jitted_blocks() + + res = self.myjit.continue_run(step=True) + self.handle_exception(res) + + self.myjit.jit.set_options(jit_maxline=50) + self.on_step() + + return res + + def run(self): + status = self.myjit.continue_run() + return self.handle_exception(status) + + def get_mem(self, addr, size=0xF): + "hexdump @addr, size" + + hexdump(self.myjit.vm.get_mem(addr, size)) + + def get_mem_raw(self, addr, size=0xF): + "hexdump @addr, size" + return self.myjit.vm.get_mem(addr, size) + + def watch_mem(self, addr, size=0xF): + self.mem_watched.append((addr, size)) + + def on_step(self): + for addr, size in self.mem_watched: + print("@0x%08x:" % addr) + self.get_mem(addr, size) + + def get_reg_value(self, reg_name): + return getattr(self.myjit.cpu, reg_name) + + def set_reg_value(self, reg_name, value): + + # Handle PC case + if reg_name == self.myjit.lifter.pc.name: + self.init_run(value) + + setattr(self.myjit.cpu, reg_name, value) + + def get_gpreg_all(self): + "Return general purposes registers" + return self.myjit.cpu.get_gpreg() + + +class DebugCmd(cmd.Cmd, object): + + "CommandLineInterpreter for Debugguer instance" + + color_g = '\033[92m' + color_e = '\033[0m' + color_b = '\033[94m' + color_r = '\033[91m' + + intro = color_g + "=== Miasm2 Debugging shell ===\nIf you need help, " + intro += "type 'help' or '?'" + color_e + prompt = color_b + "$> " + color_e + + def __init__(self, dbg): + "dbg : Debugguer" + self.dbg = dbg + super(DebugCmd, self).__init__() + + # Debug methods + + def print_breakpoints(self): + bp_list = self.dbg.bp_list + if len(bp_list) == 0: + print("No breakpoints.") + else: + for i, b in enumerate(bp_list): + print("%d\t0x%08x" % (i, b.addr)) + + def print_memory_breakpoints(self): + bp_list = self.dbg.mem_bp_list + if len(bp_list) == 0: + print("No memory breakpoints.") + else: + for _, bp in enumerate(bp_list): + print(str(bp)) + + def print_watchmems(self): + watch_list = self.dbg.mem_watched + if len(watch_list) == 0: + print("No memory watchpoints.") + else: + print("Num\tAddress \tSize") + for i, w in enumerate(watch_list): + addr, size = w + print("%d\t0x%08x\t0x%08x" % (i, addr, size)) + + def print_registers(self): + regs = self.dbg.get_gpreg_all() + + # Display settings + title1 = "Registers" + title2 = "Values" + max_name_len = max(map(len, list(regs) + [title1])) + + # Print value table + s = "%s%s | %s" % ( + title1, " " * (max_name_len - len(title1)), title2) + print(s) + print("-" * len(s)) + for name, value in sorted(viewitems(regs), key=lambda x: x[0]): + print( + "%s%s | %s" % ( + name, + " " * (max_name_len - len(name)), + hex(value).replace("L", "") + ) + ) + + def add_breakpoints(self, bp_addr): + for addr in bp_addr: + addr = int(addr, 0) + + good = True + for i, dbg_obj in enumerate(self.dbg.bp_list): + if dbg_obj.addr == addr: + good = False + break + if good is False: + print("Breakpoint 0x%08x already set (%d)" % (addr, i)) + else: + l = len(self.dbg.bp_list) + self.dbg.add_breakpoint(addr) + print("Breakpoint 0x%08x successfully added ! (%d)" % (addr, l)) + + display_mode = { + "mn": None, + "regs": None, + "newbloc": None + } + + def update_display_mode(self): + self.display_mode = { + "mn": self.dbg.myjit.jit.log_mn, + "regs": self.dbg.myjit.jit.log_regs, + "newbloc": self.dbg.myjit.jit.log_newbloc + } + + # Command line methods + def print_warning(self, s): + print(self.color_r + s + self.color_e) + + def onecmd(self, line): + cmd_translate = { + "h": "help", + "q": "exit", + "e": "exit", + "!": "exec", + "r": "run", + "i": "info", + "b": "breakpoint", + "m": "memory_breakpoint", + "s": "step", + "d": "dump" + } + + if len(line) >= 2 and \ + line[1] == " " and \ + line[:1] in cmd_translate: + line = cmd_translate[line[:1]] + line[1:] + + if len(line) == 1 and line in cmd_translate: + line = cmd_translate[line] + + r = super(DebugCmd, self).onecmd(line) + return r + + def can_exit(self): + return True + + def do_display(self, arg): + if arg == "": + self.help_display() + return + + args = arg.split(" ") + if args[-1].lower() not in ["on", "off"]: + self.print_warning("[!] %s not in 'on' / 'off'" % args[-1]) + return + mode = args[-1].lower() == "on" + d = {} + for a in args[:-1]: + d[a] = mode + self.dbg.active_trace(**d) + self.update_display_mode() + + def help_display(self): + print("Enable/Disable tracing.") + print("Usage: display <mode1> <mode2> ... on|off") + print("Available modes are:") + for k in self.display_mode: + print("\t%s" % k) + print("Use 'info display' to get current values") + + def do_watchmem(self, arg): + if arg == "": + self.help_watchmem() + return + + args = arg.split(" ") + if len(args) >= 2: + size = int(args[1], 0) + else: + size = 0xF + + addr = int(args[0], 0) + + self.dbg.watch_mem(addr, size) + + def help_watchmem(self): + print("Add a memory watcher.") + print("Usage: watchmem <addr> [size]") + print("Use 'info watchmem' to get current memory watchers") + + def do_info(self, arg): + av_info = [ + "registers", + "display", + "breakpoints", + "memory_breakpoint", + "watchmem" + ] + + if arg == "": + print("'info' must be followed by the name of an info command.") + print("List of info subcommands:") + for k in av_info: + print("\t%s" % k) + + if arg.startswith("b"): + # Breakpoint + self.print_breakpoints() + + if arg.startswith("m"): + # Memory breakpoints + self.print_memory_breakpoints() + + if arg.startswith("d"): + # Display + self.update_display_mode() + for k, v in viewitems(self.display_mode): + print("%s\t\t%s" % (k, v)) + + if arg.startswith("w"): + # Watchmem + self.print_watchmems() + + if arg.startswith("r"): + # Registers + self.print_registers() + + def help_info(self): + print("Generic command for showing things about the program being") + print("debugged. Use 'info' without arguments to get the list of") + print("available subcommands.") + + def do_breakpoint(self, arg): + if arg == "": + self.help_breakpoint() + else: + addrs = arg.split(" ") + self.add_breakpoints(addrs) + + def help_breakpoint(self): + print("Add breakpoints to argument addresses.") + print("Example:") + print("\tbreakpoint 0x11223344") + print("\tbreakpoint 1122 0xabcd") + + def do_memory_breakpoint(self, arg): + if arg == "": + self.help_memory_breakpoint() + return + args = arg.split(" ") + if len(args) > 3 or len(args) <= 1: + self.help_memory_breakpoint() + return + address = int(args[0], 0) + size = int(args[1], 0) + if len(args) == 2: + self.dbg.add_memory_breakpoint(address, size, read=True, write=True) + else: + self.dbg.add_memory_breakpoint(address, + size, + read=('r' in args[2]), + write=('w' in args[2])) + + def help_memory_breakpoint(self): + print("Add memory breakpoints to memory space defined by a starting") + print("address and a size on specified access type (default is 'rw').") + print("Example:") + print("\tmemory_breakpoint 0x11223344 0x100 r") + print("\tmemory_breakpoint 1122 10") + + def do_step(self, arg): + if arg == "": + nb = 1 + else: + nb = int(arg) + for _ in range(nb): + self.dbg.step() + + def help_step(self): + print("Step program until it reaches a different source line.") + print("Argument N means do this N times (or till program stops") + print("for another reason).") + + def do_dump(self, arg): + if arg == "": + self.help_dump() + else: + args = arg.split(" ") + if len(args) >= 2: + size = int(args[1], 0) + else: + size = 0xF + addr = int(args[0], 0) + + self.dbg.get_mem(addr, size) + + def help_dump(self): + print("Dump <addr> [size]. Dump size bytes at addr.") + + def do_run(self, _): + self.dbg.run() + + def help_run(self): + print("Launch or continue the current program") + + def do_exit(self, _): + return True + + def do_exec(self, line): + try: + print(eval(line)) + except Exception as error: + print("*** Error: %s" % error) + + def help_exec(self): + print("Exec a python command.") + print("You can also use '!' shortcut.") + + def help_exit(self): + print("Exit the interpreter.") + print("You can also use the Ctrl-D shortcut.") + + def help_help(self): + print("Print help") + + def postloop(self): + print('\nGoodbye !') + super(DebugCmd, self).postloop() + + do_EOF = do_exit + help_EOF = help_exit diff --git a/src/miasm/analysis/depgraph.py b/src/miasm/analysis/depgraph.py new file mode 100644 index 00000000..436e5354 --- /dev/null +++ b/src/miasm/analysis/depgraph.py @@ -0,0 +1,659 @@ +"""Provide dependency graph""" + +from functools import total_ordering + +from future.utils import viewitems + +from miasm.expression.expression import ExprInt, ExprLoc, ExprAssign, \ + ExprWalk, canonize_to_exprloc +from miasm.core.graph import DiGraph +from miasm.expression.simplifications import expr_simp_explicit +from miasm.ir.symbexec import SymbolicExecutionEngine +from miasm.ir.ir import IRBlock, AssignBlock +from miasm.ir.translators import Translator +from miasm.expression.expression_helper import possible_values + +try: + import z3 +except: + pass + +@total_ordering +class DependencyNode(object): + + """Node elements of a DependencyGraph + + A dependency node stands for the dependency on the @element at line number + @line_nb in the IRblock named @loc_key, *before* the evaluation of this + line. + """ + + __slots__ = ["_loc_key", "_element", "_line_nb", "_hash"] + + def __init__(self, loc_key, element, line_nb): + """Create a dependency node with: + @loc_key: LocKey instance + @element: Expr instance + @line_nb: int + """ + self._loc_key = loc_key + self._element = element + self._line_nb = line_nb + self._hash = hash( + (self._loc_key, self._element, self._line_nb)) + + def __hash__(self): + """Returns a hash of @self to uniquely identify @self""" + return self._hash + + def __eq__(self, depnode): + """Returns True if @self and @depnode are equals.""" + if not isinstance(depnode, self.__class__): + return False + return (self.loc_key == depnode.loc_key and + self.element == depnode.element and + self.line_nb == depnode.line_nb) + + def __ne__(self, depnode): + # required Python 2.7.14 + return not self == depnode + + def __lt__(self, node): + """Compares @self with @node.""" + if not isinstance(node, self.__class__): + return NotImplemented + + return ((self.loc_key, self.element, self.line_nb) < + (node.loc_key, node.element, node.line_nb)) + + def __str__(self): + """Returns a string representation of DependencyNode""" + return "<%s %s %s %s>" % (self.__class__.__name__, + self.loc_key, self.element, + self.line_nb) + + def __repr__(self): + """Returns a string representation of DependencyNode""" + return self.__str__() + + @property + def loc_key(self): + "Name of the current IRBlock" + return self._loc_key + + @property + def element(self): + "Current tracked Expr" + return self._element + + @property + def line_nb(self): + "Line in the current IRBlock" + return self._line_nb + + +class DependencyState(object): + + """ + Store intermediate depnodes states during dependencygraph analysis + """ + + def __init__(self, loc_key, pending, line_nb=None): + self.loc_key = loc_key + self.history = [loc_key] + self.pending = {k: set(v) for k, v in viewitems(pending)} + self.line_nb = line_nb + self.links = set() + + # Init lazy elements + self._graph = None + + def __repr__(self): + return "<State: %r (%r) (%r)>" % ( + self.loc_key, + self.pending, + self.links + ) + + def extend(self, loc_key): + """Return a copy of itself, with itself in history + @loc_key: LocKey instance for the new DependencyState's loc_key + """ + new_state = self.__class__(loc_key, self.pending) + new_state.links = set(self.links) + new_state.history = self.history + [loc_key] + return new_state + + def get_done_state(self): + """Returns immutable object representing current state""" + return (self.loc_key, frozenset(self.links)) + + def as_graph(self): + """Generates a Digraph of dependencies""" + graph = DiGraph() + for node_a, node_b in self.links: + if not node_b: + graph.add_node(node_a) + else: + graph.add_edge(node_a, node_b) + for parent, sons in viewitems(self.pending): + for son in sons: + graph.add_edge(parent, son) + return graph + + @property + def graph(self): + """Returns a DiGraph instance representing the DependencyGraph""" + if self._graph is None: + self._graph = self.as_graph() + return self._graph + + def remove_pendings(self, nodes): + """Remove resolved @nodes""" + for node in nodes: + del self.pending[node] + + def add_pendings(self, future_pending): + """Add @future_pending to the state""" + for node, depnodes in viewitems(future_pending): + if node not in self.pending: + self.pending[node] = depnodes + else: + self.pending[node].update(depnodes) + + def link_element(self, element, line_nb): + """Link element to its dependencies + @element: the element to link + @line_nb: the element's line + """ + + depnode = DependencyNode(self.loc_key, element, line_nb) + if not self.pending[element]: + # Create start node + self.links.add((depnode, None)) + else: + # Link element to its known dependencies + for node_son in self.pending[element]: + self.links.add((depnode, node_son)) + + def link_dependencies(self, element, line_nb, dependencies, + future_pending): + """Link unfollowed dependencies and create remaining pending elements. + @element: the element to link + @line_nb: the element's line + @dependencies: the element's dependencies + @future_pending: the future dependencies + """ + + depnode = DependencyNode(self.loc_key, element, line_nb) + + # Update pending, add link to unfollowed nodes + for dependency in dependencies: + if not dependency.follow: + # Add non followed dependencies to the dependency graph + parent = DependencyNode( + self.loc_key, dependency.element, line_nb) + self.links.add((parent, depnode)) + continue + # Create future pending between new dependency and the current + # element + future_pending.setdefault(dependency.element, set()).add(depnode) + + +class DependencyResult(DependencyState): + + """Container and methods for DependencyGraph results""" + + def __init__(self, ircfg, initial_state, state, inputs): + + super(DependencyResult, self).__init__(state.loc_key, state.pending) + self.initial_state = initial_state + self.history = state.history + self.pending = state.pending + self.line_nb = state.line_nb + self.inputs = inputs + self.links = state.links + self._ircfg = ircfg + + # Init lazy elements + self._has_loop = None + + @property + def unresolved(self): + """Set of nodes whose dependencies weren't found""" + return set(element for element in self.pending + if element != self._ircfg.IRDst) + + @property + def relevant_nodes(self): + """Set of nodes directly and indirectly influencing inputs""" + output = set() + for node_a, node_b in self.links: + output.add(node_a) + if node_b is not None: + output.add(node_b) + return output + + @property + def relevant_loc_keys(self): + """List of loc_keys containing nodes influencing inputs. + The history order is preserved.""" + # Get used loc_keys + used_loc_keys = set(depnode.loc_key for depnode in self.relevant_nodes) + + # Keep history order + output = [] + for loc_key in self.history: + if loc_key in used_loc_keys: + output.append(loc_key) + + return output + + @property + def has_loop(self): + """True iff there is at least one data dependencies cycle (regarding + the associated depgraph)""" + if self._has_loop is None: + self._has_loop = self.graph.has_loop() + return self._has_loop + + def irblock_slice(self, irb, max_line=None): + """Slice of the dependency nodes on the irblock @irb + @irb: irbloc instance + """ + + assignblks = [] + line2elements = {} + for depnode in self.relevant_nodes: + if depnode.loc_key != irb.loc_key: + continue + line2elements.setdefault(depnode.line_nb, + set()).add(depnode.element) + + for line_nb, elements in sorted(viewitems(line2elements)): + if max_line is not None and line_nb >= max_line: + break + assignmnts = {} + for element in elements: + if element in irb[line_nb]: + # constants, loc_key, ... are not in destination + assignmnts[element] = irb[line_nb][element] + assignblks.append(AssignBlock(assignmnts)) + + return IRBlock(irb.loc_db, irb.loc_key, assignblks) + + def emul(self, lifter, ctx=None, step=False): + """Symbolic execution of relevant nodes according to the history + Return the values of inputs nodes' elements + @lifter: Lifter instance + @ctx: (optional) Initial context as dictionary + @step: (optional) Verbose execution + Warning: The emulation is not sound if the inputs nodes depend on loop + variant. + """ + # Init + ctx_init = {} + if ctx is not None: + ctx_init.update(ctx) + assignblks = [] + + # Build a single assignment block according to history + last_index = len(self.relevant_loc_keys) + for index, loc_key in enumerate(reversed(self.relevant_loc_keys), 1): + if index == last_index and loc_key == self.initial_state.loc_key: + line_nb = self.initial_state.line_nb + else: + line_nb = None + assignblks += self.irblock_slice(self._ircfg.blocks[loc_key], + line_nb).assignblks + + # Eval the block + loc_db = lifter.loc_db + temp_loc = loc_db.get_or_create_name_location("Temp") + symb_exec = SymbolicExecutionEngine(lifter, ctx_init) + symb_exec.eval_updt_irblock(IRBlock(loc_db, temp_loc, assignblks), step=step) + + # Return only inputs values (others could be wrongs) + return {element: symb_exec.symbols[element] + for element in self.inputs} + + +class DependencyResultImplicit(DependencyResult): + + """Stand for a result of a DependencyGraph with implicit option + + Provide path constraints using the z3 solver""" + # Z3 Solver instance + _solver = None + + unsat_expr = ExprAssign(ExprInt(0, 1), ExprInt(1, 1)) + + def _gen_path_constraints(self, translator, expr, expected): + """Generate path constraint from @expr. Handle special case with + generated loc_keys + """ + out = [] + expected = canonize_to_exprloc(self._ircfg.loc_db, expected) + expected_is_loc_key = expected.is_loc() + for consval in possible_values(expr): + value = canonize_to_exprloc(self._ircfg.loc_db, consval.value) + if expected_is_loc_key and value != expected: + continue + if not expected_is_loc_key and value.is_loc_key(): + continue + + conds = z3.And(*[translator.from_expr(cond.to_constraint()) + for cond in consval.constraints]) + if expected != value: + conds = z3.And( + conds, + translator.from_expr( + ExprAssign(value, + expected)) + ) + out.append(conds) + + if out: + conds = z3.Or(*out) + else: + # Ex: expr: lblgen1, expected: 0x1234 + # -> Avoid inconsistent solution lblgen1 = 0x1234 + conds = translator.from_expr(self.unsat_expr) + return conds + + def emul(self, lifter, ctx=None, step=False): + # Init + ctx_init = {} + if ctx is not None: + ctx_init.update(ctx) + solver = z3.Solver() + symb_exec = SymbolicExecutionEngine(lifter, ctx_init) + history = self.history[::-1] + history_size = len(history) + translator = Translator.to_language("z3") + size = self._ircfg.IRDst.size + + for hist_nb, loc_key in enumerate(history, 1): + if hist_nb == history_size and loc_key == self.initial_state.loc_key: + line_nb = self.initial_state.line_nb + else: + line_nb = None + irb = self.irblock_slice(self._ircfg.blocks[loc_key], line_nb) + + # Emul the block and get back destination + dst = symb_exec.eval_updt_irblock(irb, step=step) + + # Add constraint + if hist_nb < history_size: + next_loc_key = history[hist_nb] + expected = symb_exec.eval_expr(ExprLoc(next_loc_key, size)) + solver.add(self._gen_path_constraints(translator, dst, expected)) + # Save the solver + self._solver = solver + + # Return only inputs values (others could be wrongs) + return { + element: symb_exec.eval_expr(element) + for element in self.inputs + } + + @property + def is_satisfiable(self): + """Return True iff the solution path admits at least one solution + PRE: 'emul' + """ + return self._solver.check() == z3.sat + + @property + def constraints(self): + """If satisfiable, return a valid solution as a Z3 Model instance""" + if not self.is_satisfiable: + raise ValueError("Unsatisfiable") + return self._solver.model() + + +class FollowExpr(object): + + "Stand for an element (expression, depnode, ...) to follow or not" + __slots__ = ["follow", "element"] + + def __init__(self, follow, element): + self.follow = follow + self.element = element + + def __repr__(self): + return '%s(%r, %r)' % (self.__class__.__name__, self.follow, self.element) + + @staticmethod + def to_depnodes(follow_exprs, loc_key, line): + """Build a set of FollowExpr(DependencyNode) from the @follow_exprs set + of FollowExpr + @follow_exprs: set of FollowExpr + @loc_key: LocKey instance + @line: integer + """ + dependencies = set() + for follow_expr in follow_exprs: + dependencies.add(FollowExpr(follow_expr.follow, + DependencyNode(loc_key, + follow_expr.element, + line))) + return dependencies + + @staticmethod + def extract_depnodes(follow_exprs, only_follow=False): + """Extract depnodes from a set of FollowExpr(Depnodes) + @only_follow: (optional) extract only elements to follow""" + return set(follow_expr.element + for follow_expr in follow_exprs + if not(only_follow) or follow_expr.follow) + + +class FilterExprSources(ExprWalk): + """ + Walk Expression to find sources to track + @follow_mem: (optional) Track memory syntactically + @follow_call: (optional) Track through "call" + """ + def __init__(self, follow_mem, follow_call): + super(FilterExprSources, self).__init__(lambda x:None) + self.follow_mem = follow_mem + self.follow_call = follow_call + self.nofollow = set() + self.follow = set() + + def visit(self, expr, *args, **kwargs): + if expr in self.cache: + return None + ret = self.visit_inner(expr, *args, **kwargs) + self.cache.add(expr) + return ret + + def visit_inner(self, expr, *args, **kwargs): + if expr.is_id(): + self.follow.add(expr) + elif expr.is_int(): + self.nofollow.add(expr) + elif expr.is_loc(): + self.nofollow.add(expr) + elif expr.is_mem(): + if self.follow_mem: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + elif expr.is_function_call(): + if self.follow_call: + self.follow.add(expr) + else: + self.nofollow.add(expr) + return None + + ret = super(FilterExprSources, self).visit(expr, *args, **kwargs) + return ret + + +class DependencyGraph(object): + + """Implementation of a dependency graph + + A dependency graph contains DependencyNode as nodes. The oriented edges + stand for a dependency. + The dependency graph is made of the lines of a group of IRblock + *explicitly* or *implicitly* involved in the equation of given element. + """ + + def __init__(self, ircfg, + implicit=False, apply_simp=True, follow_mem=True, + follow_call=True): + """Create a DependencyGraph linked to @ircfg + + @ircfg: IRCFG instance + @implicit: (optional) Track IRDst for each block in the resulting path + + Following arguments define filters used to generate dependencies + @apply_simp: (optional) Apply expr_simp_explicit + @follow_mem: (optional) Track memory syntactically + @follow_call: (optional) Track through "call" + """ + # Init + self._ircfg = ircfg + self._implicit = implicit + + # Create callback filters. The order is relevant. + self._cb_follow = [] + if apply_simp: + self._cb_follow.append(self._follow_simp_expr) + self._cb_follow.append(lambda exprs: self.do_follow(exprs, follow_mem, follow_call)) + + @staticmethod + def do_follow(exprs, follow_mem, follow_call): + visitor = FilterExprSources(follow_mem, follow_call) + for expr in exprs: + visitor.visit(expr) + return visitor.follow, visitor.nofollow + + @staticmethod + def _follow_simp_expr(exprs): + """Simplify expression so avoid tracking useless elements, + as: XOR EAX, EAX + """ + follow = set() + for expr in exprs: + follow.add(expr_simp_explicit(expr)) + return follow, set() + + def _follow_apply_cb(self, expr): + """Apply callback functions to @expr + @expr : FollowExpr instance""" + follow = set([expr]) + nofollow = set() + + for callback in self._cb_follow: + follow, nofollow_tmp = callback(follow) + nofollow.update(nofollow_tmp) + + out = set(FollowExpr(True, expr) for expr in follow) + out.update(set(FollowExpr(False, expr) for expr in nofollow)) + return out + + def _track_exprs(self, state, assignblk, line_nb): + """Track pending expression in an assignblock""" + future_pending = {} + node_resolved = set() + for dst, src in viewitems(assignblk): + # Only track pending + if dst not in state.pending: + continue + # Track IRDst in implicit mode only + if dst == self._ircfg.IRDst and not self._implicit: + continue + assert dst not in node_resolved + node_resolved.add(dst) + dependencies = self._follow_apply_cb(src) + + state.link_element(dst, line_nb) + state.link_dependencies(dst, line_nb, + dependencies, future_pending) + + # Update pending nodes + state.remove_pendings(node_resolved) + state.add_pendings(future_pending) + + def _compute_intrablock(self, state): + """Follow dependencies tracked in @state in the current irbloc + @state: instance of DependencyState""" + + irb = self._ircfg.blocks[state.loc_key] + line_nb = len(irb) if state.line_nb is None else state.line_nb + + for cur_line_nb, assignblk in reversed(list(enumerate(irb[:line_nb]))): + self._track_exprs(state, assignblk, cur_line_nb) + + def get(self, loc_key, elements, line_nb, heads): + """Compute the dependencies of @elements at line number @line_nb in + the block named @loc_key in the current IRCFG, before the execution of + this line. Dependency check stop if one of @heads is reached + @loc_key: LocKey instance + @element: set of Expr instances + @line_nb: int + @heads: set of LocKey instances + Return an iterator on DiGraph(DependencyNode) + """ + # Init the algorithm + inputs = {element: set() for element in elements} + initial_state = DependencyState(loc_key, inputs, line_nb) + todo = set([initial_state]) + done = set() + dpResultcls = DependencyResultImplicit if self._implicit else DependencyResult + + while todo: + state = todo.pop() + self._compute_intrablock(state) + done_state = state.get_done_state() + if done_state in done: + continue + done.add(done_state) + if (not state.pending or + state.loc_key in heads or + not self._ircfg.predecessors(state.loc_key)): + yield dpResultcls(self._ircfg, initial_state, state, elements) + if not state.pending: + continue + + if self._implicit: + # Force IRDst to be tracked, except in the input block + state.pending[self._ircfg.IRDst] = set() + + # Propagate state to parents + for pred in self._ircfg.predecessors_iter(state.loc_key): + todo.add(state.extend(pred)) + + def get_from_depnodes(self, depnodes, heads): + """Alias for the get() method. Use the attributes of @depnodes as + argument. + PRE: Loc_Keys and lines of depnodes have to be equals + @depnodes: set of DependencyNode instances + @heads: set of LocKey instances + """ + lead = list(depnodes)[0] + elements = set(depnode.element for depnode in depnodes) + return self.get(lead.loc_key, elements, lead.line_nb, heads) + + def address_to_location(self, address): + """Helper to retrieve the .get() arguments, ie. + assembly address -> irblock's location key and line number + """ + current_loc_key = next(iter(self._ircfg.getby_offset(address))) + assignblk_index = 0 + current_block = self._ircfg.get_block(current_loc_key) + for assignblk_index, assignblk in enumerate(current_block): + if assignblk.instr.offset == address: + break + else: + return None + + return { + "loc_key": current_block.loc_key, + "line_nb": assignblk_index, + } diff --git a/src/miasm/analysis/disasm_cb.py b/src/miasm/analysis/disasm_cb.py new file mode 100644 index 00000000..f180f0a2 --- /dev/null +++ b/src/miasm/analysis/disasm_cb.py @@ -0,0 +1,127 @@ +#-*- coding:utf-8 -*- + +from __future__ import print_function + +from future.utils import viewvalues + +from miasm.expression.expression import ExprInt, ExprId, ExprMem, match_expr +from miasm.expression.simplifications import expr_simp +from miasm.core.asmblock import AsmConstraintNext, AsmConstraintTo +from miasm.core.locationdb import LocationDB +from miasm.core.utils import upck32 + + +def get_lifter_model_call(arch, attrib): + arch = arch.name, attrib + if arch == ("arm", "arm"): + from miasm.arch.arm.lifter_model_call import LifterModelCallArmlBase as lifter_model_call + elif arch == ("x86", 32): + from miasm.arch.x86.lifter_model_call import LifterModelCall_x86_32 as lifter_model_call + elif arch == ("x86", 64): + from miasm.arch.x86.lifter_model_call import LifterModelCall_x86_64 as lifter_model_call + else: + raise ValueError('unknown architecture: %s' % arch.name) + return lifter_model_call + + +def arm_guess_subcall(dis_engine, cur_block, offsets_to_dis): + arch = dis_engine.arch + loc_db = dis_engine.loc_db + lifter_model_call = get_lifter_model_call(arch, dis_engine.attrib) + + lifter = lifter_model_call(loc_db) + ircfg = lifter_model_call.new_ircfg() + print('###') + print(cur_block) + lifter.add_asmblock_to_ircfg(cur_block, ircfg) + + to_add = set() + for irblock in viewvalues(ircfg.blocks): + pc_val = None + lr_val = None + for exprs in irblock: + for e in exprs: + if e.dst == lifter.pc: + pc_val = e.src + if e.dst == arch.regs.LR: + lr_val = e.src + if pc_val is None or lr_val is None: + continue + if not isinstance(lr_val, ExprInt): + continue + + l = cur_block.lines[-1] + if lr_val.arg != l.offset + l.l: + continue + l = loc_db.get_or_create_offset_location(int(lr_val)) + c = AsmConstraintNext(l) + + to_add.add(c) + offsets_to_dis.add(int(lr_val)) + + for c in to_add: + cur_block.addto(c) + + +def arm_guess_jump_table(dis_engine, cur_block, offsets_to_dis): + arch = dis_engine.arch + loc_db = dis_engine.loc_db + lifter_model_call = get_lifter_model_call(arch, dis_engine.attrib) + + jra = ExprId('jra') + jrb = ExprId('jrb') + + lifter = lifter_model_call(loc_db) + ircfg = lifter_model_call.new_ircfg() + lifter.add_asmblock_to_ircfg(cur_block, ircfg) + + for irblock in viewvalues(ircfg.blocks): + pc_val = None + for exprs in irblock: + for e in exprs: + if e.dst == lifter.pc: + pc_val = e.src + if pc_val is None: + continue + if not isinstance(pc_val, ExprMem): + continue + assert(pc_val.size == 32) + print(pc_val) + ad = pc_val.arg + ad = expr_simp(ad) + print(ad) + res = match_expr(ad, jra + jrb, set([jra, jrb])) + if res is False: + raise NotImplementedError('not fully functional') + print(res) + if not isinstance(res[jrb], ExprInt): + raise NotImplementedError('not fully functional') + base_ad = int(res[jrb]) + print(base_ad) + addrs = set() + i = -1 + max_table_entry = 10000 + max_diff_addr = 0x100000 # heuristic + while i < max_table_entry: + i += 1 + try: + ad = upck32(dis_engine.bin_stream.getbytes(base_ad + 4 * i, 4)) + except: + break + if abs(ad - base_ad) > max_diff_addr: + break + addrs.add(ad) + print([hex(x) for x in addrs]) + + for ad in addrs: + offsets_to_dis.add(ad) + l = loc_db.get_or_create_offset_location(ad) + c = AsmConstraintTo(l) + cur_block.addto(c) + +guess_funcs = [] + + +def guess_multi_cb(dis_engine, cur_block, offsets_to_dis): + for f in guess_funcs: + f(dis_engine, cur_block, offsets_to_dis) diff --git a/src/miasm/analysis/dse.py b/src/miasm/analysis/dse.py new file mode 100644 index 00000000..11674734 --- /dev/null +++ b/src/miasm/analysis/dse.py @@ -0,0 +1,717 @@ +"""Dynamic symbolic execution module. + +Offers a way to have a symbolic execution along a concrete one. +Basically, this is done through DSEEngine class, with scheme: + +dse = DSEEngine(Machine("x86_32")) +dse.attach(jitter) + +The DSE state can be updated through: + + - .update_state_from_concrete: update the values from the CPU, so the symbolic + execution will be completely concrete from this point (until changes) + - .update_state: inject information, for instance RAX = symbolic_RAX + - .symbolize_memory: symbolize (using .memory_to_expr) memory areas (ie, + reading from an address in one of these areas yield a symbol) + +The DSE run can be instrumented through: + - .add_handler: register an handler, modifying the state instead of the current + execution. Can be used for stubbing external API + - .add_lib_handler: register handlers for libraries + - .add_instrumentation: register an handler, modifying the state but continuing + the current execution. Can be used for logging facilities + + +On branch, if the decision is symbolic, one can also collect "path constraints" +and inverse them to produce new inputs potentially reaching new paths. + +Basically, this is done through DSEPathConstraint. In order to produce a new +solution, one can extend this class, and override 'handle_solution' to produce a +solution which fit its needs. It could avoid computing new solution by +overriding 'produce_solution'. + +If one is only interested in constraints associated to its path, the option +"produce_solution" should be set to False, to speed up emulation. +The constraints are accumulated in the .z3_cur z3.Solver object. + +Here are a few remainings TODO: + - handle endianness in check_state / atomic read: currently, but this is also + true for others Miasm2 symbolic engines, the endianness is not take in + account, and assumed to be Little Endian + + - too many memory dependencies in constraint tracking: in order to let z3 find + new solution, it does need information on memory values (for instance, a + lookup in a table with a symbolic index). The estimated possible involved + memory location could be too large to pass to the solver (threshold named + MAX_MEMORY_INJECT). One possible solution, not yet implemented, is to call + the solver for reducing the possible values thanks to its accumulated + constraints. +""" +from builtins import range +from collections import namedtuple +import warnings + +try: + import z3 +except: + z3 = None + +from future.utils import viewitems + +from miasm.core.utils import encode_hex, force_bytes +from miasm.expression.expression import ExprMem, ExprInt, ExprCompose, \ + ExprAssign, ExprId, ExprLoc, LocKey, canonize_to_exprloc +from miasm.core.bin_stream import bin_stream_vm +from miasm.jitter.emulatedsymbexec import EmulatedSymbExec +from miasm.expression.expression_helper import possible_values +from miasm.ir.translators import Translator +from miasm.analysis.expression_range import expr_range +from miasm.analysis.modularintervals import ModularIntervals + +DriftInfo = namedtuple("DriftInfo", ["symbol", "computed", "expected"]) + +class DriftException(Exception): + """Raised when the emulation drift from the reference engine""" + + def __init__(self, info): + super(DriftException, self).__init__() + self.info = info + + def __str__(self): + if len(self.info) == 1: + return "Drift of %s: %s instead of %s" % ( + self.info[0].symbol, + self.info[0].computed, + self.info[0].expected, + ) + else: + return "Drift of:\n\t" + "\n\t".join("%s: %s instead of %s" % ( + dinfo.symbol, + dinfo.computed, + dinfo.expected) + for dinfo in self.info) + + +class ESETrackModif(EmulatedSymbExec): + """Extension of EmulatedSymbExec to be used by DSE engines + + Add the tracking of modified expressions, and the ability to symbolize + memory areas + """ + + def __init__(self, *args, **kwargs): + super(ESETrackModif, self).__init__(*args, **kwargs) + self.modified_expr = set() # Expr modified since the last reset + self.dse_memory_range = [] # List/Intervals of memory addresses to + # symbolize + self.dse_memory_to_expr = None # function(addr) -> Expr used to + # symbolize + + def mem_read(self, expr_mem): + if not expr_mem.ptr.is_int(): + return super(ESETrackModif, self).mem_read(expr_mem) + dst_addr = int(expr_mem.ptr) + + # Split access in atomic accesses + out = [] + for addr in range(dst_addr, dst_addr + expr_mem.size // 8): + if addr in self.dse_memory_range: + # Symbolize memory access + out.append(self.dse_memory_to_expr(addr)) + continue + atomic_access = ExprMem(ExprInt(addr, expr_mem.ptr.size), 8) + if atomic_access in self.symbols: + out.append( super(EmulatedSymbExec, self).mem_read(atomic_access)) + else: + # Get concrete value + atomic_access = ExprMem(ExprInt(addr, expr_mem.ptr.size), 8) + out.append(super(ESETrackModif, self).mem_read(atomic_access)) + + if len(out) == 1: + # Trivial case (optimization) + return out[0] + + # Simplify for constant merging (ex: {ExprInt(1, 8), ExprInt(2, 8)}) + return self.expr_simp(ExprCompose(*out)) + + def mem_write(self, expr, data): + # Call Symbolic mem_write (avoid side effects on vm) + return super(EmulatedSymbExec, self).mem_write(expr, data) + + def reset_modified(self): + """Reset modified expression tracker""" + self.modified_expr.clear() + + def apply_change(self, dst, src): + super(ESETrackModif, self).apply_change(dst, src) + self.modified_expr.add(dst) + + +class ESENoVMSideEffects(EmulatedSymbExec): + """ + Do EmulatedSymbExec without modifying memory + """ + def mem_write(self, expr, data): + return super(EmulatedSymbExec, self).mem_write(expr, data) + + +class DSEEngine(object): + """Dynamic Symbolic Execution Engine + + This class aims to be overridden for each specific purpose + """ + SYMB_ENGINE = ESETrackModif + + def __init__(self, machine, loc_db): + self.machine = machine + self.loc_db = loc_db + self.handler = {} # addr -> callback(DSEEngine instance) + self.instrumentation = {} # addr -> callback(DSEEngine instance) + self.addr_to_cacheblocks = {} # addr -> {label -> IRBlock} + self.lifter = self.machine.lifter(loc_db=self.loc_db) # corresponding IR + self.ircfg = self.lifter.new_ircfg() # corresponding IR + + # Defined after attachment + self.jitter = None # Jitload (concrete execution) + self.symb = None # SymbolicExecutionEngine + self.symb_concrete = None # Concrete SymbExec for path desambiguisation + self.mdis = None # DisasmEngine + + def prepare(self): + """Prepare the environment for attachment with a jitter""" + # Disassembler + self.mdis = self.machine.dis_engine(bin_stream_vm(self.jitter.vm), + lines_wd=1, + loc_db=self.loc_db) + + # Symbexec engine + ## Prepare symbexec engines + self.symb = self.SYMB_ENGINE(self.jitter.cpu, self.jitter.vm, + self.lifter, {}) + self.symb.enable_emulated_simplifications() + self.symb_concrete = ESENoVMSideEffects( + self.jitter.cpu, self.jitter.vm, + self.lifter, {} + ) + + ## Update registers value + self.symb.symbols[self.lifter.IRDst] = ExprInt( + getattr(self.jitter.cpu, self.lifter.pc.name), + self.lifter.IRDst.size + ) + + # Activate callback on each instr + self.jitter.jit.set_options(max_exec_per_call=1, jit_maxline=1) + self.jitter.exec_cb = self.callback + + # Clean jit cache to avoid multi-line basic blocks already jitted + self.jitter.jit.clear_jitted_blocks() + + def attach(self, emulator): + """Attach the DSE to @emulator + @emulator: jitload (or API equivalent) instance + + To attach *DURING A BREAKPOINT*, one may consider using the following snippet: + + def breakpoint(self, jitter): + ... + dse.attach(jitter) + dse.update... + ... + # Additional call to the exec callback is necessary, as breakpoints are + # honored AFTER exec callback + jitter.exec_cb(jitter) + + return True + + Without it, one may encounteer a DriftException error due to a + "desynchronization" between jitter and dse states. Indeed, on 'handle' + call, the jitter must be one instruction AFTER the dse. + """ + self.jitter = emulator + self.prepare() + + def handle(self, cur_addr): + r"""Handle destination + @cur_addr: Expr of the next address in concrete execution + [!] cur_addr may be a loc_key + + In this method, self.symb is in the "just before branching" state + """ + pass + + def add_handler(self, addr, callback): + """Add a @callback for address @addr before any state update. + The state IS NOT updated after returning from the callback + @addr: int + @callback: func(dse instance)""" + self.handler[addr] = callback + + def add_lib_handler(self, libimp, namespace): + """Add search for handler based on a @libimp libimp instance + + Known functions will be looked by {name}_symb or {name}_{ord}_symb in the @namespace + """ + namespace = dict( + (force_bytes(name), func) for name, func in viewitems(namespace) + ) + + # lambda cannot contain statement + def default_func(dse): + fname = libimp.fad2cname[dse.jitter.pc] + if isinstance(fname, tuple): + fname = b"%s_%d_symb" % (force_bytes(fname[0]), fname[1]) + else: + fname = b"%s_symb" % force_bytes(fname) + raise RuntimeError("Symbolic stub '%s' not found" % fname) + + for addr, fname in viewitems(libimp.fad2cname): + if isinstance(fname, tuple): + fname = b"%s_%d_symb" % (force_bytes(fname[0]), fname[1]) + else: + fname = b"%s_symb" % force_bytes(fname) + func = namespace.get(fname, None) + if func is not None: + self.add_handler(addr, func) + else: + self.add_handler(addr, default_func) + + def add_instrumentation(self, addr, callback): + """Add a @callback for address @addr before any state update. + The state IS updated after returning from the callback + @addr: int + @callback: func(dse instance)""" + self.instrumentation[addr] = callback + + def _check_state(self): + """Check the current state against the concrete one""" + errors = [] # List of DriftInfo + + for symbol in self.symb.modified_expr: + # Do not consider PC + if symbol in [self.lifter.pc, self.lifter.IRDst]: + continue + + # Consider only concrete values + symb_value = self.eval_expr(symbol) + if not symb_value.is_int(): + continue + symb_value = int(symb_value) + + # Check computed values against real ones + if symbol.is_id(): + if hasattr(self.jitter.cpu, symbol.name): + value = getattr(self.jitter.cpu, symbol.name) + if value != symb_value: + errors.append(DriftInfo(symbol, symb_value, value)) + elif symbol.is_mem() and symbol.ptr.is_int(): + value_chr = self.jitter.vm.get_mem( + int(symbol.ptr), + symbol.size // 8 + ) + exp_value = int(encode_hex(value_chr[::-1]), 16) + if exp_value != symb_value: + errors.append(DriftInfo(symbol, symb_value, exp_value)) + + # Check for drift, and act accordingly + if errors: + raise DriftException(errors) + + def callback(self, _): + """Called before each instruction""" + # Assert synchronization with concrete execution + self._check_state() + + # Call callbacks associated to the current address + cur_addr = self.jitter.pc + if isinstance(cur_addr, LocKey): + lbl = self.lifter.loc_db.loc_key_to_label(cur_addr) + cur_addr = lbl.offset + + if cur_addr in self.handler: + self.handler[cur_addr](self) + return True + + if cur_addr in self.instrumentation: + self.instrumentation[cur_addr](self) + + # Handle current address + self.handle(ExprInt(cur_addr, self.lifter.IRDst.size)) + + # Avoid memory issue in ExpressionSimplifier + if len(self.symb.expr_simp.cache) > 100000: + self.symb.expr_simp.cache.clear() + + # Get IR blocks + if cur_addr in self.addr_to_cacheblocks: + self.ircfg.blocks.clear() + self.ircfg.blocks.update(self.addr_to_cacheblocks[cur_addr]) + else: + + ## Reset cache structures + self.ircfg.blocks.clear()# = {} + + ## Update current state + asm_block = self.mdis.dis_block(cur_addr) + self.lifter.add_asmblock_to_ircfg(asm_block, self.ircfg) + self.addr_to_cacheblocks[cur_addr] = dict(self.ircfg.blocks) + + # Emulate the current instruction + self.symb.reset_modified() + + # Is the symbolic execution going (potentially) to jump on a lbl_gen? + if len(self.ircfg.blocks) == 1: + self.symb.run_at(self.ircfg, cur_addr) + else: + # Emulation could stuck in generated IR blocks + # But concrete execution callback is not enough precise to obtain + # the full IR blocks path + # -> Use a fully concrete execution to get back path + + # Update the concrete execution + self._update_state_from_concrete_symb( + self.symb_concrete, cpu=True, mem=True + ) + while True: + + next_addr_concrete = self.symb_concrete.run_block_at( + self.ircfg, cur_addr + ) + self.symb.run_block_at(self.ircfg, cur_addr) + + if not (isinstance(next_addr_concrete, ExprLoc) and + self.lifter.loc_db.get_location_offset( + next_addr_concrete.loc_key + ) is None): + # Not a lbl_gen, exit + break + + # Call handle with lbl_gen state + self.handle(next_addr_concrete) + cur_addr = next_addr_concrete + + + # At this stage, symbolic engine is one instruction after the concrete + # engine + + return True + + def _get_gpregs(self): + """Return a dict of regs: value from the jitter + This version use the regs associated to the attrib (!= cpu.get_gpreg()) + """ + out = {} + regs = self.lifter.arch.regs.attrib_to_regs[self.lifter.attrib] + for reg in regs: + if hasattr(self.jitter.cpu, reg.name): + out[reg.name] = getattr(self.jitter.cpu, reg.name) + return out + + def take_snapshot(self): + """Return a snapshot of the current state (including jitter state)""" + snapshot = { + "mem": self.jitter.vm.get_all_memory(), + "regs": self._get_gpregs(), + "symb": self.symb.symbols.copy(), + } + return snapshot + + def restore_snapshot(self, snapshot, memory=True): + """Restore a @snapshot taken with .take_snapshot + @snapshot: .take_snapshot output + @memory: (optional) if set, also restore the memory + """ + # Restore memory + if memory: + self.jitter.vm.reset_memory_page_pool() + self.jitter.vm.reset_code_bloc_pool() + for addr, metadata in viewitems(snapshot["mem"]): + self.jitter.vm.add_memory_page( + addr, + metadata["access"], + metadata["data"] + ) + + # Restore registers + self.jitter.pc = snapshot["regs"][self.lifter.pc.name] + for reg, value in viewitems(snapshot["regs"]): + setattr(self.jitter.cpu, reg, value) + + # Reset intern elements + self.jitter.vm.set_exception(0) + self.jitter.cpu.set_exception(0) + self.jitter.bs._atomic_mode = False + + # Reset symb exec + for key, _ in list(viewitems(self.symb.symbols)): + del self.symb.symbols[key] + for expr, value in viewitems(snapshot["symb"]): + self.symb.symbols[expr] = value + + def update_state(self, assignblk): + """From this point, assume @assignblk in the symbolic execution + @assignblk: AssignBlock/{dst -> src} + """ + for dst, src in viewitems(assignblk): + self.symb.apply_change(dst, src) + + def _update_state_from_concrete_symb(self, symbexec, cpu=True, mem=False): + if mem: + # Values will be retrieved from the concrete execution if they are + # not present + symbexec.symbols.symbols_mem.base_to_memarray.clear() + if cpu: + regs = self.lifter.arch.regs.attrib_to_regs[self.lifter.attrib] + for reg in regs: + if hasattr(self.jitter.cpu, reg.name): + value = ExprInt(getattr(self.jitter.cpu, reg.name), + size=reg.size) + symbexec.symbols[reg] = value + + def update_state_from_concrete(self, cpu=True, mem=False): + r"""Update the symbolic state with concrete values from the concrete + engine + + @cpu: (optional) if set, update registers' value + @mem: (optional) if set, update memory value + + [!] all current states will be loss. + This function is usually called when states are no more synchronized + (at the beginning, returning from an unstubbed syscall, ...) + """ + self._update_state_from_concrete_symb(self.symb, cpu, mem) + + def eval_expr(self, expr): + """Return the evaluation of @expr: + @expr: Expr instance""" + return self.symb.eval_expr(expr) + + @staticmethod + def memory_to_expr(addr): + """Translate an address to its corresponding symbolic ID (8bits) + @addr: int""" + return ExprId("MEM_0x%x" % int(addr), 8) + + def symbolize_memory(self, memory_range): + """Register a range of memory addresses to symbolize + @memory_range: object with support of __in__ operation (intervals, list, + ...) + """ + self.symb.dse_memory_range = memory_range + self.symb.dse_memory_to_expr = self.memory_to_expr + + +class DSEPathConstraint(DSEEngine): + """Dynamic Symbolic Execution Engine keeping the path constraint + + Possible new "solutions" are produced along the path, by inversing concrete + path constraint. Thus, a "solution" is a potential initial context leading + to a new path. + + In order to produce a new solution, one can extend this class, and override + 'handle_solution' to produce a solution which fit its needs. It could avoid + computing new solution by overriding 'produce_solution'. + + If one is only interested in constraints associated to its path, the option + "produce_solution" should be set to False, to speed up emulation. + The constraints are accumulated in the .z3_cur z3.Solver object. + + """ + + # Maximum memory size to inject in constraints solving + MAX_MEMORY_INJECT = 0x10000 + + # Produce solution strategies + PRODUCE_NO_SOLUTION = 0 + PRODUCE_SOLUTION_CODE_COV = 1 + PRODUCE_SOLUTION_BRANCH_COV = 2 + PRODUCE_SOLUTION_PATH_COV = 3 + + def __init__(self, machine, loc_db, produce_solution=PRODUCE_SOLUTION_CODE_COV, + known_solutions=None, + **kwargs): + """Init a DSEPathConstraint + @machine: Machine of the targeted architecture instance + @produce_solution: (optional) if set, new solutions will be computed""" + super(DSEPathConstraint, self).__init__(machine, loc_db, **kwargs) + + # Dependency check + assert z3 is not None + + # Init PathConstraint specifics structures + self.cur_solver = z3.Solver() + self.new_solutions = {} # solution identifier -> solution's model + self._known_solutions = set() # set of solution identifiers + self.z3_trans = Translator.to_language("z3") + self._produce_solution_strategy = produce_solution + self._previous_addr = None + self._history = None + if produce_solution == self.PRODUCE_SOLUTION_PATH_COV: + self._history = [] # List of addresses in the current path + + @property + def ir_arch(self): + warnings.warn('DEPRECATION WARNING: use ".lifter" instead of ".ir_arch"') + return self.lifter + + def take_snapshot(self, *args, **kwargs): + snap = super(DSEPathConstraint, self).take_snapshot(*args, **kwargs) + snap["new_solutions"] = { + dst: src.copy + for dst, src in viewitems(self.new_solutions) + } + snap["cur_constraints"] = self.cur_solver.assertions() + if self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: + snap["_history"] = list(self._history) + elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: + snap["_previous_addr"] = self._previous_addr + return snap + + def restore_snapshot(self, snapshot, keep_known_solutions=True, **kwargs): + """Restore a DSEPathConstraint snapshot + @keep_known_solutions: if set, do not forget solutions already found. + -> They will not appear in 'new_solutions' + """ + super(DSEPathConstraint, self).restore_snapshot(snapshot, **kwargs) + self.new_solutions.clear() + self.new_solutions.update(snapshot["new_solutions"]) + self.cur_solver = z3.Solver() + self.cur_solver.add(snapshot["cur_constraints"]) + if not keep_known_solutions: + self._known_solutions.clear() + if self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: + self._history = list(snapshot["_history"]) + elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: + self._previous_addr = snapshot["_previous_addr"] + + def _key_for_solution_strategy(self, destination): + """Return the associated identifier for the current solution strategy""" + if self._produce_solution_strategy == self.PRODUCE_NO_SOLUTION: + # Never produce a solution + return None + elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_CODE_COV: + # Decision based on code coverage + # -> produce a solution if the destination has never been seen + key = destination + + elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: + # Decision based on branch coverage + # -> produce a solution if the current branch has never been take + key = (self._previous_addr, destination) + + elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: + # Decision based on path coverage + # -> produce a solution if the current path has never been take + key = tuple(self._history + [destination]) + else: + raise ValueError("Unknown produce solution strategy") + + return key + + def produce_solution(self, destination): + """Called to determine if a solution for @destination should be test for + satisfiability and computed + @destination: Expr instance of the target @destination + """ + key = self._key_for_solution_strategy(destination) + if key is None: + return False + return key not in self._known_solutions + + def handle_solution(self, model, destination): + """Called when a new solution for destination @destination is founded + @model: z3 model instance + @destination: Expr instance for an addr which is not on the DSE path + """ + key = self._key_for_solution_strategy(destination) + assert key is not None + self.new_solutions[key] = model + self._known_solutions.add(key) + + def handle_correct_destination(self, destination, path_constraints): + """[DEV] Called by handle() to update internal structures giving the + correct destination (the concrete execution one). + """ + + # Update structure used by produce_solution() + if self._produce_solution_strategy == self.PRODUCE_SOLUTION_PATH_COV: + self._history.append(destination) + elif self._produce_solution_strategy == self.PRODUCE_SOLUTION_BRANCH_COV: + self._previous_addr = destination + + # Update current solver + for cons in path_constraints: + self.cur_solver.add(self.z3_trans.from_expr(cons)) + + def handle(self, cur_addr): + cur_addr = canonize_to_exprloc(self.lifter.loc_db, cur_addr) + symb_pc = self.eval_expr(self.lifter.IRDst) + possibilities = possible_values(symb_pc) + cur_path_constraint = set() # path_constraint for the concrete path + if len(possibilities) == 1: + dst = next(iter(possibilities)).value + dst = canonize_to_exprloc(self.lifter.loc_db, dst) + assert dst == cur_addr + else: + for possibility in possibilities: + target_addr = canonize_to_exprloc(self.lifter.loc_db, possibility.value) + path_constraint = set() # Set of ExprAssign for the possible path + + # Get constraint associated to the possible path + memory_to_add = ModularIntervals(symb_pc.size) + for cons in possibility.constraints: + eaff = cons.to_constraint() + # eaff.get_r(mem_read=True) is not enough + # ExprAssign consider a Memory access in dst as a write + mem = eaff.dst.get_r(mem_read=True) + mem.update(eaff.src.get_r(mem_read=True)) + for expr in mem: + if expr.is_mem(): + addr_range = expr_range(expr.ptr) + # At upper bounds, add the size of the memory access + # if addr (- [a, b], then @size[addr] reachables + # values are in @8[a, b + size[ + for start, stop in addr_range: + stop += expr.size // 8 - 1 + full_range = ModularIntervals( + symb_pc.size, + [(start, stop)] + ) + memory_to_add.update(full_range) + path_constraint.add(eaff) + + if memory_to_add.length > self.MAX_MEMORY_INJECT: + # TODO re-croncretize the constraint or z3-try + raise RuntimeError("Not implemented: too long memory area") + + # Inject memory + for start, stop in memory_to_add: + for address in range(start, stop + 1): + expr_mem = ExprMem(ExprInt(address, + self.lifter.pc.size), + 8) + value = self.eval_expr(expr_mem) + if not value.is_int(): + raise TypeError("Rely on a symbolic memory case, " \ + "address 0x%x" % address) + path_constraint.add(ExprAssign(expr_mem, value)) + + if target_addr == cur_addr: + # Add path constraint + cur_path_constraint = path_constraint + + elif self.produce_solution(target_addr): + # Looking for a new solution + self.cur_solver.push() + for cons in path_constraint: + trans = self.z3_trans.from_expr(cons) + trans = z3.simplify(trans) + self.cur_solver.add(trans) + + result = self.cur_solver.check() + if result == z3.sat: + model = self.cur_solver.model() + self.handle_solution(model, target_addr) + self.cur_solver.pop() + + self.handle_correct_destination(cur_addr, cur_path_constraint) diff --git a/src/miasm/analysis/expression_range.py b/src/miasm/analysis/expression_range.py new file mode 100644 index 00000000..5a31873a --- /dev/null +++ b/src/miasm/analysis/expression_range.py @@ -0,0 +1,70 @@ +"""Naive range analysis for expression""" + +from future.builtins import zip +from functools import reduce + +from miasm.analysis.modularintervals import ModularIntervals + +_op_range_handler = { + "+": lambda x, y: x + y, + "&": lambda x, y: x & y, + "|": lambda x, y: x | y, + "^": lambda x, y: x ^ y, + "*": lambda x, y: x * y, + "a>>": lambda x, y: x.arithmetic_shift_right(y), + "<<": lambda x, y: x << y, + ">>": lambda x, y: x >> y, + ">>>": lambda x, y: x.rotation_right(y), + "<<<": lambda x, y: x.rotation_left(y), +} + +def expr_range(expr): + """Return a ModularIntervals containing the range of possible values of + @expr""" + max_bound = (1 << expr.size) - 1 + if expr.is_int(): + return ModularIntervals(expr.size, [(int(expr), int(expr))]) + elif expr.is_id() or expr.is_mem(): + return ModularIntervals(expr.size, [(0, max_bound)]) + elif expr.is_slice(): + interval_mask = ((1 << expr.start) - 1) ^ ((1 << expr.stop) - 1) + arg = expr_range(expr.arg) + # Mask for possible range, and shift range + return ((arg & interval_mask) >> expr.start).size_update(expr.size) + elif expr.is_compose(): + sub_ranges = [expr_range(arg) for arg in expr.args] + args_idx = [info[0] for info in expr.iter_args()] + + # No shift for the first one + ret = sub_ranges[0].size_update(expr.size) + + # Doing it progressively (2 by 2) + for shift, sub_range in zip(args_idx[1:], sub_ranges[1:]): + ret |= sub_range.size_update(expr.size) << shift + return ret + elif expr.is_op(): + # A few operation are handled with care + # Otherwise, overapproximate (ie. full range interval) + if expr.op in _op_range_handler: + sub_ranges = [expr_range(arg) for arg in expr.args] + return reduce( + _op_range_handler[expr.op], + (sub_range for sub_range in sub_ranges[1:]), + sub_ranges[0] + ) + elif expr.op == "-": + assert len(expr.args) == 1 + return - expr_range(expr.args[0]) + elif expr.op == "%": + assert len(expr.args) == 2 + op, mod = [expr_range(arg) for arg in expr.args] + if mod.intervals.length == 1: + # Modulo intervals is not supported + return op % mod.intervals.hull()[0] + + # Operand not handled, return the full domain + return ModularIntervals(expr.size, [(0, max_bound)]) + elif expr.is_cond(): + return expr_range(expr.src1).union(expr_range(expr.src2)) + else: + raise TypeError("Unsupported type: %s" % expr.__class__) diff --git a/src/miasm/analysis/gdbserver.py b/src/miasm/analysis/gdbserver.py new file mode 100644 index 00000000..b45e9f35 --- /dev/null +++ b/src/miasm/analysis/gdbserver.py @@ -0,0 +1,453 @@ +#-*- coding:utf-8 -*- + +from __future__ import print_function +from future.builtins import map, range + +from miasm.core.utils import decode_hex, encode_hex, int_to_byte + +import socket +import struct +import time +import logging +from io import BytesIO +import miasm.analysis.debugging as debugging +from miasm.jitter.jitload import ExceptionHandle + + +class GdbServer(object): + + "Debugguer binding for GDBServer protocol" + + general_registers_order = [] + general_registers_size = {} # RegName : Size in octet + status = b"S05" + + def __init__(self, dbg, port=4455): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(('localhost', port)) + server.listen(1) + self.server = server + self.dbg = dbg + + # Communication methods + + def compute_checksum(self, data): + return encode_hex(int_to_byte(sum(map(ord, data)) % 256)) + + def get_messages(self): + all_data = b"" + while True: + data = self.sock.recv(4096) + if not data: + break + all_data += data + + logging.debug("<- %r", all_data) + self.recv_queue += self.parse_messages(all_data) + + def parse_messages(self, data): + buf = BytesIO(data) + msgs = [] + + while (buf.tell() < buf.len): + token = buf.read(1) + if token == b"+": + continue + if token == b"-": + raise NotImplementedError("Resend packet") + if token == b"$": + packet_data = b"" + c = buf.read(1) + while c != b"#": + packet_data += c + c = buf.read(1) + checksum = buf.read(2) + if checksum != self.compute_checksum(packet_data): + raise ValueError("Incorrect checksum") + msgs.append(packet_data) + + return msgs + + def send_string(self, s): + self.send_queue.append(b"O" + encode_hex(s)) + + def process_messages(self): + + while self.recv_queue: + msg = self.recv_queue.pop(0) + buf = BytesIO(msg) + msg_type = buf.read(1) + + self.send_queue.append(b"+") + + if msg_type == b"q": + if msg.startswith(b"qSupported"): + self.send_queue.append(b"PacketSize=3fff") + elif msg.startswith(b"qC"): + # Current thread + self.send_queue.append(b"") + elif msg.startswith(b"qAttached"): + # Not supported + self.send_queue.append(b"") + elif msg.startswith(b"qTStatus"): + # Not supported + self.send_queue.append(b"") + elif msg.startswith(b"qfThreadInfo"): + # Not supported + self.send_queue.append(b"") + else: + raise NotImplementedError() + + elif msg_type == b"H": + # Set current thread + self.send_queue.append(b"OK") + + elif msg_type == b"?": + # Report why the target halted + self.send_queue.append(self.status) # TRAP signal + + elif msg_type == b"g": + # Report all general register values + self.send_queue.append(self.report_general_register_values()) + + elif msg_type == b"p": + # Read a specific register + reg_num = int(buf.read(), 16) + self.send_queue.append(self.read_register(reg_num)) + + elif msg_type == b"P": + # Set a specific register + reg_num, value = buf.read().split(b"=") + reg_num = int(reg_num, 16) + value = int(encode_hex(decode_hex(value)[::-1]), 16) + self.set_register(reg_num, value) + self.send_queue.append(b"OK") + + elif msg_type == b"m": + # Read memory + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + self.send_queue.append(self.read_memory(addr, size)) + + elif msg_type == b"k": + # Kill + self.sock.close() + self.send_queue = [] + self.sock = None + + elif msg_type == b"!": + # Extending debugging will be used + self.send_queue.append(b"OK") + + elif msg_type == b"v": + if msg == b"vCont?": + # Is vCont supported ? + self.send_queue.append(b"") + + elif msg_type == b"s": + # Step + self.dbg.step() + self.send_queue.append(b"S05") # TRAP signal + + elif msg_type == b"Z": + # Add breakpoint or watchpoint + bp_type = buf.read(1) + if bp_type == b"0": + # Exec breakpoint + assert(buf.read(1) == b",") + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + + if size != 1: + raise NotImplementedError("Bigger size") + self.dbg.add_breakpoint(addr) + self.send_queue.append(b"OK") + + elif bp_type == b"1": + # Hardware BP + assert(buf.read(1) == b",") + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + + self.dbg.add_memory_breakpoint( + addr, + size, + read=True, + write=True + ) + self.send_queue.append(b"OK") + + elif bp_type in [b"2", b"3", b"4"]: + # Memory breakpoint + assert(buf.read(1) == b",") + read = bp_type in [b"3", b"4"] + write = bp_type in [b"2", b"4"] + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + + self.dbg.add_memory_breakpoint( + addr, + size, + read=read, + write=write + ) + self.send_queue.append(b"OK") + + else: + raise ValueError("Impossible value") + + elif msg_type == b"z": + # Remove breakpoint or watchpoint + bp_type = buf.read(1) + if bp_type == b"0": + # Exec breakpoint + assert(buf.read(1) == b",") + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + + if size != 1: + raise NotImplementedError("Bigger size") + dbgsoft = self.dbg.get_breakpoint_by_addr(addr) + assert(len(dbgsoft) == 1) + self.dbg.remove_breakpoint(dbgsoft[0]) + self.send_queue.append(b"OK") + + elif bp_type == b"1": + # Hardware BP + assert(buf.read(1) == b",") + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + self.dbg.remove_memory_breakpoint_by_addr_access( + addr, + read=True, + write=True + ) + self.send_queue.append(b"OK") + + elif bp_type in [b"2", b"3", b"4"]: + # Memory breakpoint + assert(buf.read(1) == b",") + read = bp_type in [b"3", b"4"] + write = bp_type in [b"2", b"4"] + addr, size = (int(x, 16) for x in buf.read().split(b",", 1)) + + self.dbg.remove_memory_breakpoint_by_addr_access( + addr, + read=read, + write=write + ) + self.send_queue.append(b"OK") + + else: + raise ValueError("Impossible value") + + elif msg_type == b"c": + # Continue + self.status = b"" + self.send_messages() + ret = self.dbg.run() + if isinstance(ret, debugging.DebugBreakpointSoft): + self.status = b"S05" + self.send_queue.append(b"S05") # TRAP signal + elif isinstance(ret, ExceptionHandle): + if ret == ExceptionHandle.memoryBreakpoint(): + self.status = b"S05" + self.send_queue.append(b"S05") + else: + raise NotImplementedError("Unknown Except") + elif isinstance(ret, debugging.DebugBreakpointTerminate): + # Connection should close, but keep it running as a TRAP + # The connection will be close on instance destruction + print(ret) + self.status = b"S05" + self.send_queue.append(b"S05") + else: + raise NotImplementedError() + + else: + raise NotImplementedError( + "Not implemented: message type %r" % msg_type + ) + + def send_messages(self): + for msg in self.send_queue: + if msg == b"+": + data = b"+" + else: + data = b"$%s#%s" % (msg, self.compute_checksum(msg)) + logging.debug("-> %r", data) + self.sock.send(data) + self.send_queue = [] + + def main_loop(self): + self.recv_queue = [] + self.send_queue = [] + + self.send_string(b"Test\n") + + while (self.sock): + self.get_messages() + self.process_messages() + self.send_messages() + + def run(self): + self.sock, self.address = self.server.accept() + self.main_loop() + + # Debugguer processing methods + def report_general_register_values(self): + s = b"" + for i in range(len(self.general_registers_order)): + s += self.read_register(i) + return s + + def read_register(self, reg_num): + reg_name = self.general_registers_order[reg_num] + reg_value = self.read_register_by_name(reg_name) + size = self.general_registers_size[reg_name] + + pack_token = "" + if size == 1: + pack_token = "<B" + elif size == 2: + pack_token = "<H" + elif size == 4: + pack_token = "<I" + elif size == 8: + pack_token = "<Q" + else: + raise NotImplementedError("Unknown size") + + return encode_hex(struct.pack(pack_token, reg_value)) + + def set_register(self, reg_num, value): + reg_name = self.general_registers_order[reg_num] + self.dbg.set_reg_value(reg_name, value) + + def read_register_by_name(self, reg_name): + return self.dbg.get_reg_value(reg_name) + + def read_memory(self, addr, size): + except_flag_vm = self.dbg.myjit.vm.get_exception() + try: + return encode_hex(self.dbg.get_mem_raw(addr, size)) + except RuntimeError: + self.dbg.myjit.vm.set_exception(except_flag_vm) + return b"00" * size + + +class GdbServer_x86_32(GdbServer): + + "Extend GdbServer for x86 32bits purposes" + + general_registers_order = [ + "EAX", "ECX", "EDX", "EBX", "ESP", "EBP", "ESI", + "EDI", "EIP", "EFLAGS", "CS", "SS", "DS", "ES", + "FS", "GS" + ] + + general_registers_size = { + "EAX": 4, + "ECX": 4, + "EDX": 4, + "EBX": 4, + "ESP": 4, + "EBP": 4, + "ESI": 4, + "EDI": 4, + "EIP": 4, + "EFLAGS": 2, + "CS": 2, + "SS": 2, + "DS": 2, + "ES": 2, + "FS": 2, + "GS": 2 + } + + register_ignore = [ + "tf", "i_f", "nt", "rf", "vm", "ac", "vif", "vip", "i_d" + ] + + def read_register_by_name(self, reg_name): + sup_func = super(GdbServer_x86_32, self).read_register_by_name + + # Assert EIP on pc jitter + if reg_name == "EIP": + return self.dbg.myjit.pc + + # EFLAGS case + if reg_name == "EFLAGS": + val = 0 + eflags_args = [ + "cf", 1, "pf", 0, "af", 0, "zf", "nf", "tf", "i_f", "df", "of" + ] + eflags_args += ["nt", 0, "rf", "vm", "ac", "vif", "vip", "i_d"] + eflags_args += [0] * 10 + + for i, arg in enumerate(eflags_args): + if isinstance(arg, str): + if arg not in self.register_ignore: + to_add = sup_func(arg) + else: + to_add = 0 + else: + to_add = arg + + val |= (to_add << i) + return val + else: + return sup_func(reg_name) + + +class GdbServer_msp430(GdbServer): + + "Extend GdbServer for msp430 purposes" + + general_registers_order = [ + "PC", "SP", "SR", "R3", "R4", "R5", "R6", "R7", + "R8", "R9", "R10", "R11", "R12", "R13", "R14", + "R15" + ] + + general_registers_size = { + "PC": 2, + "SP": 2, + "SR": 2, + "R3": 2, + "R4": 2, + "R5": 2, + "R6": 2, + "R7": 2, + "R8": 2, + "R9": 2, + "R10": 2, + "R11": 2, + "R12": 2, + "R13": 2, + "R14": 2, + "R15": 2 + } + + def read_register_by_name(self, reg_name): + sup_func = super(GdbServer_msp430, self).read_register_by_name + if reg_name == "SR": + o = sup_func('res') + o <<= 1 + o |= sup_func('of') + o <<= 1 + o |= sup_func('scg1') + o <<= 1 + o |= sup_func('scg0') + o <<= 1 + o |= sup_func('osc') + o <<= 1 + o |= sup_func('cpuoff') + o <<= 1 + o |= sup_func('gie') + o <<= 1 + o |= sup_func('nf') + o <<= 1 + o |= sup_func('zf') + o <<= 1 + o |= sup_func('cf') + + return o + else: + return sup_func(reg_name) + diff --git a/src/miasm/analysis/machine.py b/src/miasm/analysis/machine.py new file mode 100644 index 00000000..cc86d753 --- /dev/null +++ b/src/miasm/analysis/machine.py @@ -0,0 +1,279 @@ +#-*- coding:utf-8 -*- +import warnings + + +class Machine(object): + """Abstract machine architecture to restrict architecture dependent code""" + + __dis_engine = None # Disassembly engine + __mn = None # Machine instance + __lifter_model_call = None # IR analyser + __jitter = None # Jit engine + __gdbserver = None # GdbServer handler + + __available = ["arml", "armb", "armtl", "armtb", "sh4", "x86_16", "x86_32", + "x86_64", "msp430", "mips32b", "mips32l", + "aarch64l", "aarch64b", "ppc32b", "mepl", "mepb"] + + + def __init__(self, machine_name): + + dis_engine = None + mn = None + lifter_model_call = None + ir = None + jitter = None + gdbserver = None + jit = None + log_jit = None + log_arch = None + + # Import on runtime for performance issue + if machine_name == "arml": + from miasm.arch.arm.disasm import dis_arml as dis_engine + from miasm.arch.arm import arch + try: + from miasm.arch.arm import jit + jitter = jit.jitter_arml + except ImportError: + pass + mn = arch.mn_arm + from miasm.arch.arm.lifter_model_call import LifterModelCallArml as lifter_model_call + from miasm.arch.arm.sem import Lifter_Arml as lifter + elif machine_name == "armb": + from miasm.arch.arm.disasm import dis_armb as dis_engine + from miasm.arch.arm import arch + try: + from miasm.arch.arm import jit + jitter = jit.jitter_armb + except ImportError: + pass + mn = arch.mn_arm + from miasm.arch.arm.lifter_model_call import LifterModelCallArmb as lifter_model_call + from miasm.arch.arm.sem import Lifter_Armb as lifter + elif machine_name == "aarch64l": + from miasm.arch.aarch64.disasm import dis_aarch64l as dis_engine + from miasm.arch.aarch64 import arch + try: + from miasm.arch.aarch64 import jit + jitter = jit.jitter_aarch64l + except ImportError: + pass + mn = arch.mn_aarch64 + from miasm.arch.aarch64.lifter_model_call import LifterModelCallAarch64l as lifter_model_call + from miasm.arch.aarch64.sem import Lifter_Aarch64l as lifter + elif machine_name == "aarch64b": + from miasm.arch.aarch64.disasm import dis_aarch64b as dis_engine + from miasm.arch.aarch64 import arch + try: + from miasm.arch.aarch64 import jit + jitter = jit.jitter_aarch64b + except ImportError: + pass + mn = arch.mn_aarch64 + from miasm.arch.aarch64.lifter_model_call import LifterModelCallAarch64b as lifter_model_call + from miasm.arch.aarch64.sem import Lifter_Aarch64b as lifter + elif machine_name == "armtl": + from miasm.arch.arm.disasm import dis_armtl as dis_engine + from miasm.arch.arm import arch + mn = arch.mn_armt + from miasm.arch.arm.lifter_model_call import LifterModelCallArmtl as lifter_model_call + from miasm.arch.arm.sem import Lifter_Armtl as lifter + try: + from miasm.arch.arm import jit + jitter = jit.jitter_armtl + except ImportError: + pass + elif machine_name == "armtb": + from miasm.arch.arm.disasm import dis_armtb as dis_engine + from miasm.arch.arm import arch + mn = arch.mn_armt + from miasm.arch.arm.lifter_model_call import LifterModelCallArmtb as lifter_model_call + from miasm.arch.arm.sem import Lifter_Armtb as lifter + elif machine_name == "sh4": + from miasm.arch.sh4 import arch + mn = arch.mn_sh4 + elif machine_name == "x86_16": + from miasm.arch.x86.disasm import dis_x86_16 as dis_engine + from miasm.arch.x86 import arch + try: + from miasm.arch.x86 import jit + jitter = jit.jitter_x86_16 + except ImportError: + pass + mn = arch.mn_x86 + from miasm.arch.x86.lifter_model_call import LifterModelCall_x86_16 as lifter_model_call + from miasm.arch.x86.sem import Lifter_X86_16 as lifter + elif machine_name == "x86_32": + from miasm.arch.x86.disasm import dis_x86_32 as dis_engine + from miasm.arch.x86 import arch + try: + from miasm.arch.x86 import jit + jitter = jit.jitter_x86_32 + except ImportError: + pass + mn = arch.mn_x86 + from miasm.arch.x86.lifter_model_call import LifterModelCall_x86_32 as lifter_model_call + from miasm.arch.x86.sem import Lifter_X86_32 as lifter + try: + from miasm.analysis.gdbserver import GdbServer_x86_32 as gdbserver + except ImportError: + pass + elif machine_name == "x86_64": + from miasm.arch.x86.disasm import dis_x86_64 as dis_engine + from miasm.arch.x86 import arch + try: + from miasm.arch.x86 import jit + jitter = jit.jitter_x86_64 + except ImportError: + pass + mn = arch.mn_x86 + from miasm.arch.x86.lifter_model_call import LifterModelCall_x86_64 as lifter_model_call + from miasm.arch.x86.sem import Lifter_X86_64 as lifter + elif machine_name == "msp430": + from miasm.arch.msp430.disasm import dis_msp430 as dis_engine + from miasm.arch.msp430 import arch + try: + from miasm.arch.msp430 import jit + jitter = jit.jitter_msp430 + except ImportError: + pass + mn = arch.mn_msp430 + from miasm.arch.msp430.lifter_model_call import LifterModelCallMsp430 as lifter_model_call + from miasm.arch.msp430.sem import Lifter_MSP430 as lifter + try: + from miasm.analysis.gdbserver import GdbServer_msp430 as gdbserver + except ImportError: + pass + elif machine_name == "mips32b": + from miasm.arch.mips32.disasm import dis_mips32b as dis_engine + from miasm.arch.mips32 import arch + try: + from miasm.arch.mips32 import jit + jitter = jit.jitter_mips32b + except ImportError: + pass + mn = arch.mn_mips32 + from miasm.arch.mips32.lifter_model_call import LifterModelCallMips32b as lifter_model_call + from miasm.arch.mips32.sem import Lifter_Mips32b as lifter + elif machine_name == "mips32l": + from miasm.arch.mips32.disasm import dis_mips32l as dis_engine + from miasm.arch.mips32 import arch + try: + from miasm.arch.mips32 import jit + jitter = jit.jitter_mips32l + except ImportError: + pass + mn = arch.mn_mips32 + from miasm.arch.mips32.lifter_model_call import LifterModelCallMips32l as lifter_model_call + from miasm.arch.mips32.sem import Lifter_Mips32l as lifter + elif machine_name == "ppc32b": + from miasm.arch.ppc.disasm import dis_ppc32b as dis_engine + from miasm.arch.ppc import arch + try: + from miasm.arch.ppc import jit + jitter = jit.jitter_ppc32b + except ImportError: + pass + mn = arch.mn_ppc + from miasm.arch.ppc.lifter_model_call import LifterModelCallPpc32b as lifter_model_call + from miasm.arch.ppc.sem import Lifter_PPC32b as lifter + elif machine_name == "mepb": + from miasm.arch.mep.disasm import dis_mepb as dis_engine + from miasm.arch.mep import arch + try: + from miasm.arch.mep import jit + jitter = jit.jitter_mepb + except ImportError: + pass + mn = arch.mn_mep + from miasm.arch.mep.lifter_model_call import LifterModelCallMepb as lifter_model_call + from miasm.arch.mep.sem import Lifter_MEPb as lifter + elif machine_name == "mepl": + from miasm.arch.mep.disasm import dis_mepl as dis_engine + from miasm.arch.mep import arch + try: + from miasm.arch.mep import jit + jitter = jit.jitter_mepl + except ImportError: + pass + mn = arch.mn_mep + from miasm.arch.mep.lifter_model_call import LifterModelCallMepl as lifter_model_call + from miasm.arch.mep.sem import Lifter_MEPl as lifter + else: + raise ValueError('Unknown machine: %s' % machine_name) + + # Loggers + if jit is not None: + log_jit = jit.log + log_arch = arch.log + + self.__dis_engine = dis_engine + self.__mn = mn + self.__lifter_model_call = lifter_model_call + self.__jitter = jitter + self.__gdbserver = gdbserver + self.__log_jit = log_jit + self.__log_arch = log_arch + self.__base_expr = arch.base_expr + self.__lifter = lifter + self.__name = machine_name + + @property + def dis_engine(self): + return self.__dis_engine + + @property + def mn(self): + return self.__mn + + @property + def lifter(self): + return self.__lifter + + @property + def lifter_model_call(self): + return self.__lifter_model_call + + @property + def ir(self): + return self.__ir + + @property + def jitter(self): + return self.__jitter + + @property + def gdbserver(self): + return self.__gdbserver + + @property + def log_jit(self): + return self.__log_jit + + @property + def log_arch(self): + return self.__log_arch + + @property + def base_expr(self): + return self.__base_expr + + @property + def name(self): + return self.__name + + @classmethod + def available_machine(cls): + "Return a list of supported machines" + return cls.__available + + @property + def ira(self): + warnings.warn('DEPRECATION WARNING: use ".lifter_model_call" instead of ".ira"') + return self.lifter_model_call + + @property + def ir(self): + warnings.warn('DEPRECATION WARNING: use ".lifter" instead of ".ir"') + return self.lifter diff --git a/src/miasm/analysis/modularintervals.py b/src/miasm/analysis/modularintervals.py new file mode 100644 index 00000000..67eda9dc --- /dev/null +++ b/src/miasm/analysis/modularintervals.py @@ -0,0 +1,525 @@ +"""Intervals with a maximum size, supporting modular arithmetic""" + +from future.builtins import range +from builtins import int as int_types +from itertools import product + +from miasm.core.interval import interval +from miasm.core.utils import size2mask + +class ModularIntervals(object): + """Intervals with a maximum size, supporting modular arithmetic""" + + def __init__(self, size, intervals=None): + """Instantiate a ModularIntervals of size @size + @size: maximum size of elements + @intervals: (optional) interval instance, or any type supported by + interval initialisation; element of the current instance + """ + # Create or cast @intervals argument + if intervals is None: + intervals = interval() + if not isinstance(intervals, interval): + intervals = interval(intervals) + self.intervals = intervals + self.size = size + + # Sanity check + start, end = intervals.hull() + if start is not None: + assert start >= 0 + if end is not None: + assert end <= self.mask + + # Helpers + def _range2interval(func): + """Convert a function taking 2 ranges to a function taking a ModularIntervals + and applying to the current instance""" + def ret_func(self, target): + ret = interval() + for left_i, right_i in product(self.intervals, target.intervals): + ret += func(self, left_i[0], left_i[1], right_i[0], + right_i[1]) + return self.__class__(self.size, ret) + return ret_func + + def _range2integer(func): + """Convert a function taking 1 range and optional arguments to a function + applying to the current instance""" + def ret_func(self, *args): + ret = interval() + for x_min, x_max in self.intervals: + ret += func(self, x_min, x_max, *args) + return self.__class__(self.size, ret) + return ret_func + + def _promote(func): + """Check and promote the second argument from integer to + ModularIntervals with one value""" + def ret_func(self, target): + if isinstance(target, int_types): + target = ModularIntervals(self.size, interval([(target, target)])) + if not isinstance(target, ModularIntervals): + raise TypeError("Unsupported operation with %s" % target.__class__) + if target.size != self.size: + raise TypeError("Size are not the same: %s vs %s" % (self.size, + target.size)) + return func(self, target) + return ret_func + + def _unsigned2signed(self, value): + """Return the signed value of @value, based on self.size""" + if (value & (1 << (self.size - 1))): + return -(self.mask ^ value) - 1 + else: + return value + + def _signed2unsigned(self, value): + """Return the unsigned value of @value, based on self.size""" + return value & self.mask + + # Operation internals + # + # Naming convention: + # _range_{op}: takes 2 interval bounds and apply op + # _range_{op}_uniq: takes 1 interval bounds and apply op + # _interval_{op}: apply op on an ModularIntervals + # _integer_{op}: apply op on itself with possible arguments + + def _range_add(self, x_min, x_max, y_min, y_max): + """Bounds interval for x + y, with + - x, y of size 'self.size' + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + max_bound = self.mask + if (x_min + y_min <= max_bound and + x_max + y_max >= max_bound + 1): + # HD returns 0, max_bound; but this is because it cannot handle multiple + # interval. + # x_max + y_max can only overflow once, so returns + # [result_min, overflow] U [0, overflow_rest] + return interval([(x_min + y_min, max_bound), + (0, (x_max + y_max) & max_bound)]) + else: + return interval([((x_min + y_min) & max_bound, + (x_max + y_max) & max_bound)]) + + _interval_add = _range2interval(_range_add) + + def _range_minus_uniq(self, x_min, x_max): + """Bounds interval for -x, with + - x of size self.size + - @x_min <= x <= @x_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + max_bound = self.mask + if (x_min == 0 and x_max != 0): + # HD returns 0, max_bound; see _range_add + return interval([(0, 0), ((- x_max) & max_bound, max_bound)]) + else: + return interval([((- x_max) & max_bound, (- x_min) & max_bound)]) + + _interval_minus = _range2integer(_range_minus_uniq) + + def _range_or_min(self, x_min, x_max, y_min, y_max): + """Interval min for x | y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + max_bit = 1 << (self.size - 1) + while max_bit: + if ~x_min & y_min & max_bit: + temp = (x_min | max_bit) & - max_bit + if temp <= x_max: + x_min = temp + break + elif x_min & ~y_min & max_bit: + temp = (y_min | max_bit) & - max_bit + if temp <= y_max: + y_min = temp + break + max_bit >>= 1 + return x_min | y_min + + def _range_or_max(self, x_min, x_max, y_min, y_max): + """Interval max for x | y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + max_bit = 1 << (self.size - 1) + while max_bit: + if x_max & y_max & max_bit: + temp = (x_max - max_bit) | (max_bit - 1) + if temp >= x_min: + x_max = temp + break + temp = (y_max - max_bit) | (max_bit - 1) + if temp >= y_min: + y_max = temp + break + max_bit >>= 1 + return x_max | y_max + + def _range_or(self, x_min, x_max, y_min, y_max): + """Interval bounds for x | y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + return interval([(self._range_or_min(x_min, x_max, y_min, y_max), + self._range_or_max(x_min, x_max, y_min, y_max))]) + + _interval_or = _range2interval(_range_or) + + def _range_and_min(self, x_min, x_max, y_min, y_max): + """Interval min for x & y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + max_bit = (1 << (self.size - 1)) + while max_bit: + if ~x_min & ~y_min & max_bit: + temp = (x_min | max_bit) & - max_bit + if temp <= x_max: + x_min = temp + break + temp = (y_min | max_bit) & - max_bit + if temp <= y_max: + y_min = temp + break + max_bit >>= 1 + return x_min & y_min + + def _range_and_max(self, x_min, x_max, y_min, y_max): + """Interval max for x & y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + max_bit = (1 << (self.size - 1)) + while max_bit: + if x_max & ~y_max & max_bit: + temp = (x_max & ~max_bit) | (max_bit - 1) + if temp >= x_min: + x_max = temp + break + elif ~x_max & y_max & max_bit: + temp = (y_max & ~max_bit) | (max_bit - 1) + if temp >= y_min: + y_max = temp + break + max_bit >>= 1 + return x_max & y_max + + def _range_and(self, x_min, x_max, y_min, y_max): + """Interval bounds for x & y, with + - x, y of size @size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + return interval([(self._range_and_min(x_min, x_max, y_min, y_max), + self._range_and_max(x_min, x_max, y_min, y_max))]) + + _interval_and = _range2interval(_range_and) + + def _range_xor(self, x_min, x_max, y_min, y_max): + """Interval bounds for x ^ y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + From Hacker's Delight: Chapter 4 + """ + not_size = lambda x: x ^ self.mask + min_xor = self._range_and_min(x_min, x_max, not_size(y_max), not_size(y_min)) | self._range_and_min(not_size(x_max), not_size(x_min), y_min, y_max) + max_xor = self._range_or_max(0, + self._range_and_max(x_min, x_max, not_size(y_max), not_size(y_min)), + 0, + self._range_and_max(not_size(x_max), not_size(x_min), y_min, y_max)) + return interval([(min_xor, max_xor)]) + + _interval_xor = _range2interval(_range_xor) + + def _range_mul(self, x_min, x_max, y_min, y_max): + """Interval bounds for x * y, with + - x, y of size self.size + - @x_min <= x <= @x_max + - @y_min <= y <= @y_max + - operations are considered unsigned + This is a naive version, going to TOP on overflow""" + max_bound = self.mask + if y_max * x_max > max_bound: + return interval([(0, max_bound)]) + else: + return interval([(x_min * y_min, x_max * y_max)]) + + _interval_mul = _range2interval(_range_mul) + + def _range_mod_uniq(self, x_min, x_max, mod): + """Interval bounds for x % @mod, with + - x, @mod of size self.size + - @x_min <= x <= @x_max + - operations are considered unsigned + """ + if (x_max - x_min) >= mod: + return interval([(0, mod - 1)]) + x_max = x_max % mod + x_min = x_min % mod + if x_max < x_min: + return interval([(0, x_max), (x_min, mod - 1)]) + else: + return interval([(x_min, x_max)]) + + _integer_modulo = _range2integer(_range_mod_uniq) + + def _range_shift_uniq(self, x_min, x_max, shift, op): + """Bounds interval for x @op @shift with + - x of size self.size + - @x_min <= x <= @x_max + - operations are considered unsigned + - shift <= self.size + """ + assert shift <= self.size + # Shift operations are monotonic, and overflow results in 0 + max_bound = self.mask + + if op == "<<": + obtain_max = x_max << shift + if obtain_max > max_bound: + # Overflow at least on max, best-effort + # result '0' often happen, include it + return interval([(0, 0), ((1 << shift) - 1, max_bound)]) + else: + return interval([(x_min << shift, obtain_max)]) + elif op == ">>": + return interval([((x_min >> shift) & max_bound, + (x_max >> shift) & max_bound)]) + elif op == "a>>": + # The Miasm2 version (Expr or ModInt) could have been used, but + # introduce unnecessary dependencies for this module + # Python >> is the arithmetic one + ashr = lambda x, y: self._signed2unsigned(self._unsigned2signed(x) >> y) + end_min, end_max = ashr(x_min, shift), ashr(x_max, shift) + end_min, end_max = min(end_min, end_max), max(end_min, end_max) + return interval([(end_min, end_max)]) + else: + raise ValueError("%s is not a shifter" % op) + + def _interval_shift(self, operation, shifter): + """Apply the shifting operation @operation with a shifting + ModularIntervals @shifter on the current instance""" + # Work on a copy of shifter intervals + shifter = interval(shifter.intervals) + if (shifter.hull()[1] >= self.size): + shifter += interval([(self.size, self.size)]) + shifter &= interval([(0, self.size)]) + ret = interval() + for shift_range in shifter: + for shift in range(shift_range[0], shift_range[1] + 1): + for x_min, x_max in self.intervals: + ret += self._range_shift_uniq(x_min, x_max, shift, operation) + return self.__class__(self.size, ret) + + def _range_rotate_uniq(self, x_min, x_max, shift, op): + """Bounds interval for x @op @shift with + - x of size self.size + - @x_min <= x <= @x_max + - operations are considered unsigned + - shift <= self.size + """ + assert shift <= self.size + # Divide in sub-operations: a op b: a left b | a right (size - b) + if op == ">>>": + left, right = ">>", "<<" + elif op == "<<<": + left, right = "<<", ">>" + else: + raise ValueError("Not a rotator: %s" % op) + + left_intervals = self._range_shift_uniq(x_min, x_max, shift, left) + right_intervals = self._range_shift_uniq(x_min, x_max, + self.size - shift, right) + + result = self.__class__(self.size, left_intervals) | self.__class__(self.size, right_intervals) + return result.intervals + + def _interval_rotate(self, operation, shifter): + """Apply the rotate operation @operation with a shifting + ModularIntervals @shifter on the current instance""" + # Consider only rotation without repetition, and enumerate + # -> apply a '% size' on shifter + shifter %= self.size + ret = interval() + for shift_range in shifter: + for shift in range(shift_range[0], shift_range[1] + 1): + for x_min, x_max in self.intervals: + ret += self._range_rotate_uniq(x_min, x_max, shift, + operation) + + return self.__class__(self.size, ret) + + # Operation wrappers + + @_promote + def __add__(self, to_add): + """Add @to_add to the current intervals + @to_add: ModularInstances or integer + """ + return self._interval_add(to_add) + + @_promote + def __or__(self, to_or): + """Bitwise OR @to_or to the current intervals + @to_or: ModularInstances or integer + """ + return self._interval_or(to_or) + + @_promote + def __and__(self, to_and): + """Bitwise AND @to_and to the current intervals + @to_and: ModularInstances or integer + """ + return self._interval_and(to_and) + + @_promote + def __xor__(self, to_xor): + """Bitwise XOR @to_xor to the current intervals + @to_xor: ModularInstances or integer + """ + return self._interval_xor(to_xor) + + @_promote + def __mul__(self, to_mul): + """Multiply @to_mul to the current intervals + @to_mul: ModularInstances or integer + """ + return self._interval_mul(to_mul) + + @_promote + def __rshift__(self, to_shift): + """Logical shift right the current intervals of @to_shift + @to_shift: ModularInstances or integer + """ + return self._interval_shift('>>', to_shift) + + @_promote + def __lshift__(self, to_shift): + """Logical shift left the current intervals of @to_shift + @to_shift: ModularInstances or integer + """ + return self._interval_shift('<<', to_shift) + + @_promote + def arithmetic_shift_right(self, to_shift): + """Arithmetic shift right the current intervals of @to_shift + @to_shift: ModularInstances or integer + """ + return self._interval_shift('a>>', to_shift) + + def __neg__(self): + """Negate the current intervals""" + return self._interval_minus() + + def __mod__(self, modulo): + """Apply % @modulo on the current intervals + @modulo: integer + """ + + if not isinstance(modulo, int_types): + raise TypeError("Modulo with %s is not supported" % modulo.__class__) + return self._integer_modulo(modulo) + + @_promote + def rotation_right(self, to_rotate): + """Right rotate the current intervals of @to_rotate + @to_rotate: ModularInstances or integer + """ + return self._interval_rotate('>>>', to_rotate) + + @_promote + def rotation_left(self, to_rotate): + """Left rotate the current intervals of @to_rotate + @to_rotate: ModularInstances or integer + """ + return self._interval_rotate('<<<', to_rotate) + + # Instance operations + + @property + def mask(self): + """Return the mask corresponding to the instance size""" + return size2mask(self.size) + + def __iter__(self): + return iter(self.intervals) + + @property + def length(self): + return self.intervals.length + + def __contains__(self, other): + if isinstance(other, ModularIntervals): + other = other.intervals + return other in self.intervals + + def __str__(self): + return "%s (Size: %s)" % (self.intervals, self.size) + + def size_update(self, new_size): + """Update the instance size to @new_size + The size of elements must be <= @new_size""" + + # Increasing size is always safe + if new_size < self.size: + # Check that current values are indeed included in the new range + assert self.intervals.hull()[1] <= size2mask(new_size) + + self.size = new_size + + # For easy chainning + return self + + # Mimic Python's set operations + + @_promote + def union(self, to_union): + """Union set operation with @to_union + @to_union: ModularIntervals instance""" + return ModularIntervals(self.size, self.intervals + to_union.intervals) + + @_promote + def update(self, to_union): + """Union set operation in-place with @to_union + @to_union: ModularIntervals instance""" + self.intervals += to_union.intervals + + @_promote + def intersection(self, to_intersect): + """Intersection set operation with @to_intersect + @to_intersect: ModularIntervals instance""" + return ModularIntervals(self.size, self.intervals & to_intersect.intervals) + + @_promote + def intersection_update(self, to_intersect): + """Intersection set operation in-place with @to_intersect + @to_intersect: ModularIntervals instance""" + self.intervals &= to_intersect.intervals diff --git a/src/miasm/analysis/outofssa.py b/src/miasm/analysis/outofssa.py new file mode 100644 index 00000000..2f2b185c --- /dev/null +++ b/src/miasm/analysis/outofssa.py @@ -0,0 +1,415 @@ +from future.utils import viewitems, viewvalues + +from miasm.expression.expression import ExprId +from miasm.ir.ir import IRBlock, AssignBlock +from miasm.analysis.ssa import get_phi_sources_parent_block, \ + irblock_has_phi + + +class Varinfo(object): + """Store liveness information for a variable""" + __slots__ = ["live_index", "loc_key", "index"] + + def __init__(self, live_index, loc_key, index): + self.live_index = live_index + self.loc_key = loc_key + self.index = index + + +class UnSSADiGraph(object): + """ + Implements unssa algorithm + Revisiting Out-of-SSA Translation for Correctness, Code Quality, and + Efficiency + """ + + def __init__(self, ssa, head, cfg_liveness): + self.cfg_liveness = cfg_liveness + self.ssa = ssa + self.head = head + + # Set of created variables + self.copy_vars = set() + # Virtual parallel copies + + # On loc_key's Phi node dst -> set((parent, src)) + self.phi_parent_sources = {} + # On loc_key's Phi node, loc_key -> set(Phi dsts) + self.phi_destinations = {} + # Phi's dst -> new var + self.phi_new_var = {} + # For a new_var representing dst: + # new_var -> set(parents of Phi's src in dst = Phi(src,...)) + self.new_var_to_srcs_parents = {} + # new_var -> set(variables to be coalesced with, named "merge_set") + self.merge_state = {} + + # Launch the algorithm in several steps + self.isolate_phi_nodes_block() + self.init_phis_merge_state() + self.order_ssa_var_dom() + self.aggressive_coalesce_block() + self.insert_parallel_copy() + self.replace_merge_sets() + self.remove_assign_eq() + + def insert_parallel_copy(self): + """ + Naive Out-of-SSA from CSSA (without coalescing for now) + - Replace Phi + - Create room for parallel copies in Phi's parents + """ + ircfg = self.ssa.graph + + for irblock in list(viewvalues(ircfg.blocks)): + if not irblock_has_phi(irblock): + continue + + # Replace Phi with Phi's dst = new_var + parallel_copies = {} + for dst in self.phi_destinations[irblock.loc_key]: + new_var = self.phi_new_var[dst] + parallel_copies[dst] = new_var + + assignblks = list(irblock) + assignblks[0] = AssignBlock(parallel_copies, irblock[0].instr) + new_irblock = IRBlock(irblock.loc_db, irblock.loc_key, assignblks) + ircfg.blocks[irblock.loc_key] = new_irblock + + # Insert new_var = src in each Phi's parent, at the end of the block + parent_to_parallel_copies = {} + parallel_copies = {} + for dst in irblock[0]: + new_var = self.phi_new_var[dst] + for parent, src in self.phi_parent_sources[dst]: + parent_to_parallel_copies.setdefault(parent, {})[new_var] = src + + for parent, parallel_copies in viewitems(parent_to_parallel_copies): + parent = ircfg.blocks[parent] + assignblks = list(parent) + assignblks.append(AssignBlock(parallel_copies, parent[-1].instr)) + new_irblock = IRBlock(parent.loc_db, parent.loc_key, assignblks) + ircfg.blocks[parent.loc_key] = new_irblock + + def create_copy_var(self, var): + """ + Generate a new var standing for @var + @var: variable to replace + """ + new_var = ExprId('var%d' % len(self.copy_vars), var.size) + self.copy_vars.add(new_var) + return new_var + + def isolate_phi_nodes_block(self): + """ + Init structures and virtually insert parallel copy before/after each phi + node + """ + ircfg = self.ssa.graph + for irblock in viewvalues(ircfg.blocks): + if not irblock_has_phi(irblock): + continue + for dst, sources in viewitems(irblock[0]): + assert sources.is_op('Phi') + new_var = self.create_copy_var(dst) + self.phi_new_var[dst] = new_var + + var_to_parents = get_phi_sources_parent_block( + self.ssa.graph, + irblock.loc_key, + sources.args + ) + + for src in sources.args: + parents = var_to_parents[src] + self.new_var_to_srcs_parents.setdefault(new_var, set()).update(parents) + for parent in parents: + self.phi_parent_sources.setdefault(dst, set()).add((parent, src)) + + self.phi_destinations[irblock.loc_key] = set(irblock[0]) + + def init_phis_merge_state(self): + """ + Generate trivial coalescing of phi variable and itself + """ + for phi_new_var in viewvalues(self.phi_new_var): + self.merge_state.setdefault(phi_new_var, set([phi_new_var])) + + def order_ssa_var_dom(self): + """Compute dominance order of each ssa variable""" + ircfg = self.ssa.graph + + # compute dominator tree + dominator_tree = ircfg.compute_dominator_tree(self.head) + + # variable -> Varinfo + self.var_to_varinfo = {} + # live_index can later be used to compare dominance of AssignBlocks + live_index = 0 + + # walk in DFS over the dominator tree + for loc_key in dominator_tree.walk_depth_first_forward(self.head): + irblock = ircfg.blocks.get(loc_key, None) + if irblock is None: + continue + + # Create live index for phi new vars + # They do not exist in the graph yet, so index is set to None + if irblock_has_phi(irblock): + for dst in irblock[0]: + if not dst.is_id(): + continue + new_var = self.phi_new_var[dst] + self.var_to_varinfo[new_var] = Varinfo(live_index, loc_key, None) + + live_index += 1 + + # Create live index for remaining assignments + for index, assignblk in enumerate(irblock): + used = False + for dst in assignblk: + if not dst.is_id(): + continue + if dst in self.ssa.immutable_ids: + # Will not be considered by the current algo, ignore it + # (for instance, IRDst) + continue + + assert dst not in self.var_to_varinfo + self.var_to_varinfo[dst] = Varinfo(live_index, loc_key, index) + used = True + if used: + live_index += 1 + + + def ssa_def_dominates(self, node_a, node_b): + """ + Return living index order of @node_a and @node_b + @node_a: Varinfo instance + @node_b: Varinfo instance + """ + ret = self.var_to_varinfo[node_a].live_index <= self.var_to_varinfo[node_b].live_index + return ret + + def merge_set_sort(self, merge_set): + """ + Return a sorted list of (live_index, var) from @merge_set in dominance + order + @merge_set: set of coalescing variables + """ + return sorted( + (self.var_to_varinfo[var].live_index, var) + for var in merge_set + ) + + def ssa_def_is_live_at(self, node_a, node_b, parent): + """ + Return True if @node_a is live during @node_b definition + If @parent is None, this is a liveness test for a post phi variable; + Else, it is a liveness test for a variable source of the phi node + + @node_a: Varinfo instance + @node_b: Varinfo instance + @parent: Optional parent location of the phi source + """ + loc_key_b, index_b = self.var_to_varinfo[node_b].loc_key, self.var_to_varinfo[node_b].index + if parent and index_b is None: + index_b = 0 + if node_a not in self.new_var_to_srcs_parents: + # node_a is not a new var (it is a "classic" var) + # -> use a basic liveness test + liveness_b = self.cfg_liveness.blocks[loc_key_b].infos[index_b] + return node_a in liveness_b.var_out + + for def_loc_key in self.new_var_to_srcs_parents[node_a]: + # Consider node_a as defined at the end of its parents blocks + # and compute liveness check accordingly + + if def_loc_key == parent: + # Same path as node_a definition, so SSA ensure b cannot be live + # on this path (otherwise, a Phi would already happen earlier) + continue + liveness_end_block = self.cfg_liveness.blocks[def_loc_key].infos[-1] + if node_b in liveness_end_block.var_out: + return True + return False + + def merge_nodes_interfere(self, node_a, node_b, parent): + """ + Return True if @node_a and @node_b interfere + @node_a: variable + @node_b: variable + @parent: Optional parent location of the phi source for liveness tests + + Interference check is: is x live at y definition (or reverse) + TODO: add Value-based interference improvement + """ + if self.var_to_varinfo[node_a].live_index == self.var_to_varinfo[node_b].live_index: + # Defined in the same AssignBlock -> interfere + return True + + if self.var_to_varinfo[node_a].live_index < self.var_to_varinfo[node_b].live_index: + return self.ssa_def_is_live_at(node_a, node_b, parent) + return self.ssa_def_is_live_at(node_b, node_a, parent) + + def merge_sets_interfere(self, merge_a, merge_b, parent): + """ + Return True if no variable in @merge_a and @merge_b interferes. + + Implementation of "Algorithm 2: Check intersection in a set of variables" + + @merge_a: a dom ordered list of equivalent variables + @merge_b: a dom ordered list of equivalent variables + @parent: Optional parent location of the phi source for liveness tests + """ + if merge_a == merge_b: + # No need to consider interference if equal + return False + + merge_a_list = self.merge_set_sort(merge_a) + merge_b_list = self.merge_set_sort(merge_b) + dom = [] + while merge_a_list or merge_b_list: + if not merge_a_list: + _, current = merge_b_list.pop(0) + elif not merge_b_list: + _, current = merge_a_list.pop(0) + else: + # compare live_indexes (standing for dominance) + if merge_a_list[-1] < merge_b_list[-1]: + _, current = merge_a_list.pop(0) + else: + _, current = merge_b_list.pop(0) + while dom and not self.ssa_def_dominates(dom[-1], current): + dom.pop() + + # Don't test node in same merge_set + if ( + # Is stack not empty? + dom and + # Trivial non-interference if dom.top() and current come + # from the same merge set + not (dom[-1] in merge_a and current in merge_a) and + not (dom[-1] in merge_b and current in merge_b) and + # Actually test for interference + self.merge_nodes_interfere(current, dom[-1], parent) + ): + return True + dom.append(current) + return False + + def aggressive_coalesce_parallel_copy(self, parallel_copies, parent): + """ + Try to coalesce variables each dst/src couple together from + @parallel_copies + + @parallel_copies: a dictionary representing dst/src parallel + assignments. + @parent: Optional parent location of the phi source for liveness tests + """ + for dst, src in viewitems(parallel_copies): + dst_merge = self.merge_state.setdefault(dst, set([dst])) + src_merge = self.merge_state.setdefault(src, set([src])) + if not self.merge_sets_interfere(dst_merge, src_merge, parent): + dst_merge.update(src_merge) + for node in dst_merge: + self.merge_state[node] = dst_merge + + def aggressive_coalesce_block(self): + """Try to coalesce phi var with their pre/post variables""" + + ircfg = self.ssa.graph + + # Run coalesce on the post phi parallel copy + for irblock in viewvalues(ircfg.blocks): + if not irblock_has_phi(irblock): + continue + parallel_copies = {} + for dst in self.phi_destinations[irblock.loc_key]: + parallel_copies[dst] = self.phi_new_var[dst] + self.aggressive_coalesce_parallel_copy(parallel_copies, None) + + # Run coalesce on the pre phi parallel copy + + # Stand for the virtual parallel copies at the end of Phi's block + # parents + parent_to_parallel_copies = {} + for dst in irblock[0]: + new_var = self.phi_new_var[dst] + for parent, src in self.phi_parent_sources[dst]: + parent_to_parallel_copies.setdefault(parent, {})[new_var] = src + + for parent, parallel_copies in viewitems(parent_to_parallel_copies): + self.aggressive_coalesce_parallel_copy(parallel_copies, parent) + + def get_best_merge_set_name(self, merge_set): + """ + For a given @merge_set, prefer an original SSA variable instead of a + created copy. In other case, take a random name. + @merge_set: set of equivalent expressions + """ + if not merge_set: + raise RuntimeError("Merge set should not be empty") + for var in merge_set: + if var not in self.copy_vars: + return var + # Get random name + return var + + + def replace_merge_sets(self): + """ + In the graph, replace all variables from merge state by their + representative variable + """ + replace = {} + merge_sets = set() + + # Elect representative for merge sets + merge_set_to_name = {} + for merge_set in viewvalues(self.merge_state): + frozen_merge_set = frozenset(merge_set) + merge_sets.add(frozen_merge_set) + var_name = self.get_best_merge_set_name(merge_set) + merge_set_to_name[frozen_merge_set] = var_name + + # Generate replacement of variable by their representative + for merge_set in merge_sets: + var_name = merge_set_to_name[merge_set] + merge_set = list(merge_set) + for var in merge_set: + replace[var] = var_name + + self.ssa.graph.simplify(lambda x: x.replace_expr(replace)) + + def remove_phi(self): + """ + Remove phi operators in @ifcfg + @ircfg: IRDiGraph instance + """ + + for irblock in list(viewvalues(self.ssa.graph.blocks)): + assignblks = list(irblock) + out = {} + for dst, src in viewitems(assignblks[0]): + if src.is_op('Phi'): + assert set([dst]) == set(src.args) + continue + out[dst] = src + assignblks[0] = AssignBlock(out, assignblks[0].instr) + self.ssa.graph.blocks[irblock.loc_key] = IRBlock(irblock.loc_db, irblock.loc_key, assignblks) + + def remove_assign_eq(self): + """ + Remove trivial expressions (a=a) in the current graph + """ + for irblock in list(viewvalues(self.ssa.graph.blocks)): + assignblks = list(irblock) + for i, assignblk in enumerate(assignblks): + out = {} + for dst, src in viewitems(assignblk): + if dst == src: + continue + out[dst] = src + assignblks[i] = AssignBlock(out, assignblk.instr) + self.ssa.graph.blocks[irblock.loc_key] = IRBlock(irblock.loc_db, irblock.loc_key, assignblks) diff --git a/src/miasm/analysis/sandbox.py b/src/miasm/analysis/sandbox.py new file mode 100644 index 00000000..e51fd45a --- /dev/null +++ b/src/miasm/analysis/sandbox.py @@ -0,0 +1,1033 @@ +from __future__ import print_function +from builtins import range + +import os +import logging +from argparse import ArgumentParser + +from future.utils import viewitems, viewvalues +from past.builtins import basestring + +from miasm.analysis.machine import Machine +from miasm.jitter.csts import PAGE_READ, PAGE_WRITE +from miasm.analysis import debugging +from miasm.jitter.jitload import log_func +from miasm.core.utils import force_bytes + + +class Sandbox(object): + + """ + Parent class for Sandbox abstraction + """ + + CALL_FINISH_ADDR = 0x13371acc + + @staticmethod + def code_sentinelle(jitter): + jitter.running = False + return False + + @classmethod + def _classes_(cls): + """ + Iterator on parent classes except Sanbox + """ + for base_cls in cls.__bases__: + # Avoid infinite loop + if base_cls == Sandbox: + continue + + yield base_cls + + classes = property(lambda x: x.__class__._classes_()) + + def __init__(self, loc_db, fname, options, custom_methods=None, **kwargs): + """ + Initialize a sandbox + @fname: str file name + @options: namespace instance of specific options + @custom_methods: { str => func } for custom API implementations + """ + + # Initialize + assert isinstance(fname, basestring) + self.fname = fname + self.options = options + self.loc_db = loc_db + if custom_methods is None: + custom_methods = {} + kwargs["loc_db"] = loc_db + for cls in self.classes: + if cls == Sandbox: + continue + if issubclass(cls, OS): + cls.__init__(self, custom_methods, **kwargs) + else: + cls.__init__(self, **kwargs) + + # Logging options + self.jitter.set_trace_log( + trace_instr=self.options.singlestep, + trace_regs=self.options.singlestep, + trace_new_blocks=self.options.dumpblocs + ) + + if not self.options.quiet_function_calls: + log_func.setLevel(logging.INFO) + + @classmethod + def parser(cls, *args, **kwargs): + """ + Return instance of instance parser with expecting options. + Extra parameters are passed to parser initialisation. + """ + + parser = ArgumentParser(*args, **kwargs) + parser.add_argument('-a', "--address", + help="Force entry point address", default=None) + parser.add_argument('-b', "--dumpblocs", action="store_true", + help="Log disasm blocks") + parser.add_argument('-z', "--singlestep", action="store_true", + help="Log single step") + parser.add_argument('-d', "--debugging", action="store_true", + help="Debug shell") + parser.add_argument('-g', "--gdbserver", type=int, + help="Listen on port @port") + parser.add_argument("-j", "--jitter", + help="Jitter engine. Possible values are: gcc (default), llvm, python", + default="gcc") + parser.add_argument( + '-q', "--quiet-function-calls", action="store_true", + help="Don't log function calls") + parser.add_argument('-i', "--dependencies", action="store_true", + help="Load PE and its dependencies") + + for base_cls in cls._classes_(): + base_cls.update_parser(parser) + return parser + + def run(self, addr=None): + """ + Launch emulation (gdbserver, debugging, basic JIT). + @addr: (int) start address + """ + if addr is None and self.options.address is not None: + addr = int(self.options.address, 0) + + if any([self.options.debugging, self.options.gdbserver]): + dbg = debugging.Debugguer(self.jitter) + self.dbg = dbg + dbg.init_run(addr) + + if self.options.gdbserver: + port = self.options.gdbserver + print("Listen on port %d" % port) + gdb = self.machine.gdbserver(dbg, port) + self.gdb = gdb + gdb.run() + else: + cmd = debugging.DebugCmd(dbg) + self.cmd = cmd + cmd.cmdloop() + + else: + self.jitter.init_run(addr) + self.jitter.continue_run() + + def call(self, prepare_cb, addr, *args): + """ + Direct call of the function at @addr, with arguments @args prepare in + calling convention implemented by @prepare_cb + @prepare_cb: func(ret_addr, *args) + @addr: address of the target function + @args: arguments + """ + self.jitter.init_run(addr) + self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.code_sentinelle) + prepare_cb(self.CALL_FINISH_ADDR, *args) + self.jitter.continue_run() + + + +class OS(object): + + """ + Parent class for OS abstraction + """ + + def __init__(self, custom_methods, **kwargs): + pass + + @classmethod + def update_parser(cls, parser): + pass + + +class Arch(object): + + """ + Parent class for Arch abstraction + """ + + # Architecture name + _ARCH_ = None + + def __init__(self, loc_db, **kwargs): + self.machine = Machine(self._ARCH_) + self.jitter = self.machine.jitter(loc_db, self.options.jitter) + + @classmethod + def update_parser(cls, parser): + pass + + +class OS_Win(OS): + # DLL to import + ALL_IMP_DLL = [ + "ntdll.dll", "kernel32.dll", "user32.dll", + "ole32.dll", "urlmon.dll", + "ws2_32.dll", 'advapi32.dll', "psapi.dll", + ] + modules_path = "win_dll" + + def __init__(self, custom_methods, *args, **kwargs): + from miasm.jitter.loader.pe import vm_load_pe, vm_load_pe_libs,\ + preload_pe, libimp_pe, vm_load_pe_and_dependencies + from miasm.os_dep import win_api_x86_32, win_api_x86_32_seh + methods = dict((name, func) for name, func in viewitems(win_api_x86_32.__dict__)) + methods.update(custom_methods) + + super(OS_Win, self).__init__(methods, *args, **kwargs) + + # Import manager + libs = libimp_pe() + self.libs = libs + win_api_x86_32.winobjs.runtime_dll = libs + + self.name2module = {} + fname_basename = os.path.basename(self.fname).lower() + + # Load main pe + with open(self.fname, "rb") as fstream: + self.pe = vm_load_pe( + self.jitter.vm, + fstream.read(), + load_hdr=self.options.load_hdr, + name=self.fname, + winobjs=win_api_x86_32.winobjs, + **kwargs + ) + self.name2module[fname_basename] = self.pe + + # Load library + if self.options.loadbasedll: + + # Load libs in memory + self.name2module.update( + vm_load_pe_libs( + self.jitter.vm, + self.ALL_IMP_DLL, + libs, + self.modules_path, + winobjs=win_api_x86_32.winobjs, + **kwargs + ) + ) + + # Patch libs imports + for pe in viewvalues(self.name2module): + preload_pe(self.jitter.vm, pe, libs) + + if self.options.dependencies: + vm_load_pe_and_dependencies( + self.jitter.vm, + fname_basename, + self.name2module, + libs, + self.modules_path, + winobjs=win_api_x86_32.winobjs, + **kwargs + ) + + win_api_x86_32.winobjs.current_pe = self.pe + + # Fix pe imports + preload_pe(self.jitter.vm, self.pe, libs) + + # Library calls handler + self.jitter.add_lib_handler(libs, methods) + + # Manage SEH + if self.options.use_windows_structs: + win_api_x86_32_seh.main_pe_name = fname_basename + win_api_x86_32_seh.main_pe = self.pe + win_api_x86_32.winobjs.hcurmodule = self.pe.NThdr.ImageBase + win_api_x86_32_seh.name2module = self.name2module + win_api_x86_32_seh.set_win_fs_0(self.jitter) + win_api_x86_32_seh.init_seh(self.jitter) + + self.entry_point = self.pe.rva2virt( + self.pe.Opthdr.AddressOfEntryPoint) + + @classmethod + def update_parser(cls, parser): + parser.add_argument('-o', "--load-hdr", action="store_true", + help="Load pe hdr") + parser.add_argument('-y', "--use-windows-structs", action="store_true", + help="Create and use windows structures (peb, ldr, seh, ...)") + parser.add_argument('-l', "--loadbasedll", action="store_true", + help="Load base dll (path './win_dll')") + parser.add_argument('-r', "--parse-resources", + action="store_true", help="Load resources") + + +class OS_Linux(OS): + + PROGRAM_PATH = "./program" + + def __init__(self, custom_methods, *args, **kwargs): + from miasm.jitter.loader.elf import vm_load_elf, preload_elf, libimp_elf + from miasm.os_dep import linux_stdlib + methods = linux_stdlib.__dict__ + methods.update(custom_methods) + + super(OS_Linux, self).__init__(methods, *args, **kwargs) + + # Import manager + self.libs = libimp_elf() + + with open(self.fname, "rb") as fstream: + self.elf = vm_load_elf( + self.jitter.vm, + fstream.read(), + name=self.fname, + **kwargs + ) + preload_elf(self.jitter.vm, self.elf, self.libs) + + self.entry_point = self.elf.Ehdr.entry + + # Library calls handler + self.jitter.add_lib_handler(self.libs, methods) + linux_stdlib.ABORT_ADDR = self.CALL_FINISH_ADDR + + # Arguments + self.argv = [self.PROGRAM_PATH] + if self.options.command_line: + self.argv += self.options.command_line + self.envp = self.options.environment_vars + + @classmethod + def update_parser(cls, parser): + parser.add_argument('-c', '--command-line', + action="append", + default=[], + help="Command line arguments") + parser.add_argument('--environment-vars', + action="append", + default=[], + help="Environment variables arguments") + parser.add_argument('--mimic-env', + action="store_true", + help="Mimic the environment of a starting executable") + +class OS_Linux_str(OS): + + PROGRAM_PATH = "./program" + + def __init__(self, custom_methods, *args, **kwargs): + from miasm.jitter.loader.elf import libimp_elf + from miasm.os_dep import linux_stdlib + methods = linux_stdlib.__dict__ + methods.update(custom_methods) + + super(OS_Linux_str, self).__init__(methods, *args, **kwargs) + + # Import manager + libs = libimp_elf() + self.libs = libs + + data = open(self.fname, "rb").read() + self.options.load_base_addr = int(self.options.load_base_addr, 0) + self.jitter.vm.add_memory_page( + self.options.load_base_addr, PAGE_READ | PAGE_WRITE, data, + "Initial Str" + ) + + # Library calls handler + self.jitter.add_lib_handler(libs, methods) + linux_stdlib.ABORT_ADDR = self.CALL_FINISH_ADDR + + # Arguments + self.argv = [self.PROGRAM_PATH] + if self.options.command_line: + self.argv += self.options.command_line + self.envp = self.options.environment_vars + + @classmethod + def update_parser(cls, parser): + parser.add_argument('-c', '--command-line', + action="append", + default=[], + help="Command line arguments") + parser.add_argument('--environment-vars', + action="append", + default=[], + help="Environment variables arguments") + parser.add_argument('--mimic-env', + action="store_true", + help="Mimic the environment of a starting executable") + parser.add_argument("load_base_addr", help="load base address") + + +class Arch_x86(Arch): + _ARCH_ = None # Arch name + STACK_SIZE = 0x10000 + STACK_BASE = 0x130000 + + def __init__(self, loc_db, **kwargs): + super(Arch_x86, self).__init__(loc_db, **kwargs) + + if self.options.usesegm: + self.jitter.lifter.do_stk_segm = True + self.jitter.lifter.do_ds_segm = True + self.jitter.lifter.do_str_segm = True + self.jitter.lifter.do_all_segm = True + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + + @classmethod + def update_parser(cls, parser): + parser.add_argument('-s', "--usesegm", action="store_true", + help="Use segments") + + +class Arch_x86_32(Arch_x86): + _ARCH_ = "x86_32" + + +class Arch_x86_64(Arch_x86): + _ARCH_ = "x86_64" + + +class Arch_arml(Arch): + _ARCH_ = "arml" + STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 + + def __init__(self, loc_db, **kwargs): + super(Arch_arml, self).__init__(loc_db, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + + +class Arch_armb(Arch): + _ARCH_ = "armb" + STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 + + def __init__(self, loc_db, **kwargs): + super(Arch_armb, self).__init__(loc_db, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + + +class Arch_armtl(Arch): + _ARCH_ = "armtl" + STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 + + def __init__(self, loc_db, **kwargs): + super(Arch_armtl, self).__init__(loc_db, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + + +class Arch_mips32b(Arch): + _ARCH_ = "mips32b" + STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 + + def __init__(self, loc_db, **kwargs): + super(Arch_mips32b, self).__init__(loc_db, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + + +class Arch_aarch64l(Arch): + _ARCH_ = "aarch64l" + STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 + + def __init__(self, loc_db, **kwargs): + super(Arch_aarch64l, self).__init__(loc_db, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + + +class Arch_aarch64b(Arch): + _ARCH_ = "aarch64b" + STACK_SIZE = 0x100000 + STACK_BASE = 0x100000 + + def __init__(self, loc_db, **kwargs): + super(Arch_aarch64b, self).__init__(loc_db, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + +class Arch_ppc(Arch): + _ARCH_ = None + +class Arch_ppc32(Arch): + _ARCH_ = None + +class Arch_ppc32b(Arch_ppc32): + _ARCH_ = "ppc32b" + +class Sandbox_Win_x86_32(Sandbox, Arch_x86_32, OS_Win): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + self.jitter.push_uint32_t(2) + self.jitter.push_uint32_t(1) + self.jitter.push_uint32_t(0) + self.jitter.push_uint32_t(self.CALL_FINISH_ADDR) + + # Set the runtime guard + self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.__class__.code_sentinelle) + + def run(self, addr=None): + """ + If addr is not set, use entrypoint + """ + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Win_x86_32, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_stdcall) + super(self.__class__, self).call(prepare_cb, addr, *args) + + +class Sandbox_Win_x86_64(Sandbox, Arch_x86_64, OS_Win): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # reserve stack for local reg + for _ in range(0x4): + self.jitter.push_uint64_t(0) + + # Pre-stack return address + self.jitter.push_uint64_t(self.CALL_FINISH_ADDR) + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + """ + If addr is not set, use entrypoint + """ + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Win_x86_64, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_stdcall) + super(self.__class__, self).call(prepare_cb, addr, *args) + + +class Sandbox_Linux_x86_32(Sandbox, Arch_x86_32, OS_Linux): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.ESP -= len(env) + ptr = self.jitter.cpu.ESP + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.ESP -= len(arg) + ptr = self.jitter.cpu.ESP + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + self.jitter.push_uint32_t(self.CALL_FINISH_ADDR) + self.jitter.push_uint32_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(len(self.argv)) + else: + self.jitter.push_uint32_t(self.CALL_FINISH_ADDR) + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + """ + If addr is not set, use entrypoint + """ + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_x86_32, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) + + + +class Sandbox_Linux_x86_64(Sandbox, Arch_x86_64, OS_Linux): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.RSP -= len(env) + ptr = self.jitter.cpu.RSP + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.RSP -= len(arg) + ptr = self.jitter.cpu.RSP + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + self.jitter.push_uint64_t(self.CALL_FINISH_ADDR) + self.jitter.push_uint64_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint64_t(ptr) + self.jitter.push_uint64_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint64_t(ptr) + self.jitter.push_uint64_t(len(self.argv)) + else: + self.jitter.push_uint64_t(self.CALL_FINISH_ADDR) + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + """ + If addr is not set, use entrypoint + """ + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_x86_64, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) + + +class Sandbox_Linux_arml(Sandbox, Arch_arml, OS_Linux): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.SP -= len(env) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.SP -= len(arg) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + # Round SP to 4 + self.jitter.cpu.SP = self.jitter.cpu.SP & ~ 3 + + self.jitter.push_uint32_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(len(self.argv)) + + self.jitter.cpu.LR = self.CALL_FINISH_ADDR + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_arml, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) + + +class Sandbox_Linux_armtl(Sandbox, Arch_armtl, OS_Linux): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.SP -= len(env) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.SP -= len(arg) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + # Round SP to 4 + self.jitter.cpu.SP = self.jitter.cpu.SP & ~ 3 + + self.jitter.push_uint32_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(len(self.argv)) + + self.jitter.cpu.LR = self.CALL_FINISH_ADDR + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_armtl, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) + + + +class Sandbox_Linux_mips32b(Sandbox, Arch_mips32b, OS_Linux): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.SP -= len(env) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.SP -= len(arg) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + self.jitter.push_uint32_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.push_uint32_t(len(self.argv)) + + self.jitter.cpu.RA = 0x1337beef + + # Set the runtime guard + self.jitter.add_breakpoint( + 0x1337beef, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_mips32b, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) + + +class Sandbox_Linux_armb_str(Sandbox, Arch_armb, OS_Linux_str): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + self.jitter.cpu.LR = self.CALL_FINISH_ADDR + + # Set the runtime guard + self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.__class__.code_sentinelle) + + def run(self, addr=None): + if addr is None and self.options.address is not None: + addr = int(self.options.address, 0) + super(Sandbox_Linux_armb_str, self).run(addr) + + +class Sandbox_Linux_arml_str(Sandbox, Arch_arml, OS_Linux_str): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + self.jitter.cpu.LR = self.CALL_FINISH_ADDR + + # Set the runtime guard + self.jitter.add_breakpoint(self.CALL_FINISH_ADDR, self.__class__.code_sentinelle) + + def run(self, addr=None): + if addr is None and self.options.address is not None: + addr = int(self.options.address, 0) + super(Sandbox_Linux_arml_str, self).run(addr) + + +class Sandbox_Linux_aarch64l(Sandbox, Arch_aarch64l, OS_Linux): + + def __init__(self, loc_db, *args, **kwargs): + Sandbox.__init__(self, loc_db, *args, **kwargs) + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.SP -= len(env) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.SP -= len(arg) + ptr = self.jitter.cpu.SP + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + self.jitter.push_uint64_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint64_t(ptr) + self.jitter.push_uint64_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint64_t(ptr) + self.jitter.push_uint64_t(len(self.argv)) + + self.jitter.cpu.LR = self.CALL_FINISH_ADDR + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_aarch64l, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) + +class Sandbox_Linux_ppc32b(Sandbox, Arch_ppc32b, OS_Linux): + + STACK_SIZE = 0x10000 + STACK_BASE = 0xbfce0000 + + # The glue between the kernel and the ELF ABI on Linux/PowerPC is + # implemented in glibc/sysdeps/powerpc/powerpc32/dl-start.S, so we + # have to play the role of ld.so here. + def __init__(self, loc_db, *args, **kwargs): + super(Sandbox_Linux_ppc32b, self).__init__(loc_db, *args, **kwargs) + + # Init stack + self.jitter.stack_size = self.STACK_SIZE + self.jitter.stack_base = self.STACK_BASE + self.jitter.init_stack() + self.jitter.cpu.R1 -= 8 + + # Pre-stack some arguments + if self.options.mimic_env: + env_ptrs = [] + for env in self.envp: + env = force_bytes(env) + env += b"\x00" + self.jitter.cpu.R1 -= len(env) + ptr = self.jitter.cpu.R1 + self.jitter.vm.set_mem(ptr, env) + env_ptrs.append(ptr) + argv_ptrs = [] + for arg in self.argv: + arg = force_bytes(arg) + arg += b"\x00" + self.jitter.cpu.R1 -= len(arg) + ptr = self.jitter.cpu.R1 + self.jitter.vm.set_mem(ptr, arg) + argv_ptrs.append(ptr) + + self.jitter.push_uint32_t(0) + for ptr in reversed(env_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.cpu.R5 = self.jitter.cpu.R1 # envp + self.jitter.push_uint32_t(0) + for ptr in reversed(argv_ptrs): + self.jitter.push_uint32_t(ptr) + self.jitter.cpu.R4 = self.jitter.cpu.R1 # argv + self.jitter.cpu.R3 = len(self.argv) # argc + self.jitter.push_uint32_t(self.jitter.cpu.R3) + + self.jitter.cpu.R6 = 0 # auxp + self.jitter.cpu.R7 = 0 # termination function + + # From the glibc, we should push a 0 here to distinguish a + # dynamically linked executable from a statically linked one. + # We actually do not do it and attempt to be somehow compatible + # with both types of executables. + #self.jitter.push_uint32_t(0) + + self.jitter.cpu.LR = self.CALL_FINISH_ADDR + + # Set the runtime guard + self.jitter.add_breakpoint( + self.CALL_FINISH_ADDR, + self.__class__.code_sentinelle + ) + + def run(self, addr=None): + """ + If addr is not set, use entrypoint + """ + if addr is None and self.options.address is None: + addr = self.entry_point + super(Sandbox_Linux_ppc32b, self).run(addr) + + def call(self, addr, *args, **kwargs): + """ + Direct call of the function at @addr, with arguments @args + @addr: address of the target function + @args: arguments + """ + prepare_cb = kwargs.pop('prepare_cb', self.jitter.func_prepare_systemv) + super(self.__class__, self).call(prepare_cb, addr, *args) diff --git a/src/miasm/analysis/simplifier.py b/src/miasm/analysis/simplifier.py new file mode 100644 index 00000000..a7c29b06 --- /dev/null +++ b/src/miasm/analysis/simplifier.py @@ -0,0 +1,325 @@ +""" +Apply simplification passes to an IR cfg +""" + +import logging +import warnings +from functools import wraps +from miasm.analysis.ssa import SSADiGraph +from miasm.analysis.outofssa import UnSSADiGraph +from miasm.analysis.data_flow import DiGraphLivenessSSA +from miasm.expression.simplifications import expr_simp +from miasm.ir.ir import AssignBlock, IRBlock +from miasm.analysis.data_flow import DeadRemoval, \ + merge_blocks, remove_empty_assignblks, \ + del_unused_edges, \ + PropagateExpressions, DelDummyPhi + + +log = logging.getLogger("simplifier") +console_handler = logging.StreamHandler() +console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s")) +log.addHandler(console_handler) +log.setLevel(logging.WARNING) + + +def fix_point(func): + @wraps(func) + def ret_func(self, ircfg, head): + log.debug('[%s]: start', func.__name__) + has_been_modified = False + modified = True + while modified: + modified = func(self, ircfg, head) + has_been_modified |= modified + log.debug( + '[%s]: stop %r', + func.__name__, + has_been_modified + ) + return has_been_modified + return ret_func + + +class IRCFGSimplifier(object): + """ + Simplify an IRCFG + This class applies passes until reaching a fix point + """ + + def __init__(self, lifter): + self.lifter = lifter + self.init_passes() + + @property + def ir_arch(self): + warnings.warn('DEPRECATION WARNING: use ".lifter" instead of ".ir_arch"') + return self.lifter + + def init_passes(self): + """ + Init the array of simplification passes + """ + self.passes = [] + + @fix_point + def simplify(self, ircfg, head): + """ + Apply passes until reaching a fix point + Return True if the graph has been modified + + @ircfg: IRCFG instance to simplify + @head: Location instance of the ircfg head + """ + modified = False + for simplify_pass in self.passes: + modified |= simplify_pass(ircfg, head) + return modified + + def __call__(self, ircfg, head): + return self.simplify(ircfg, head) + + +class IRCFGSimplifierCommon(IRCFGSimplifier): + """ + Simplify an IRCFG + This class applies following passes until reaching a fix point: + - simplify_ircfg + - do_dead_simp_ircfg + """ + def __init__(self, lifter, expr_simp=expr_simp): + self.expr_simp = expr_simp + super(IRCFGSimplifierCommon, self).__init__(lifter) + self.deadremoval = DeadRemoval(self.lifter) + + def init_passes(self): + self.passes = [ + self.simplify_ircfg, + self.do_dead_simp_ircfg, + ] + + @fix_point + def simplify_ircfg(self, ircfg, _head): + """ + Apply self.expr_simp on the @ircfg until reaching fix point + Return True if the graph has been modified + + @ircfg: IRCFG instance to simplify + """ + modified = ircfg.simplify(self.expr_simp) + return modified + + @fix_point + def do_dead_simp_ircfg(self, ircfg, head): + """ + Apply: + - dead_simp + - remove_empty_assignblks + - merge_blocks + on the @ircfg until reaching fix point + Return True if the graph has been modified + + @ircfg: IRCFG instance to simplify + @head: Location instance of the ircfg head + """ + modified = self.deadremoval(ircfg) + modified |= remove_empty_assignblks(ircfg) + modified |= merge_blocks(ircfg, set([head])) + return modified + + +class IRCFGSimplifierSSA(IRCFGSimplifierCommon): + """ + Simplify an IRCFG. + The IRCF is first transformed in SSA, then apply transformations passes + and apply out-of-ssa. Final passes of IRcfgSimplifier are applied + + This class apply following pass until reaching a fix point: + - do_propagate_expressions + - do_dead_simp_ssa + """ + + def __init__(self, lifter, expr_simp=expr_simp): + super(IRCFGSimplifierSSA, self).__init__(lifter, expr_simp) + + self.lifter.ssa_var = {} + self.all_ssa_vars = {} + + self.ssa_forbidden_regs = self.get_forbidden_regs() + + self.propag_expressions = PropagateExpressions() + self.del_dummy_phi = DelDummyPhi() + + self.deadremoval = DeadRemoval(self.lifter, self.all_ssa_vars) + + def get_forbidden_regs(self): + """ + Return a set of immutable register during SSA transformation + """ + regs = set( + [ + self.lifter.pc, + self.lifter.IRDst, + self.lifter.arch.regs.exception_flags + ] + ) + return regs + + def init_passes(self): + """ + Init the array of simplification passes + """ + self.passes = [ + self.simplify_ssa, + self.do_propagate_expressions, + self.do_del_dummy_phi, + self.do_dead_simp_ssa, + self.do_remove_empty_assignblks, + self.do_del_unused_edges, + self.do_merge_blocks, + ] + + + + def ircfg_to_ssa(self, ircfg, head): + """ + Apply the SSA transformation to @ircfg using it's @head + + @ircfg: IRCFG instance to simplify + @head: Location instance of the ircfg head + """ + ssa = SSADiGraph(ircfg) + ssa.immutable_ids.update(self.ssa_forbidden_regs) + ssa.ssa_variable_to_expr.update(self.all_ssa_vars) + ssa.transform(head) + self.all_ssa_vars.update(ssa.ssa_variable_to_expr) + self.lifter.ssa_var.update(ssa.ssa_variable_to_expr) + return ssa + + def ssa_to_unssa(self, ssa, head): + """ + Apply the out-of-ssa transformation to @ssa using it's @head + + @ssa: SSADiGraph instance + @head: Location instance of the graph head + """ + cfg_liveness = DiGraphLivenessSSA(ssa.graph) + cfg_liveness.init_var_info(self.lifter) + cfg_liveness.compute_liveness() + + UnSSADiGraph(ssa, head, cfg_liveness) + return ssa.graph + + @fix_point + def simplify_ssa(self, ssa, _head): + """ + Apply self.expr_simp on the @ssa.graph until reaching fix point + Return True if the graph has been modified + + @ssa: SSADiGraph instance + """ + modified = ssa.graph.simplify(self.expr_simp) + return modified + + @fix_point + def do_del_unused_edges(self, ssa, head): + """ + Del unused edges of the ssa graph + @head: Location instance of the graph head + """ + modified = del_unused_edges(ssa.graph, set([head])) + return modified + + def do_propagate_expressions(self, ssa, head): + """ + Expressions propagation through ExprId in the @ssa graph + @head: Location instance of the graph head + """ + modified = self.propag_expressions.propagate(ssa, head) + return modified + + @fix_point + def do_del_dummy_phi(self, ssa, head): + """ + Del dummy phi + @head: Location instance of the graph head + """ + modified = self.del_dummy_phi.del_dummy_phi(ssa, head) + return modified + + @fix_point + def do_remove_empty_assignblks(self, ssa, head): + """ + Remove empty assignblks + @head: Location instance of the graph head + """ + modified = remove_empty_assignblks(ssa.graph) + return modified + + @fix_point + def do_merge_blocks(self, ssa, head): + """ + Merge blocks with one parent/son + @head: Location instance of the graph head + """ + modified = merge_blocks(ssa.graph, set([head])) + return modified + + @fix_point + def do_dead_simp_ssa(self, ssa, head): + """ + Apply: + - deadrm + - remove_empty_assignblks + - del_unused_edges + - merge_blocks + on the @ircfg until reaching fix point + Return True if the graph has been modified + + @ircfg: IRCFG instance to simplify + @head: Location instance of the ircfg head + """ + modified = self.deadremoval(ssa.graph) + return modified + + def do_simplify(self, ssa, head): + """ + Apply passes until reaching a fix point + Return True if the graph has been modified + """ + return super(IRCFGSimplifierSSA, self).simplify(ssa, head) + + def do_simplify_loop(self, ssa, head): + """ + Apply do_simplify until reaching a fix point + SSA is updated between each do_simplify + Return True if the graph has been modified + """ + modified = True + while modified: + modified = self.do_simplify(ssa, head) + # Update ssa structs + ssa = self.ircfg_to_ssa(ssa.graph, head) + return ssa + + def simplify(self, ircfg, head): + """ + Add access to "abi out regs" in each leaf block + Apply SSA transformation to @ircfg + Apply passes until reaching a fix point + Apply out-of-ssa transformation + Apply post simplification passes + + Updated simplified IRCFG instance and return it + + @ircfg: IRCFG instance to simplify + @head: Location instance of the ircfg head + """ + + ssa = self.ircfg_to_ssa(ircfg, head) + ssa = self.do_simplify_loop(ssa, head) + ircfg = self.ssa_to_unssa(ssa, head) + ircfg_simplifier = IRCFGSimplifierCommon(self.lifter) + ircfg_simplifier.deadremoval.add_expr_to_original_expr(self.all_ssa_vars) + ircfg_simplifier.simplify(ircfg, head) + return ircfg diff --git a/src/miasm/analysis/ssa.py b/src/miasm/analysis/ssa.py new file mode 100644 index 00000000..5c1964ef --- /dev/null +++ b/src/miasm/analysis/ssa.py @@ -0,0 +1,731 @@ +from collections import deque +from future.utils import viewitems, viewvalues + +from miasm.expression.expression import ExprId, ExprAssign, ExprOp, \ + ExprLoc, get_expr_ids +from miasm.ir.ir import AssignBlock, IRBlock + + +def sanitize_graph_head(ircfg, head): + """ + In multiple algorithm, the @head of the ircfg may not have predecessors. + The function transform the @ircfg in order to ensure this property + @ircfg: IRCFG instance + @head: the location of the graph's head + """ + + if not ircfg.predecessors(head): + return + original_edges = ircfg.predecessors(head) + sub_head = ircfg.loc_db.add_location() + + # Duplicate graph, replacing references to head by sub_head + replaced_expr = { + ExprLoc(head, ircfg.IRDst.size): + ExprLoc(sub_head, ircfg.IRDst.size) + } + ircfg.simplify( + lambda expr:expr.replace_expr(replaced_expr) + ) + # Duplicate head block + ircfg.add_irblock(IRBlock(ircfg.loc_db, sub_head, list(ircfg.blocks[head]))) + + # Remove original head block + ircfg.del_node(head) + + for src in original_edges: + ircfg.add_edge(src, sub_head) + + # Create new head, jumping to sub_head + assignblk = AssignBlock({ircfg.IRDst:ExprLoc(sub_head, ircfg.IRDst.size)}) + new_irblock = IRBlock(ircfg.loc_db, head, [assignblk]) + ircfg.add_irblock(new_irblock) + + +class SSA(object): + """ + Generic class for static single assignment (SSA) transformation + + Handling of + - variable generation + - variable renaming + - conversion of an IRCFG block into SSA + + Variables will be renamed to <variable>.<index>, whereby the + index will be increased in every definition of <variable>. + + Memory expressions are stateless. The addresses are in SSA form, + but memory aliasing will occur. For instance, if it holds + that RAX == RBX.0 + (-0x8) and + + @64[RBX.0 + (-0x8)] = RDX + RCX.0 = @64[RAX], + + then it cannot be tracked that RCX.0 == RDX. + """ + + + def __init__(self, ircfg): + """ + Initialises generic class for SSA + :param ircfg: instance of IRCFG + """ + # IRCFG instance + self.ircfg = ircfg + + # stack for RHS + self._stack_rhs = {} + # stack for LHS + self._stack_lhs = {} + + self.ssa_variable_to_expr = {} + + # dict of SSA expressions + self.expressions = {} + + # dict of SSA to original location + self.ssa_to_location = {} + + # Don't SSA IRDst + self.immutable_ids = set([self.ircfg.IRDst]) + + def get_regs(self, expr): + return get_expr_ids(expr) + + def transform(self, *args, **kwargs): + """Transforms into SSA""" + raise NotImplementedError("Abstract method") + + def get_block(self, loc_key): + """ + Returns an IRBlock + :param loc_key: LocKey instance + :return: IRBlock + """ + irblock = self.ircfg.blocks.get(loc_key, None) + + return irblock + + def reverse_variable(self, ssa_var): + """ + Transforms a variable in SSA form into non-SSA form + :param ssa_var: ExprId, variable in SSA form + :return: ExprId, variable in non-SSA form + """ + expr = self.ssa_variable_to_expr.get(ssa_var, ssa_var) + return expr + + def reset(self): + """Resets SSA transformation""" + self.expressions = {} + self._stack_rhs = {} + self._stack_lhs = {} + self.ssa_to_location = {} + + def _gen_var_expr(self, expr, stack): + """ + Generates a variable expression in SSA form + :param expr: variable expression which will be translated + :param stack: self._stack_rhs or self._stack_lhs + :return: variable expression in SSA form + """ + index = stack[expr] + name = "%s.%d" % (expr.name, index) + ssa_var = ExprId(name, expr.size) + self.ssa_variable_to_expr[ssa_var] = expr + + return ssa_var + + def _transform_var_rhs(self, ssa_var): + """ + Transforms a variable on the right hand side into SSA + :param ssa_var: variable + :return: transformed variable + """ + # variable has never been on the LHS + if ssa_var not in self._stack_rhs: + return ssa_var + # variable has been on the LHS + stack = self._stack_rhs + return self._gen_var_expr(ssa_var, stack) + + def _transform_var_lhs(self, expr): + """ + Transforms a variable on the left hand side into SSA + :param expr: variable + :return: transformed variable + """ + # check if variable has already been on the LHS + if expr not in self._stack_lhs: + self._stack_lhs[expr] = 0 + # save last value for RHS transformation + self._stack_rhs[expr] = self._stack_lhs[expr] + + # generate SSA expression + stack = self._stack_lhs + ssa_var = self._gen_var_expr(expr, stack) + + return ssa_var + + def _transform_expression_lhs(self, dst): + """ + Transforms an expression on the left hand side into SSA + :param dst: expression + :return: expression in SSA form + """ + if dst.is_mem(): + # transform with last RHS instance + ssa_var = self._transform_expression_rhs(dst) + else: + # transform LHS + ssa_var = self._transform_var_lhs(dst) + + # increase SSA variable counter + self._stack_lhs[dst] += 1 + + return ssa_var + + def _transform_expression_rhs(self, src): + """ + Transforms an expression on the right hand side into SSA + :param src: expression + :return: expression in SSA form + """ + # dissect expression in variables + variables = self.get_regs(src) + src_ssa = src + # transform variables + to_replace = {} + for expr in variables: + ssa_var = self._transform_var_rhs(expr) + to_replace[expr] = ssa_var + src_ssa = src_ssa.replace_expr(to_replace) + + return src_ssa + + @staticmethod + def _parallel_instructions(assignblk): + """ + Extracts the instruction from a AssignBlock. + + Since instructions in a AssignBlock are evaluated + in parallel, memory instructions on the left hand + side will be inserted into the start of the list. + Then, memory instruction on the LHS will be + transformed firstly. + + :param assignblk: assignblock + :return: sorted list of expressions + """ + instructions = [] + for dst in assignblk: + # dst = src + aff = assignblk.dst2ExprAssign(dst) + # insert memory expression into start of list + if dst.is_mem(): + instructions.insert(0, aff) + else: + instructions.append(aff) + + return instructions + + @staticmethod + def _convert_block(irblock, ssa_list): + """ + Transforms an IRBlock inplace into SSA + :param irblock: IRBlock to be transformed + :param ssa_list: list of SSA expressions + """ + # iterator over SSA expressions + ssa_iter = iter(ssa_list) + new_irs = [] + # walk over IR blocks' assignblocks + for assignblk in irblock.assignblks: + # list of instructions + instructions = [] + # insert SSA instructions + for _ in assignblk: + instructions.append(next(ssa_iter)) + # replace instructions of assignblock in IRBlock + new_irs.append(AssignBlock(instructions, assignblk.instr)) + return IRBlock(irblock.loc_db, irblock.loc_key, new_irs) + + def _rename_expressions(self, loc_key): + """ + Transforms variables and expressions + of an IRBlock into SSA. + + IR representations of an assembly instruction are evaluated + in parallel. Thus, RHS and LHS instructions will be performed + separately. + :param loc_key: IRBlock loc_key + """ + # list of IRBlock's SSA expressions + ssa_expressions_block = [] + + # retrieve IRBlock + irblock = self.get_block(loc_key) + if irblock is None: + # Incomplete graph + return + + # iterate block's IR expressions + for index, assignblk in enumerate(irblock.assignblks): + # list of parallel instructions + instructions = self._parallel_instructions(assignblk) + # list for transformed RHS expressions + rhs = deque() + + # transform RHS + for expr in instructions: + src = expr.src + src_ssa = self._transform_expression_rhs(src) + # save transformed RHS + rhs.append(src_ssa) + + # transform LHS + for expr in instructions: + if expr.dst in self.immutable_ids or expr.dst in self.ssa_variable_to_expr: + dst_ssa = expr.dst + else: + dst_ssa = self._transform_expression_lhs(expr.dst) + + # retrieve corresponding RHS expression + src_ssa = rhs.popleft() + + # rebuild SSA expression + expr = ExprAssign(dst_ssa, src_ssa) + self.expressions[dst_ssa] = src_ssa + self.ssa_to_location[dst_ssa] = (loc_key, index) + + + # append ssa expression to list + ssa_expressions_block.append(expr) + + # replace blocks IR expressions with corresponding SSA transformations + new_irblock = self._convert_block(irblock, ssa_expressions_block) + self.ircfg.blocks[loc_key] = new_irblock + + +class SSABlock(SSA): + """ + SSA transformation on block level + + It handles + - transformation of a single IRBlock into SSA + - reassembling an SSA expression into a non-SSA + expression through iterative resolving of the RHS + """ + + def transform(self, loc_key): + """ + Transforms a block into SSA form + :param loc_key: IRBlock loc_key + """ + self._rename_expressions(loc_key) + + def reassemble_expr(self, expr): + """ + Reassembles an expression in SSA form into a solely non-SSA expression + :param expr: expression + :return: non-SSA expression + """ + # worklist + todo = {expr.copy()} + + while todo: + # current expression + cur = todo.pop() + # RHS of current expression + cur_rhs = self.expressions[cur] + + # replace cur with RHS in expr + expr = expr.replace_expr({cur: cur_rhs}) + + # parse ExprIDs on RHS + ids_rhs = self.get_regs(cur_rhs) + + # add RHS ids to worklist + for id_rhs in ids_rhs: + if id_rhs in self.expressions: + todo.add(id_rhs) + return expr + + +class SSAPath(SSABlock): + """ + SSA transformation on path level + + It handles + - transformation of a path of IRBlocks into SSA + """ + + def transform(self, path): + """ + Transforms a path into SSA + :param path: list of IRBlock loc_key + """ + for block in path: + self._rename_expressions(block) + + +class SSADiGraph(SSA): + """ + SSA transformation on DiGraph level + + It handles + - transformation of a DiGraph into SSA + - generation, insertion and filling of phi nodes + + The implemented SSA form is known as minimal SSA. + """ + + PHI_STR = 'Phi' + + + def __init__(self, ircfg): + """ + Initialises SSA class for directed graphs + :param ircfg: instance of IRCFG + """ + super(SSADiGraph, self).__init__(ircfg) + + # variable definitions + self.defs = {} + + # dict of blocks' phi nodes + self._phinodes = {} + + # IRCFG control flow graph + self.graph = ircfg + + + def transform(self, head): + """Transforms into SSA""" + sanitize_graph_head(self.graph, head) + self._init_variable_defs(head) + self._place_phi(head) + self._rename(head) + self._insert_phi() + self._convert_phi() + self._fix_no_def_var(head) + + def reset(self): + """Resets SSA transformation""" + super(SSADiGraph, self).reset() + self.defs = {} + self._phinodes = {} + + def _init_variable_defs(self, head): + """ + Initialises all variable definitions and + assigns the corresponding IRBlocks. + + All variable definitions in self.defs contain + a set of IRBlocks in which the variable gets assigned + """ + + visited_loc = set() + for loc_key in self.graph.walk_depth_first_forward(head): + irblock = self.get_block(loc_key) + if irblock is None: + # Incomplete graph + continue + visited_loc.add(loc_key) + # search for block's IR definitions/destinations + for assignblk in irblock.assignblks: + for dst in assignblk: + # enforce ExprId + if dst.is_id(): + # exclude immutable ids + if dst in self.immutable_ids or dst in self.ssa_variable_to_expr: + continue + # map variable definition to blocks + self.defs.setdefault(dst, set()).add(irblock.loc_key) + if visited_loc != set(self.graph.blocks): + raise RuntimeError("Cannot operate on a non connected graph") + + def _place_phi(self, head): + """ + For all blocks, empty phi functions will be placed for every + variable in the block's dominance frontier. + + self.phinodes contains a dict for every block in the + dominance frontier. In this dict, each variable + definition maps to its corresponding phi function. + + Source: Cytron, Ron, et al. + "An efficient method of computing static single assignment form" + Proceedings of the 16th ACM SIGPLAN-SIGACT symposium on + Principles of programming languages (1989), p. 30 + """ + # dominance frontier + frontier = self.graph.compute_dominance_frontier(head) + + for variable in self.defs: + done = set() + todo = set() + intodo = set() + + for loc_key in self.defs[variable]: + todo.add(loc_key) + intodo.add(loc_key) + + while todo: + loc_key = todo.pop() + + # walk through block's dominance frontier + for node in frontier.get(loc_key, []): + if node in done: + continue + # place empty phi functions for a variable + empty_phi = self._gen_empty_phi(variable) + + # add empty phi node for variable in node + self._phinodes.setdefault(node, {})[variable] = empty_phi.src + done.add(node) + + if node not in intodo: + intodo.add(node) + todo.add(node) + + def _gen_empty_phi(self, expr): + """ + Generates an empty phi function for a variable + :param expr: variable + :return: ExprAssign, empty phi function for expr + """ + phi = ExprId(self.PHI_STR, expr.size) + return ExprAssign(expr, phi) + + def _fill_phi(self, *args): + """ + Fills a phi function with variables. + + phi(x.1, x.5, x.6) + + :param args: list of ExprId + :return: ExprOp + """ + return ExprOp(self.PHI_STR, *set(args)) + + def _rename(self, head): + """ + Transforms each variable expression in the CFG into SSA + by traversing the dominator tree in depth-first search. + + 1. Transform variables of phi functions on LHS into SSA + 2. Transform all non-phi expressions into SSA + 3. Update the successor's phi functions' RHS with current SSA variables + 4. Save current SSA variable stack for successors in the dominator tree + + Source: Cytron, Ron, et al. + "An efficient method of computing static single assignment form" + Proceedings of the 16th ACM SIGPLAN-SIGACT symposium on + Principles of programming languages (1989), p. 31 + """ + # compute dominator tree + dominator_tree = self.graph.compute_dominator_tree(head) + + # init SSA variable stack + stack = [self._stack_rhs] + + # walk in DFS over the dominator tree + for loc_key in dominator_tree.walk_depth_first_forward(head): + # restore SSA variable stack of the predecessor in the dominator tree + self._stack_rhs = stack.pop().copy() + + # Transform variables of phi functions on LHS into SSA + self._rename_phi_lhs(loc_key) + + # Transform all non-phi expressions into SSA + self._rename_expressions(loc_key) + + # Update the successor's phi functions' RHS with current SSA variables + # walk over block's successors in the CFG + for successor in self.graph.successors_iter(loc_key): + self._rename_phi_rhs(successor) + + # Save current SSA variable stack for successors in the dominator tree + for _ in dominator_tree.successors_iter(loc_key): + stack.append(self._stack_rhs) + + def _rename_phi_lhs(self, loc_key): + """ + Transforms phi function's expressions of an IRBlock + on the left hand side into SSA + :param loc_key: IRBlock loc_key + """ + if loc_key in self._phinodes: + # create temporary list of phi function assignments for inplace renaming + tmp = list(self._phinodes[loc_key]) + + # iterate over all block's phi nodes + for dst in tmp: + # transform variables on LHS inplace + self._phinodes[loc_key][self._transform_expression_lhs(dst)] = self._phinodes[loc_key].pop(dst) + + def _rename_phi_rhs(self, successor): + """ + Transforms the right hand side of each successor's phi function + into SSA. Each transformed expression of a phi function's + right hand side is of the form + + phi(<var>.<index 1>, <var>.<index 2>, ..., <var>.<index n>) + + :param successor: loc_key of block's direct successor in the CFG + """ + # if successor is in block's dominance frontier + if successor in self._phinodes: + # walk over all variables on LHS + for dst, src in list(viewitems(self._phinodes[successor])): + # transform variable on RHS in non-SSA form + expr = self.reverse_variable(dst) + # transform expr into it's SSA form using current stack + src_ssa = self._transform_expression_rhs(expr) + + # Add src_ssa to phi args + if src.is_id(self.PHI_STR): + # phi function is empty + expr = self._fill_phi(src_ssa) + else: + # phi function contains at least one value + expr = self._fill_phi(src_ssa, *src.args) + + # update phi function + self._phinodes[successor][dst] = expr + + def _insert_phi(self): + """Inserts phi functions into the list of SSA expressions""" + for loc_key in self._phinodes: + for dst in self._phinodes[loc_key]: + self.expressions[dst] = self._phinodes[loc_key][dst] + + def _convert_phi(self): + """Inserts corresponding phi functions inplace + into IRBlock at the beginning""" + for loc_key in self._phinodes: + irblock = self.get_block(loc_key) + if irblock is None: + continue + assignblk = AssignBlock(self._phinodes[loc_key]) + if irblock_has_phi(irblock): + # If first block contains phi, we are updating an existing ssa form + # so update phi + assignblks = list(irblock.assignblks) + out = dict(assignblks[0]) + out.update(dict(assignblk)) + assignblks[0] = AssignBlock(out, assignblk.instr) + new_irblock = IRBlock(self.ircfg.loc_db, loc_key, assignblks) + else: + # insert at the beginning + new_irblock = IRBlock(self.ircfg.loc_db, loc_key, [assignblk] + list(irblock.assignblks)) + self.ircfg.blocks[loc_key] = new_irblock + + def _fix_no_def_var(self, head): + """ + Replace phi source variables which are not ssa vars by ssa vars. + @head: loc_key of the graph head + """ + var_to_insert = set() + for loc_key in self._phinodes: + for dst, sources in viewitems(self._phinodes[loc_key]): + for src in sources.args: + if src in self.ssa_variable_to_expr: + continue + var_to_insert.add(src) + var_to_newname = {} + newname_to_var = {} + for var in var_to_insert: + new_var = self._transform_var_lhs(var) + var_to_newname[var] = new_var + newname_to_var[new_var] = var + + # Replace non modified node used in phi with new variable + self.ircfg.simplify(lambda expr:expr.replace_expr(var_to_newname)) + + if newname_to_var: + irblock = self.ircfg.blocks[head] + assignblks = list(irblock) + assignblks[0:0] = [AssignBlock(newname_to_var, assignblks[0].instr)] + self.ircfg.blocks[head] = IRBlock(self.ircfg.loc_db, head, assignblks) + + # Updt structure + for loc_key in self._phinodes: + for dst, sources in viewitems(self._phinodes[loc_key]): + self._phinodes[loc_key][dst] = sources.replace_expr(var_to_newname) + + for var, (loc_key, index) in list(viewitems(self.ssa_to_location)): + if loc_key == head: + self.ssa_to_location[var] = loc_key, index + 1 + + for newname, var in viewitems(newname_to_var): + self.ssa_to_location[newname] = head, 0 + self.ssa_variable_to_expr[newname] = var + self.expressions[newname] = var + + +def irblock_has_phi(irblock): + """ + Return True if @irblock has Phi assignments + @irblock: IRBlock instance + """ + if not irblock.assignblks: + return False + for src in viewvalues(irblock[0]): + return src.is_op('Phi') + return False + + +class Varinfo(object): + """Store liveness information for a variable""" + __slots__ = ["live_index", "loc_key", "index"] + + def __init__(self, live_index, loc_key, index): + self.live_index = live_index + self.loc_key = loc_key + self.index = index + + +def get_var_assignment_src(ircfg, node, variables): + """ + Return the variable of @variables which is written by the irblock at @node + @node: Location + @variables: a set of variable to test + """ + irblock = ircfg.blocks[node] + for assignblk in irblock: + result = set(assignblk).intersection(variables) + if not result: + continue + assert len(result) == 1 + return list(result)[0] + return None + + +def get_phi_sources_parent_block(ircfg, loc_key, sources): + """ + Return a dictionary linking a variable to it's direct parent label + which belong to a path which affects the node. + @loc_key: the starting node + @sources: set of variables to resolve + """ + source_to_parent = {} + for parent in ircfg.predecessors(loc_key): + done = set() + todo = set([parent]) + found = False + while todo: + node = todo.pop() + if node in done: + continue + done.add(node) + ret = get_var_assignment_src(ircfg, node, sources) + if ret: + source_to_parent.setdefault(ret, set()).add(parent) + found = True + break + for pred in ircfg.predecessors(node): + todo.add(pred) + assert found + return source_to_parent |