diff options
| author | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2024-01-17 17:08:48 +0100 |
|---|---|---|
| committer | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2024-01-17 17:08:48 +0100 |
| commit | 605e12fc5cf0fb64e45f68378390a09aa28df2f9 (patch) | |
| tree | b8ba3c3c6f0776b5b2cd99fee7811d3c27af3f1d | |
| parent | eae0b3b08bd078ad2f621ce2ef201e656da3f16a (diff) | |
| download | focaccia-605e12fc5cf0fb64e45f68378390a09aa28df2f9.tar.gz focaccia-605e12fc5cf0fb64e45f68378390a09aa28df2f9.zip | |
Refactor symbolic transformation handling
| -rwxr-xr-x | focaccia.py | 42 | ||||
| -rw-r--r-- | focaccia/compare.py | 148 | ||||
| -rw-r--r-- | focaccia/lldb_target.py | 4 | ||||
| -rw-r--r-- | focaccia/miasm_util.py | 19 | ||||
| -rw-r--r-- | focaccia/parser.py | 21 | ||||
| -rw-r--r-- | focaccia/snapshot.py | 22 | ||||
| -rw-r--r-- | focaccia/symbolic.py | 354 | ||||
| -rw-r--r-- | focaccia/utils.py | 36 | ||||
| -rw-r--r-- | tools/capture_transforms.py | 27 |
9 files changed, 464 insertions, 209 deletions
diff --git a/focaccia.py b/focaccia.py index e140337..bbd1317 100755 --- a/focaccia.py +++ b/focaccia.py @@ -11,7 +11,7 @@ from focaccia.lldb_target import LLDBConcreteTarget from focaccia.parser import parse_arancini from focaccia.snapshot import ProgramState from focaccia.symbolic import SymbolicTransform, collect_symbolic_trace -from focaccia.utils import print_separator +from focaccia.utils import print_result def run_native_execution(oracle_program: str, breakpoints: Iterable[int]): """Gather snapshots from a native execution via an external debugger. @@ -42,7 +42,7 @@ def match_traces(test: list[ProgramState], truth: list[SymbolicTransform]): if not test or not truth: return [], [] - assert(test[0].read('pc') == truth[0].addr) + assert(test[0].read_register('pc') == truth[0].addr) def index(seq, target, access=lambda el: el): for i, el in enumerate(seq): @@ -52,7 +52,7 @@ def match_traces(test: list[ProgramState], truth: list[SymbolicTransform]): i = 0 for next_state in test[1:]: - next_pc = next_state.read('pc') + next_pc = next_state.read_register('pc') index_in_truth = index(truth[i:], next_pc, lambda el: el.range[1]) # If no next element (i.e. no foldable range) is found in the truth @@ -83,7 +83,7 @@ def parse_inputs(txl_path, program): txl = parse_arancini(txl_file, arch) with open(txl_path, "r") as txl_file: - breakpoints = [state.read('PC') for state in txl] + breakpoints = [state.read_register('PC') for state in txl] ref = run_native_execution(program, breakpoints) return txl, ref @@ -124,40 +124,6 @@ def parse_arguments(): args = parser.parse_args() return args -def print_result(result, min_severity: ErrorSeverity): - shown = 0 - suppressed = 0 - - for res in result: - pc = res['pc'] - print_separator() - print(f'For PC={hex(pc)}') - print_separator() - - # Filter errors by severity - errs = [e for e in res['errors'] if e.severity >= min_severity] - suppressed += len(res['errors']) - len(errs) - shown += len(errs) - - # Print all non-suppressed errors - for n, err in enumerate(errs, start=1): - print(f' {n:2}. {err}') - - if errs: - print() - print(f'Expected transformation: {res["ref"]}') - print(f'Actual transformation: {res["txl"]}') - else: - print('No errors found.') - - print() - print('#' * 60) - print(f'Found {shown} errors.') - print(f'Suppressed {suppressed} low-priority errors' - f' (showing {min_severity} and higher).') - print('#' * 60) - print() - def main(): verbosity = { 'verbose': ErrorTypes.INFO, diff --git a/focaccia/compare.py b/focaccia/compare.py index 066127d..36cd54e 100644 --- a/focaccia/compare.py +++ b/focaccia/compare.py @@ -1,7 +1,7 @@ from functools import total_ordering -from typing import Self +from typing import Iterable, Self -from .snapshot import ProgramState, MemoryAccessError +from .snapshot import ProgramState, MemoryAccessError, RegisterAccessError from .symbolic import SymbolicTransform @total_ordering @@ -57,10 +57,10 @@ def _calc_transformation(previous: ProgramState, current: ProgramState): transformation = ProgramState(arch) for reg in arch.regnames: try: - prev_val, cur_val = previous.read(reg), current.read(reg) + prev_val, cur_val = previous.read_register(reg), current.read_register(reg) if prev_val is not None and cur_val is not None: - transformation.set(reg, cur_val - prev_val) - except ValueError: + transformation.set_register(reg, cur_val - prev_val) + except RegisterAccessError: # Register is not set in either state pass @@ -86,11 +86,12 @@ def _find_errors(transform_txl: ProgramState, transform_truth: ProgramState) \ errors = [] for reg in transform_truth.arch.regnames: try: - diff_txl = transform_txl.read(reg) - diff_truth = transform_truth.read(reg) - except ValueError: + diff_txl = transform_txl.read_register(reg) + diff_truth = transform_truth.read_register(reg) + except RegisterAccessError: errors.append(Error(ErrorTypes.INFO, - f'Value for register {reg} is not set in' + f'Unable to calculate difference:' + f' Value for register {reg} is not set in' f' either the tested or the reference state.')) continue @@ -123,7 +124,7 @@ def compare_simple(test_states: list[ProgramState], # No errors in initial snapshot because we can't perform difference # calculations on it result = [{ - 'pc': test_states[0].read(PC_REGNAME), + 'pc': test_states[0].read_register(PC_REGNAME), 'txl': test_states[0], 'ref': truth_states[0], 'errors': [] }] @@ -134,15 +135,15 @@ def compare_simple(test_states: list[ProgramState], for txl, truth in it_cur: prev_txl, prev_truth = next(it_prev) - pc_txl = txl.read(PC_REGNAME) - pc_truth = truth.read(PC_REGNAME) + pc_txl = txl.read_register(PC_REGNAME) + pc_truth = truth.read_register(PC_REGNAME) # The program counter should always be set on a snapshot assert(pc_truth is not None) assert(pc_txl is not None) if pc_txl != pc_truth: - print(f'Unmatched program counter {hex(txl.read(PC_REGNAME))}' + print(f'Unmatched program counter {hex(txl.read_register(PC_REGNAME))}' f' in translated code!') continue @@ -171,7 +172,7 @@ def _find_register_errors(txl_from: ProgramState, """ # Calculate expected register values try: - truth = transform_truth.calc_register_transform(txl_from) + truth = transform_truth.eval_register_transforms(txl_from) except MemoryAccessError as err: s, e = transform_truth.range return [Error( @@ -179,18 +180,24 @@ def _find_register_errors(txl_from: ProgramState, f'Register transformations {hex(s)} -> {hex(e)} depend on' f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}' f' that are not entirely present in the tested state' - f' {hex(txl_from.read("pc"))}. Skipping.', + f' {hex(txl_from.read_register("pc"))}.', )] + except RegisterAccessError as err: + s, e = transform_truth.range + return [Error(ErrorTypes.INCOMPLETE, + f'Register transformations {hex(s)} -> {hex(e)} depend' + f' on the value of register {err.regname}, which is not' + f' set in the tested state.')] # Compare expected values to actual values in the tested state errors = [] for regname, truth_val in truth.items(): try: - txl_val = txl_to.read(regname) - except ValueError: + txl_val = txl_to.read_register(regname) + except RegisterAccessError: errors.append(Error(ErrorTypes.INCOMPLETE, f'Value of register {regname} has changed, but' - f' is not set in the tested state. Skipping.')) + f' is not set in the tested state.')) continue except KeyError as err: print(f'[WARNING] {err}') @@ -217,14 +224,20 @@ def _find_memory_errors(txl_from: ProgramState, """ # Calculate expected register values try: - truth = transform_truth.calc_memory_transform(txl_from) + truth = transform_truth.eval_memory_transforms(txl_from) except MemoryAccessError as err: s, e = transform_truth.range return [Error(ErrorTypes.INCOMPLETE, f'Memory transformations {hex(s)} -> {hex(e)} depend on' f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}' - f' that are not entirely present in the tested state' - f' {hex(txl_from.read("pc"))}. Skipping.')] + f' that are not entirely present in the tested state at' + f' {hex(txl_from.read_register("pc"))}.')] + except RegisterAccessError as err: + s, e = transform_truth.range + return [Error(ErrorTypes.INCOMPLETE, + f'Memory transformations {hex(s)} -> {hex(e)} depend on' + f' the value of register {err.regname}, which is not' + f' set in the tested state.')] # Compare expected values to actual values in the tested state errors = [] @@ -234,15 +247,19 @@ def _find_memory_errors(txl_from: ProgramState, txl_bytes = txl_to.read_memory(addr, size) except MemoryAccessError: errors.append(Error(ErrorTypes.POSSIBLE, - f'Memory range [{addr}, {addr + size}) is not' - f' set in the tested result state. Skipping.')) + f'Memory range [{hex(addr)}, {hex(addr + size)})' + f' is not set in the tested result state at' + f' {hex(txl_to.read_register("pc"))}. This is' + f' either an error in the translation or' + f' the recorded test state is missing data.')) continue if txl_bytes != truth_bytes: errors.append(Error(ErrorTypes.CONFIRMED, - f'Content of memory at {addr} is false.' - f' Expected content: {truth_bytes.hex()}, actual' - f' content in the translation: {txl_bytes.hex()}.')) + f'Content of memory at {hex(addr)} is false.' + f' Expected content: {truth_bytes.hex()},' + f' actual content in the translation:' + f' {txl_bytes.hex()}.')) return errors def _find_errors_symbolic(txl_from: ProgramState, @@ -263,13 +280,13 @@ def _find_errors_symbolic(txl_from: ProgramState, :param transform_truth: The symbolic transformation that maps the source state to the destination state. """ - if (txl_from.read('PC') != transform_truth.range[0]) \ - or (txl_to.read('PC') != transform_truth.range[1]): + if (txl_from.read_register('PC') != transform_truth.range[0]) \ + or (txl_to.read_register('PC') != transform_truth.range[1]): tstart, tend = transform_truth.range return [Error(ErrorTypes.POSSIBLE, f'Program counters of the tested transformation' f' do not match the truth transformation:' - f' {hex(txl_from.read("PC"))} -> {hex(txl_to.read("PC"))}' + f' {hex(txl_from.read_register("PC"))} -> {hex(txl_to.read_register("PC"))}' f' (test) vs. {hex(tstart)} -> {hex(tend)} (truth).' f' Skipping with no errors.')] @@ -279,41 +296,48 @@ def _find_errors_symbolic(txl_from: ProgramState, return errors -def compare_symbolic(test_states: list[ProgramState], - transforms: list[SymbolicTransform]) \ +def compare_symbolic(test_states: Iterable[ProgramState], + transforms: Iterable[SymbolicTransform]) \ -> list[dict]: - #assert(len(test_states) == len(transforms) - 1) + test_states = iter(test_states) + transforms = iter(transforms) - result = [{ - 'pc': test_states[0].read('PC'), - 'txl': test_states[0], - 'ref': transforms[0], - 'errors': [] - }] + result = [] + cur_state = next(test_states) # The state before the transformation + transform = next(transforms) # Operates on `cur_state` - _list = zip(test_states[:-1], test_states[1:], transforms) - for cur_state, next_state, transform in _list: - pc_cur = cur_state.read('PC') - pc_next = next_state.read('PC') - - start_addr, end_addr = transform.range - if pc_cur != start_addr: - print(f'Program counter {hex(pc_cur)} in translated code has no' - f' corresponding reference state! Skipping.' - f' (reference: {hex(start_addr)})') - continue - if pc_next != end_addr: - print(f'Tested state transformation is {hex(pc_cur)} ->' - f' {hex(pc_next)}, but reference transform is' - f' {hex(start_addr)} -> {hex(end_addr)}!' - f' Skipping.') - - errors = _find_errors_symbolic(cur_state, next_state, transform) - result.append({ - 'pc': pc_cur, - 'txl': _calc_transformation(cur_state, next_state), - 'ref': transform, - 'errors': errors - }) + while True: + try: + next_state = next(test_states) # The state after the transformation + + pc_cur = cur_state.read_register('PC') + pc_next = next_state.read_register('PC') + start_addr, end_addr = transform.range + if pc_cur != start_addr: + print(f'Program counter {hex(pc_cur)} in translated code has' + f' no corresponding reference state! Skipping.' + f' (reference: {hex(start_addr)})') + cur_state = next_state + transform = next(transforms) + continue + if pc_next != end_addr: + print(f'Tested state transformation is {hex(pc_cur)} ->' + f' {hex(pc_next)}, but reference transform is' + f' {hex(start_addr)} -> {hex(end_addr)}!' + f' Skipping.') + + errors = _find_errors_symbolic(cur_state, next_state, transform) + result.append({ + 'pc': pc_cur, + 'txl': _calc_transformation(cur_state, next_state), + 'ref': transform, + 'errors': errors + }) + + # Step forward + cur_state = next_state + transform = next(transforms) + except StopIteration: + break return result diff --git a/focaccia/lldb_target.py b/focaccia/lldb_target.py index 146891f..444ab36 100644 --- a/focaccia/lldb_target.py +++ b/focaccia/lldb_target.py @@ -90,7 +90,7 @@ class LLDBConcreteTarget: for regname in arch.regnames: try: conc_val = self.read_register(regname) - state.set(regname, conc_val) + state.set_register(regname, conc_val) except KeyError: pass except ConcreteRegisterError: @@ -98,7 +98,7 @@ class LLDBConcreteTarget: if arch.archname == x86.archname: rflags = x86.decompose_rflags(self.read_register('rflags')) if regname in rflags: - state.set(regname, rflags[regname]) + state.set_register(regname, rflags[regname]) # Query and store memory state for mapping in self.get_mappings(): diff --git a/focaccia/miasm_util.py b/focaccia/miasm_util.py index d9a9936..514390d 100644 --- a/focaccia/miasm_util.py +++ b/focaccia/miasm_util.py @@ -30,6 +30,8 @@ expr_simp = expr_simp_explicit expr_simp.enable_passes({ExprOp: [simp_segm]}) class MiasmConcreteState: + """Resolves atomic symbols to some state.""" + miasm_flag_aliases = { 'NF': 'SF', 'I_F': 'IF', @@ -38,22 +40,22 @@ class MiasmConcreteState: } def __init__(self, state: ProgramState, loc_db: LocationDB): - self.state = state - self.loc_db = loc_db + self._state = state + self._loc_db = loc_db def resolve_register(self, regname: str) -> int | None: regname = regname.upper() if regname in self.miasm_flag_aliases: regname = self.miasm_flag_aliases[regname] - return self.state.read(regname) + return self._state.read_register(regname) def resolve_memory(self, addr: int, size: int) -> bytes | None: - return self.state.read_memory(addr, size) + return self._state.read_memory(addr, size) def resolve_location(self, loc: LocKey) -> int | None: - return self.loc_db.get_location_offset(loc) + return self._loc_db.get_location_offset(loc) -def eval_expr(expr: Expr, conc_state: MiasmConcreteState): +def eval_expr(expr: Expr, conc_state: MiasmConcreteState) -> Expr: """Evaluate a symbolic expression with regard to a concrete reference state. @@ -62,8 +64,9 @@ def eval_expr(expr: Expr, conc_state: MiasmConcreteState): register and memory state is resolved. :return: The most simplified and concrete representation of `expr` that - is possibly producible. May be either an `ExprInt` or an - `ExprLoc`. + is producible with the values from `conc_state`. Is guaranteed to + be either an `ExprInt` or an `ExprLoc` *if* `conc_state` only + returns concrete register- and memory values. """ # Most of these implementation are just copy-pasted members of # `SymbolicExecutionEngine`. diff --git a/focaccia/parser.py b/focaccia/parser.py index 8680eed..a5a1014 100644 --- a/focaccia/parser.py +++ b/focaccia/parser.py @@ -7,6 +7,7 @@ from typing import TextIO from .arch import supported_architectures, Arch from .snapshot import ProgramState +from .symbolic import SymbolicTransform class ParseError(Exception): """A parse error.""" @@ -18,6 +19,20 @@ def _get_or_throw(obj: dict, key: str): return val raise ParseError(f'Expected value at key {key}, but found none.') +def parse_transformations(json_stream: TextIO) -> list[SymbolicTransform]: + """Parse symbolic transformations from a text stream.""" + from .symbolic import parse_symbolic_transform + + json_data = json.load(json_stream) + return [parse_symbolic_transform(item) for item in json_data] + +def serialize_transformations(transforms: list[SymbolicTransform], + out_stream: TextIO): + """Serialize symbolic transformations to a text stream.""" + from .symbolic import serialize_symbolic_transform + + json.dump([serialize_symbolic_transform(t) for t in transforms], out_stream) + def parse_snapshots(json_stream: TextIO) -> list[ProgramState]: """Parse snapshots from our JSON format.""" json_data = json.load(json_stream) @@ -27,7 +42,7 @@ def parse_snapshots(json_stream: TextIO) -> list[ProgramState]: for snapshot in _get_or_throw(json_data, 'snapshots'): state = ProgramState(arch) for reg, val in _get_or_throw(snapshot, 'registers').items(): - state.set(reg, val) + state.set_register(reg, val) for mem in _get_or_throw(snapshot, 'memory'): start, end = _get_or_throw(mem, 'range') data = base64.b64decode(_get_or_throw(mem, 'data')) @@ -111,7 +126,7 @@ def _parse_qemu_line(line: str, cur_state: ProgramState): value = value.replace(' ', '') regname = cur_state.arch.to_regname(regname) if regname is not None: - cur_state.set(regname, int(value, 16)) + cur_state.set_register(regname, int(value, 16)) def parse_arancini(stream: TextIO, arch: Arch) -> list[ProgramState]: aliases = { @@ -136,6 +151,6 @@ def parse_arancini(stream: TextIO, arch: Arch) -> list[ProgramState]: regname, value = split regname = arch.to_regname(aliases.get(regname, regname)) if regname is not None: - states[-1].set(regname, int(value, 16)) + states[-1].set_register(regname, int(value, 16)) return states diff --git a/focaccia/snapshot.py b/focaccia/snapshot.py index 0f10dda..6342660 100644 --- a/focaccia/snapshot.py +++ b/focaccia/snapshot.py @@ -1,6 +1,13 @@ from .arch.arch import Arch +class RegisterAccessError(Exception): + """Raised when a register access fails.""" + def __init__(self, regname: str, msg: str): + super().__init__(msg) + self.regname = regname + class MemoryAccessError(Exception): + """Raised when a memory access fails.""" def __init__(self, addr: int, size: int, msg: str): super().__init__(msg) self.mem_addr = addr @@ -81,11 +88,11 @@ class ProgramState: self.regs: dict_t = { reg: None for reg in arch.regnames } self.mem = SparseMemory() - def read(self, reg: str) -> int: + def read_register(self, reg: str) -> int: """Read a register's value. - :raise KeyError: If `reg` is not a register name. - :raise ValueError: If the register has no value. + :raise KeyError: If `reg` is not a register name. + :raise RegisterAccessError: If the register has no value. """ regname = self.arch.to_regname(reg) if regname is None: @@ -94,11 +101,14 @@ class ProgramState: assert(regname in self.regs) regval = self.regs[regname] if regval is None: - raise ValueError(f'Unable to read value of register {reg} (aka.' - f' {regname}): The register contains no value.') + raise RegisterAccessError( + regname, + f'[In ProgramState.read_register]: Unable to read value of' + f' register {reg} (a.k.a. {regname}): The register is not set.' + f' Full state: {self}') return regval - def set(self, reg: str, value: int): + def set_register(self, reg: str, value: int): """Assign a value to a register. :raise KeyError: If `reg` is not a register name. diff --git a/focaccia/symbolic.py b/focaccia/symbolic.py index e132ebd..7ab0d84 100644 --- a/focaccia/symbolic.py +++ b/focaccia/symbolic.py @@ -1,7 +1,6 @@ """Tools and utilities for symbolic execution with Miasm.""" from __future__ import annotations -from typing import Self from miasm.analysis.binary import ContainerELF from miasm.analysis.machine import Machine @@ -18,37 +17,48 @@ from .lldb_target import LLDBConcreteTarget, \ from .miasm_util import MiasmConcreteState, eval_expr from .snapshot import ProgramState -class SymbolicTransform: - def __init__(self, from_addr: int, to_addr: int): - self.addr = from_addr - self.range = (from_addr, to_addr) - - def concat(self, other: Self) -> Self: - """Concatenate another transform to this transform. - - The symbolic transform on which `concat` is called is the transform - that is applied first, meaning: `(a.concat(b))(state) == b(a(state))`. - """ - raise NotImplementedError('concat is abstract.') - - def calc_register_transform(self, conc_state: ProgramState) \ - -> dict[str, int]: - raise NotImplementedError('calc_register_transform is abstract.') +def eval_symbol(symbol: Expr, conc_state: ProgramState) -> int: + """Evaluate a symbol based on a concrete reference state. - def calc_memory_transform(self, conc_state: ProgramState) \ - -> dict[int, bytes]: - raise NotImplementedError('calc_memory_transform is abstract.') + :param conc_state: A concrete state. + :return: The resolved value. - def __repr__(self) -> str: - start, end = self.range - return f'Symbolic state transformation {hex(start)} -> {hex(end)}' + :raise ValueError: If the concrete state does not contain a register value + that is referenced by the symbolic expression. + :raise MemoryAccessError: If the concrete state does not contain memory + that is referenced by the symbolic expression. + """ + class ConcreteStateWrapper(MiasmConcreteState): + """Extend the state resolver with assumptions about the expressions + that may be resolved with `eval_symbol`.""" + def __init__(self, conc_state: ProgramState): + super().__init__(conc_state, LocationDB()) + + def resolve_register(self, regname: str) -> int: + regname = regname.upper() + regname = self.miasm_flag_aliases.get(regname, regname) + return self._state.read_register(regname) + + def resolve_memory(self, addr: int, size: int) -> bytes: + return self._state.read_memory(addr, size) + + def resolve_location(self, _): + raise ValueError(f'[In eval_symbol]: Unable to evaluate symbols' + f' that contain IR location expressions.') + + res = eval_expr(symbol, ConcreteStateWrapper(conc_state)) + assert(isinstance(res, ExprInt)) # Must be either ExprInt or ExprLoc, + # but ExprLocs are disallowed by the + # ConcreteStateWrapper + return int(res) -class MiasmSymbolicTransform(SymbolicTransform): +class SymbolicTransform: + """A symbolic transformation mapping one program state to another.""" def __init__(self, transform: dict[Expr, Expr], arch: Arch, - start_addr: int, - end_addr: int): + from_addr: int, + to_addr: int): """ :param state: The symbolic transformation in the form of a SimState object. @@ -56,100 +66,264 @@ class MiasmSymbolicTransform(SymbolicTransform): represents the modifications to the program state performed by this instruction. """ - super().__init__(start_addr, end_addr) + self.addr = from_addr + """The instruction address of the program state on which the + transformation operates. Equivalent to `self.range[0]`.""" + + self.range = (from_addr, to_addr) + """The range of addresses that the transformation covers. + The transformation `t` maps the program state at instruction + `t.range[0]` to the program state at instruction `t.range[1]`.""" - self.regs_diff: dict[str, Expr] = {} - self.mem_diff: dict[ExprMem, Expr] = {} + self.changed_regs: dict[str, Expr] = {} + """Maps register names to expressions for the register's content. + + Contains only registers that are changed by the transformation. + Register names are already normalized to a respective architecture's + naming conventions.""" + + self.changed_mem: dict[Expr, Expr] = {} + """Maps memory addresses to memory content. + + For a dict tuple `(addr, value)`, `value.size` is the number of *bits* + written to address `addr`. Memory addresses may depend on other + symbolic values, such as register content, and are therefore symbolic + themselves.""" for dst, expr in transform.items(): assert(isinstance(dst, ExprMem) or isinstance(dst, ExprId)) if isinstance(dst, ExprMem): - self.mem_diff[dst] = expr + assert(dst.size == expr.size) + assert(expr.size % 8 == 0) + self.changed_mem[dst.ptr] = expr else: assert(isinstance(dst, ExprId)) regname = arch.to_regname(dst.name) if regname is not None: - self.regs_diff[regname] = expr + self.changed_regs[regname] = expr - self.arch = arch + def concat(self, other: SymbolicTransform) -> SymbolicTransform: + """Concatenate two transformations. - def concat(self, other: MiasmSymbolicTransform) -> Self: - class MiasmSymbolicState(MiasmConcreteState): - """Drop-in replacement for MiasmConcreteState in eval_expr that - returns the current transform's symbolic equations instead of - concrete values. Calling eval_expr with this effectively nests the - transformation into the concatenated transformation. - - We inherit from `MiasmSymbolicTransform` only for the purpose of - correct type checking. - """ - def __init__(self, transform: MiasmSymbolicTransform): - self.transform = transform + The symbolic transform on which `concat` is called is the transform + that is applied first, meaning: `(a.concat(b))(state) == b(a(state))`. - def resolve_register(self, regname: str): - return self.transform.regs_diff.get(regname, None) + Note that if transformation are concatenated that write to the same + memory location when applied to a specific starting state, the + concatenation may not recognize equivalence of syntactically different + symbolic address expressions. In this case, if you calculate all memory + values and store them at their address, the final result will depend on + the random iteration order over the `changed_mem` dict. - def resolve_memory(self, addr: int, size: int): - mem = ExprMem(ExprInt(addr, 64), size) - return self.transform.mem_diff.get(mem, None) + :param other: The transformation to concatenate to `self`. - def resolve_location(self, _): - return None + :return: Returns `self`. `self` is modified in-place. + :raise ValueError: If the two transformations don't span a contiguous + range of instructions. + """ + from typing import Callable + from miasm.expression.expression import ExprLoc, ExprSlice, ExprCond, \ + ExprOp, ExprCompose + from miasm.expression.simplifications import expr_simp_explicit if self.range[1] != other.range[0]: - raise ValueError(f'The concatenated transformations must span a' - f' contiguous range of instructions.') + repr_range = lambda r: f'[{hex(r[0])} -> {hex(r[1])}]' + raise ValueError( + f'Unable to concatenate transformation' + f' {repr_range(self.range)} with {repr_range(other.range)};' + f' the concatenated transformations must span a' + f' contiguous range of instructions.') + + def _eval_exprslice(expr: ExprSlice): + arg = _concat_to_self(expr.arg) + return ExprSlice(arg, expr.start, expr.stop) + + def _eval_exprcond(expr: ExprCond): + cond = _concat_to_self(expr.cond) + src1 = _concat_to_self(expr.src1) + src2 = _concat_to_self(expr.src2) + return ExprCond(cond, src1, src2) + + def _eval_exprop(expr: ExprOp): + args = [_concat_to_self(arg) for arg in expr.args] + return ExprOp(expr.op, *args) + + def _eval_exprcompose(expr: ExprCompose): + args = [_concat_to_self(arg) for arg in expr.args] + return ExprCompose(*args) + + expr_to_visitor: dict[type[Expr], Callable] = { + ExprInt: lambda e: e, + ExprId: lambda e: self.changed_regs.get(e.name, e), + ExprLoc: lambda e: e, + ExprMem: lambda e: ExprMem(_concat_to_self(e.ptr), e.size), + ExprSlice: _eval_exprslice, + ExprCond: _eval_exprcond, + ExprOp: _eval_exprop, + ExprCompose: _eval_exprcompose, + } + + def _concat_to_self(expr: Expr): + visitor = expr_to_visitor[expr.__class__] + return expr_simp_explicit(visitor(expr)) + + new_regs = self.changed_regs.copy() + for reg, expr in other.changed_regs.items(): + new_regs[reg] = _concat_to_self(expr) + + new_mem = self.changed_mem.copy() + for addr, expr in other.changed_mem.items(): + new_addr = _concat_to_self(addr) + new_expr = _concat_to_self(expr) + new_mem[new_addr] = new_expr + + self.changed_regs = new_regs + self.changed_mem = new_mem + self.range = (self.range[0], other.range[1]) - ref_state = MiasmSymbolicState(self) - for reg, expr in other.regs_diff.items(): - if reg not in self.regs_diff: - self.regs_diff[reg] = expr - else: - self.regs_diff[reg] = eval_expr(expr, ref_state) + return self - for dst, expr in other.mem_diff.items(): - dst = eval_expr(dst, ref_state) - if dst not in self.mem_diff: - self.mem_diff[dst] = expr - else: - self.mem_diff[dst] = eval_expr(expr, ref_state) + def get_used_registers(self) -> list[str]: + """Find all registers used by the transformation as input. - self.range = (self.range[0], other.range[1]) + :return: A list of register names. + """ + accessed_regs = set[str]() - return self + class ConcreteStateWrapper(MiasmConcreteState): + def __init__(self): pass + def resolve_register(self, regname: str) -> int | None: + accessed_regs.add(regname) + return None + def resolve_memory(self, addr: int, size: int): assert(False) + def resolve_location(self, _): assert(False) + + state = ConcreteStateWrapper() + for expr in self.changed_regs.values(): + eval_expr(expr, state) + for addr_expr, mem_expr in self.changed_mem.items(): + eval_expr(addr_expr, state) + eval_expr(mem_expr, state) + + return list(accessed_regs) - def calc_register_transform(self, conc_state: ProgramState) \ + def get_used_memory_addresses(self) -> list[ExprMem]: + """Find all memory addresses used by the transformation as input. + + :return: A list of memory access expressions. + """ + from typing import Callable + from miasm.expression.expression import ExprLoc, ExprSlice, ExprCond, \ + ExprOp, ExprCompose + + accessed_mem = set[ExprMem]() + + def _eval(expr: Expr): + def _eval_exprmem(expr: ExprMem): + accessed_mem.add(expr) # <-- this is the only important line! + _eval(expr.ptr) + def _eval_exprcond(expr: ExprCond): + _eval(expr.cond) + _eval(expr.src1) + _eval(expr.src2) + def _eval_exprop(expr: ExprOp): + for arg in expr.args: + _eval(arg) + def _eval_exprcompose(expr: ExprCompose): + for arg in expr.args: + _eval(arg) + + expr_to_visitor: dict[type[Expr], Callable] = { + ExprInt: lambda e: e, + ExprId: lambda e: e, + ExprLoc: lambda e: e, + ExprMem: _eval_exprmem, + ExprSlice: lambda e: _eval(e.arg), + ExprCond: _eval_exprcond, + ExprOp: _eval_exprop, + ExprCompose: _eval_exprcompose, + } + visitor = expr_to_visitor[expr.__class__] + visitor(expr) + + for expr in self.changed_regs.values(): + _eval(expr) + for addr_expr, mem_expr in self.changed_mem.items(): + _eval(addr_expr) + _eval(mem_expr) + + return list(accessed_mem) + + def eval_register_transforms(self, conc_state: ProgramState) \ -> dict[str, int]: - # Construct a dummy location DB. At this point, expressions should - # never contain IR locations. - ref_state = MiasmConcreteState(conc_state, LocationDB()) + """Calculate register transformations when applied to a concrete state. + + :param conc_state: A concrete program state that serves as the input + state on which the transformation operates. + :return: A map from register names to the register values that were + changed by the transformation. + :raise MemoryError: + :raise ValueError: + """ res = {} - for regname, expr in self.regs_diff.items(): - res[regname] = int(eval_expr(expr, ref_state)) + for regname, expr in self.changed_regs.items(): + res[regname] = eval_symbol(expr, conc_state) return res - def calc_memory_transform(self, conc_state: ProgramState) \ + def eval_memory_transforms(self, conc_state: ProgramState) \ -> dict[int, bytes]: - # Construct a dummy location DB. At this point, expressions should - # never contain IR locations. - ref_state = MiasmConcreteState(conc_state, LocationDB()) + """Calculate memory transformations when applied to a concrete state. + + :param conc_state: A concrete program state that serves as the input + state on which the transformation operates. + :return: A map from memory addresses to the bytes that were changed by + the transformation. + :raise MemoryError: + :raise ValueError: + """ res = {} - for addr, expr in self.mem_diff.items(): - addr = int(eval_expr(addr, ref_state)) + for addr, expr in self.changed_mem.items(): + addr = eval_symbol(addr, conc_state) length = int(expr.size / 8) - res[addr] = int(eval_expr(expr, ref_state)).to_bytes(length) + res[addr] = eval_symbol(expr, conc_state).to_bytes(length) return res def __repr__(self) -> str: start, end = self.range res = f'Symbolic state transformation {hex(start)} -> {hex(end)}:\n' - for reg, expr in self.regs_diff.items(): + for reg, expr in self.changed_regs.items(): res += f' {reg:6s} = {expr}\n' - for mem, expr in self.mem_diff.items(): - res += f' {mem} = {expr}\n' - return res[:-2] # Remove trailing newline + for addr, expr in self.changed_mem.items(): + res += f' {ExprMem(addr, expr.size)} = {expr}\n' + return res[:-1] # Remove trailing newline + +def parse_symbolic_transform(string: str) -> SymbolicTransform: + """Parse a symbolic transformation from a string. + :raise KeyError: if a parse error occurs. + """ + import json + from miasm.expression.parser import str_to_expr as parse + + data = json.loads(string) + + # We can use a None-arch because it's only used when the dict is not empty + t = SymbolicTransform({}, None, int(data['from_addr']), int(data['to_addr'])) + t.changed_regs = { name: parse(val) for name, val in data['regs'].items() } + t.changed_mem = { parse(addr): parse(val) for addr, val in data['mem'].items() } + + return t + +def serialize_symbolic_transform(t: SymbolicTransform) -> str: + """Serialize a symbolic transformation.""" + import json + return json.dumps({ + 'from_addr': t.range[0], + 'to_addr': t.range[1], + 'regs': { name: repr(expr) for name, expr in t.changed_regs.items() }, + 'mem': { repr(addr): repr(val) for addr, val in t.changed_mem.items() }, + }) def _step_until(target: LLDBConcreteTarget, addr: int) -> list[int]: """Step a concrete target to a specific instruction. @@ -205,7 +379,7 @@ class DisassemblyContext: class DisassemblyError(Exception): def __init__(self, - partial_trace: list[tuple[int, MiasmSymbolicTransform]], + partial_trace: list[tuple[int, SymbolicTransform]], faulty_pc: int, err_msg: str): self.partial_trace = partial_trace @@ -277,7 +451,7 @@ def _run_block(pc: int, conc_state: MiasmConcreteState, ctx: DisassemblyContext) # instructions are translated to multiple IR instructions. pass -class LLDBConcreteState: +class _LLDBConcreteState: """A back-end replacement for the `ProgramState` object from which `MiasmConcreteState` reads its values. This reads values directly from an LLDB target instead. This saves us the trouble of recording a full program @@ -287,7 +461,7 @@ class LLDBConcreteState: self._target = target self._arch = arch - def read(self, reg: str) -> int | None: + def read_register(self, reg: str) -> int | None: from focaccia.arch import x86 regname = self._arch.to_regname(reg) @@ -339,7 +513,7 @@ def collect_symbolic_trace(binary: str, target.set_breakpoint(pc) target.run() target.remove_breakpoint(pc) - conc_state = LLDBConcreteState(target, arch) + conc_state = _LLDBConcreteState(target, arch) symb_trace = [] # The resulting list of symbolic transforms per instruction @@ -402,8 +576,8 @@ def collect_symbolic_trace(binary: str, res = [] for (start, diff), (end, _) in zip(symb_trace[:-1], symb_trace[1:]): - res.append(MiasmSymbolicTransform(diff, arch, start, end)) + res.append(SymbolicTransform(diff, arch, start, end)) start, diff = symb_trace[-1] - res.append(MiasmSymbolicTransform(diff, arch, start, start)) + res.append(SymbolicTransform(diff, arch, start, start)) return res diff --git a/focaccia/utils.py b/focaccia/utils.py index 0c6f292..7173169 100644 --- a/focaccia/utils.py +++ b/focaccia/utils.py @@ -1,7 +1,43 @@ import sys import shutil +from .compare import ErrorSeverity + def print_separator(separator: str = '-', stream=sys.stdout, count: int = 80): maxtermsize = count termsize = shutil.get_terminal_size((80, 20)).columns print(separator * min(termsize, maxtermsize), file=stream) + +def print_result(result, min_severity: ErrorSeverity): + """Print a comparison result.""" + shown = 0 + suppressed = 0 + + for res in result: + # Filter errors by severity + errs = [e for e in res['errors'] if e.severity >= min_severity] + suppressed += len(res['errors']) - len(errs) + shown += len(errs) + + if errs: + pc = res['pc'] + print_separator() + print(f'For PC={hex(pc)}') + print_separator() + + # Print all non-suppressed errors + for n, err in enumerate(errs, start=1): + print(f' {n:2}. {err}') + + if errs: + print() + print(f'Expected transformation: {res["ref"]}') + print(f'Actual difference: {res["txl"]}') + + print() + print('#' * 60) + print(f'Found {shown} errors.') + print(f'Suppressed {suppressed} low-priority errors' + f' (showing {min_severity} and higher).') + print('#' * 60) + print() diff --git a/tools/capture_transforms.py b/tools/capture_transforms.py new file mode 100644 index 0000000..de35d86 --- /dev/null +++ b/tools/capture_transforms.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python + +import argparse +import logging + +from focaccia import parser +from focaccia.symbolic import collect_symbolic_trace + +def main(): + prog = argparse.ArgumentParser() + prog.description = 'Trace an executable concolically to capture symbolic' \ + ' transformations among instructions.' + prog.add_argument('binary', help='The program to analyse.') + prog.add_argument('args', action='store', nargs=argparse.REMAINDER, + help='Arguments to the program.') + prog.add_argument('-o', '--output', + default='trace.out', + help='Name of output file. (default: trace.out)') + args = prog.parse_args() + + logging.disable(logging.CRITICAL) + trace = collect_symbolic_trace(args.binary, args.args, None) + with open(args.output, 'w') as file: + parser.serialize_transformations(trace, file) + +if __name__ == "__main__": + main() |