diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | .gitmodules | 3 | ||||
| -rw-r--r-- | README.md | 52 | ||||
| -rw-r--r-- | arancini.py | 14 | ||||
| -rw-r--r-- | arch/__init__.py | 8 | ||||
| -rw-r--r-- | arch/x86.py | 1 | ||||
| -rw-r--r-- | compare.py | 78 | ||||
| m--------- | cpuid | 0 | ||||
| -rw-r--r-- | gen_trace.py | 63 | ||||
| -rw-r--r-- | interpreter.py | 125 | ||||
| -rw-r--r-- | lldb_target.py | 72 | ||||
| -rwxr-xr-x | main.py | 45 | ||||
| -rw-r--r-- | miasm_util.py | 51 | ||||
| -rw-r--r-- | requirements.txt | 1 | ||||
| -rw-r--r-- | run.py | 105 | ||||
| -rw-r--r-- | snapshot.py | 10 | ||||
| -rw-r--r-- | symbolic.py | 96 | ||||
| -rw-r--r-- | trace_symbols.py | 167 |
18 files changed, 278 insertions, 614 deletions
diff --git a/.gitignore b/.gitignore index 94631a0..39b5bdb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ build* *.md *.out *.txt +!requirements.txt *.bin *.dot build*/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a6d7f14 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cpuid"] + path = cpuid + url = https://github.com/flababah/cpuid.py.git diff --git a/README.md b/README.md index 65fe4ce..fcdbe90 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ -# DBT Testing +# Focaccia This repository contains initial code for comprehensive testing of binary translators. ## Requirements -We require at least LLDB version 17 for `fs_base`/`gs_base` register support. +For Python dependencies, see the `requirements.txt`. We also require at least LLDB version 17 for `fs_base`/`gs_base` +register support. I had to compile LLDB myself; these are the steps I had to take (you also need swig version >= 4): @@ -25,50 +26,33 @@ It will take a while to compile. The following files belong to a rough framework for the snapshot comparison engine: - - `main.py`: Entry point to the tool. Handling of command line arguments, pre-processing of input -logs, etc. + - `main.py`: Entry point to the tool. Handling of command line arguments, pre-processing of input logs, etc. - - `snapshot.py`: Internal structures used to work with snapshots. Contains the previous -`ContextBlock` class, which has been renamed to `ProgramState` to make its purpose as a snapshot of -the program state clearer. + - `snapshot.py`: Structures used to work with snapshots. The `ProgramState` class is our primary representation of +program snapshots. - `compare.py`: The central algorithms that work on snapshots. - - `run.py`: Tools to execute native programs and capture their state via an external debugger. + - `arancini.py`: Functionality specific to working with arancini. Parsing of arancini's logs into our snapshot +structures. - - `arancini.py`: Functionality specific to working with arancini. Parsing of arancini's logs into our -snapshot structures. + - `arch/`: Abstractions over different processor architectures. Will be used to integrate support for more +architectures later. Currently, we only have X86. - - `arch/`: Abstractions over different processor architectures. Will be used to integrate support for -more architectures later. Currently, we only have X86. - -## Symbolic execution +## Concolic execution The following files belong to a prototype of a data-dependency generator based on symbolic execution: - - `symbolic.py`: Algorithms and data structures to compute and manipulate symbolic program -transformations. - - - `gen_trace.py`: An invokable tool that generates an instruction trace for an executable's native -execution. Is imported into `trace_symbols.py`, which uses the core function that records a trace. - - - `trace_symbols.py`: A simple proof of concept for symbolic data-dependency tracking. Takes an -executable as an argument and does the following: - - 1. Executes the program natively (starting at `main`) and records a trace of every instruction -executed, stopping when exiting `main`. - - 2. Tries to follow this trace of instructions concolically (keeps a concrete program state from -a native execution in parallel to a symbolic program state), recording after each instruction the -changes it has made to the program state before that instruction. + - `symbolic.py`: Algorithms and data structures to compute and manipulate symbolic program transformations. This +handles the symbolic part of "concolic" execution. - 3. Writes the program state at each instruction to log files; writes the concrete state of the -real execution to 'concrete.log' and the symbolic difference to 'symbolic.log'. + - `lldb_target.py`: Tools for executing a program concretely and tracking its execution using +[LLDB](https://lldb.llvm.org/). This handles the concrete part of "concolic" execution. - - `interpreter.py`: Contains an algorithm that evaluates a symbolic expression to a concrete value, -using a reference state as input. + - `miasm_util.py`: Tools to evaluate Miasm's symbolic expressions based on a concrete state. Ties the symbolic and +concrete parts together into "concolic" execution. ## Helpers - - `lldb_target.py`: Implements angr's `ConcreteTarget` interface for [LLDB](https://lldb.llvm.org/). + - `miasm_test.py`: A test script that traces a program concolically. diff --git a/arancini.py b/arancini.py index adbe6c8..71f45d7 100644 --- a/arancini.py +++ b/arancini.py @@ -45,7 +45,7 @@ def parse(lines: list[str], arch: Arch) -> list[ProgramState]: snapshots = [] for line in lines: if 'Backwards' in line and len(snapshots) > 0: - snapshots[-1].set_backwards() + # snapshots[-1].set_backwards() continue match = try_parse_line(line) @@ -85,10 +85,10 @@ def get_labels(): 'R13': ('R13', split_first), 'R14': ('R14', split_first), 'R15': ('R15', split_first), - 'flag ZF': ('flag ZF', split_second), - 'flag CF': ('flag CF', split_second), - 'flag OF': ('flag OF', split_second), - 'flag SF': ('flag SF', split_second), - 'flag PF': ('flag PF', split_second), - 'flag DF': ('flag DF', split_second)} + 'flag ZF': ('ZF', split_second), + 'flag CF': ('CF', split_second), + 'flag OF': ('OF', split_second), + 'flag SF': ('SF', split_second), + 'flag PF': ('PF', split_second), + 'flag DF': ('DF', split_second)} return labels diff --git a/arch/__init__.py b/arch/__init__.py index 4943749..2926d20 100644 --- a/arch/__init__.py +++ b/arch/__init__.py @@ -1,7 +1,11 @@ from .arch import Arch from . import x86 -"""A dictionary containing all supported architectures at their names.""" +"""A dictionary containing all supported architectures at their names. + +The arch names (keys) should be compatible with the string returned from +`platform.machine()`. +""" supported_architectures: dict[str, Arch] = { - "X86": x86.ArchX86(), + "x86_64": x86.ArchX86(), } diff --git a/arch/x86.py b/arch/x86.py index 88e6d1a..776291d 100644 --- a/arch/x86.py +++ b/arch/x86.py @@ -34,6 +34,7 @@ regnames = [ # A dictionary mapping aliases to standard register names. regname_aliases = { 'PC': 'RIP', + 'NF': 'SF', # negative flag == sign flag in Miasm? } def decompose_rflags(rflags: int) -> dict[str, int]: diff --git a/compare.py b/compare.py index 8a25d8a..0f144bf 100644 --- a/compare.py +++ b/compare.py @@ -1,8 +1,7 @@ -from snapshot import ProgramState, SnapshotSymbolResolver +from snapshot import ProgramState from symbolic import SymbolicTransform -from utils import print_separator -def calc_transformation(previous: ProgramState, current: ProgramState): +def _calc_transformation(previous: ProgramState, current: ProgramState): """Calculate the difference between two context blocks. :return: A context block that contains in its registers the difference @@ -13,14 +12,18 @@ def calc_transformation(previous: ProgramState, current: ProgramState): arch = previous.arch transformation = ProgramState(arch) for reg in arch.regnames: - prev_val, cur_val = previous.regs[reg], current.regs[reg] - if prev_val is not None and cur_val is not None: - transformation.regs[reg] = cur_val - prev_val + try: + prev_val, cur_val = previous.read(reg), current.read(reg) + if prev_val is not None and cur_val is not None: + transformation.set(reg, cur_val - prev_val) + except ValueError: + # Register is not set in either state + pass return transformation -def find_errors(txl_state: ProgramState, prev_txl_state: ProgramState, - truth_state: ProgramState, prev_truth_state: ProgramState) \ +def _find_errors(txl_state: ProgramState, prev_txl_state: ProgramState, + truth_state: ProgramState, prev_truth_state: ProgramState) \ -> list[dict]: """Find possible errors between a reference and a tested state. @@ -38,11 +41,16 @@ def find_errors(txl_state: ProgramState, prev_txl_state: ProgramState, arch = txl_state.arch errors = [] - transform_truth = calc_transformation(prev_truth_state, truth_state) - transform_txl = calc_transformation(prev_txl_state, txl_state) + transform_truth = _calc_transformation(prev_truth_state, truth_state) + transform_txl = _calc_transformation(prev_txl_state, txl_state) for reg in arch.regnames: - diff_txl = transform_txl.regs[reg] - diff_truth = transform_truth.regs[reg] + try: + diff_txl = transform_txl.read(reg) + diff_truth = transform_truth.read(reg) + except ValueError: + # Register is not set in either state + continue + if diff_txl == diff_truth: # The register contains a value that is expected # by the transformation. @@ -80,7 +88,7 @@ def compare_simple(test_states: list[ProgramState], # No errors in initial snapshot because we can't perform difference # calculations on it result = [{ - 'pc': test_states[0].regs[PC_REGNAME], + 'pc': test_states[0].read(PC_REGNAME), 'txl': test_states[0], 'ref': truth_states[0], 'errors': [] }] @@ -91,21 +99,19 @@ def compare_simple(test_states: list[ProgramState], for txl, truth in it_cur: prev_txl, prev_truth = next(it_prev) - pc_txl = txl.regs[PC_REGNAME] - pc_truth = truth.regs[PC_REGNAME] + pc_txl = txl.read(PC_REGNAME) + pc_truth = truth.read(PC_REGNAME) # The program counter should always be set on a snapshot assert(pc_truth is not None) assert(pc_txl is not None) if pc_txl != pc_truth: - print(f'Unmatched program counter {txl.as_repr(PC_REGNAME)}' + print(f'Unmatched program counter {hex(txl.read(PC_REGNAME))}' f' in translated code!') continue - else: - txl.matched = True - errors = find_errors(txl, prev_txl, truth, prev_truth) + errors = _find_errors(txl, prev_txl, truth, prev_truth) result.append({ 'pc': pc_txl, 'txl': txl, 'ref': truth, @@ -113,20 +119,19 @@ def compare_simple(test_states: list[ProgramState], }) # TODO: Why do we skip backward branches? - if txl.has_backwards: - print(f' -- Encountered backward branch. Don\'t skip.') + #if txl.has_backwards: + # print(f' -- Encountered backward branch. Don\'t skip.') return result -def find_errors_symbolic(txl_from: ProgramState, - txl_to: ProgramState, - transform_truth: SymbolicTransform) \ +def _find_errors_symbolic(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ -> list[dict]: arch = txl_from.arch - resolver = SnapshotSymbolResolver(txl_from) - assert(txl_from.read('PC') == transform_truth.start_addr) - assert(txl_to.read('PC') == transform_truth.end_addr) + assert(txl_from.read('PC') == transform_truth.range[0]) + assert(txl_to.read('PC') == transform_truth.range[1]) errors = [] for reg in arch.regnames: @@ -137,14 +142,14 @@ def find_errors_symbolic(txl_from: ProgramState, txl_val = txl_to.read(reg) try: - truth = transform_truth.eval_register_transform(reg.lower(), resolver) + truth = transform_truth.calc_register_transform(txl_from) print(f'Evaluated symbolic formula to {hex(txl_val)} vs. txl {hex(txl_val)}') if txl_val != truth: errors.append({ 'reg': reg, 'expected': truth, 'actual': txl_val, - 'equation': transform_truth.state.regs.get(reg), + 'equation': transform_truth.regs_diff[reg], }) except AttributeError: print(f'Register {reg} does not exist.') @@ -157,7 +162,7 @@ def compare_symbolic(test_states: list[ProgramState], PC_REGNAME = 'PC' result = [{ - 'pc': test_states[0].regs[PC_REGNAME], + 'pc': test_states[0].read(PC_REGNAME), 'txl': test_states[0], 'ref': transforms[0], 'errors': [] @@ -171,21 +176,22 @@ def compare_symbolic(test_states: list[ProgramState], # The program counter should always be set on a snapshot assert(pc_cur is not None and pc_next is not None) - if pc_cur != transform.start_addr: + start_addr, end_addr = transform.range + if pc_cur != start_addr: print(f'Program counter {hex(pc_cur)} in translated code has no' f' corresponding reference state! Skipping.' - f' (reference: {hex(transform.start_addr)})') + f' (reference: {hex(start_addr)})') continue - if pc_next != transform.end_addr: + if pc_next != end_addr: print(f'Tested state transformation is {hex(pc_cur)} ->' f' {hex(pc_next)}, but reference transform is' - f' {hex(transform.start_addr)} -> {hex(transform.end_addr)}!' + f' {hex(start_addr)} -> {hex(end_addr)}!' f' Skipping.') - errors = find_errors_symbolic(cur_state, next_state, transform) + errors = _find_errors_symbolic(cur_state, next_state, transform) result.append({ 'pc': pc_cur, - 'txl': calc_transformation(cur_state, next_state), + 'txl': _calc_transformation(cur_state, next_state), 'ref': transform, 'errors': errors }) diff --git a/cpuid b/cpuid new file mode 160000 +Subproject 335f97a08af46dda14a09f2e825dddbbe7e8177 diff --git a/gen_trace.py b/gen_trace.py deleted file mode 100644 index ec5cb86..0000000 --- a/gen_trace.py +++ /dev/null @@ -1,63 +0,0 @@ -import argparse -import lldb -import lldb_target - -def record_trace(binary: str, - args: list[str] = [], - func_name: str | None = 'main') -> list[int]: - """ - :param binary: The binary file to execute. - :param args: Arguments to the program. Should *not* include the - executable's location as the usual first argument. - :param func_name: Only record trace of a specific function. - """ - # Set up LLDB target - target = lldb_target.LLDBConcreteTarget(binary, args) - - # Skip to first instruction in `main` - if func_name is not None: - result = lldb.SBCommandReturnObject() - break_at_func = f'b -b {func_name} -s {target.module.GetFileSpec().GetFilename()}' - target.interpreter.HandleCommand(break_at_func, result) - target.run() - - # Run until main function is exited - trace = [] - while not target.is_exited(): - thread = target.process.GetThreadAtIndex(0) - - # Break if the traced function is exited - if func_name is not None: - func_names = [thread.GetFrameAtIndex(i).GetFunctionName() \ - for i in range(0, thread.GetNumFrames())] - if func_name not in func_names: - break - trace.append(target.read_register('pc')) - thread.StepInstruction(False) - - return trace - -def parse_args(): - prog = argparse.ArgumentParser() - prog.add_argument('binary', - help='The executable to trace.') - prog.add_argument('-o', '--output', - default='breakpoints', - type=str, - help='File to which the recorded trace is written.') - prog.add_argument('--args', - default=[], - nargs='+', - help='Arguments to the executable.') - return prog.parse_args() - -def main(): - args = parse_args() - trace = record_trace(args.binary, args.args) - with open(args.output, 'w') as file: - for addr in trace: - print(hex(addr), file=file) - print(f'Generated a trace of {len(trace)} instructions.') - -if __name__ == '__main__': - main() diff --git a/interpreter.py b/interpreter.py deleted file mode 100644 index 8e876f5..0000000 --- a/interpreter.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Interpreter for claripy ASTs""" - -from inspect import signature -from logging import debug - -import claripy as cp - -class SymbolResolver: - def __init__(self): - pass - - def resolve(self, symbol_name: str) -> cp.ast.Base: - raise NotImplementedError() - -class SymbolResolveError(Exception): - def __init__(self, symbol, reason: str = ""): - super().__init__(f'Unable to resolve symbol name \"{symbol}\" to a' - ' concrete value' - + f': {reason}' if len(reason) > 0 else '.') - -def eval(resolver: SymbolResolver, expr) -> int: - """Evaluate a claripy expression to a concrete value. - - :param resolver: A `SymbolResolver` implementation that can resolve symbol - names to concrete values. - :param expr: The claripy AST to evaluate. Should be a subclass of - `claripy.ast.Base`. - - :return: A concrete value if the expression was resolved successfully. - If `expr` is not a claripy AST, `expr` is returned immediately. - :raise NotImplementedError: - :raise SymbolResolveError: If `resolver` is not able to resolve a symbol. - """ - if not issubclass(type(expr), cp.ast.Base): - return expr - - if expr.depth == 1: - if expr.symbolic: - name = expr._encoded_name.decode() - val = resolver.resolve(name) - if val is None: - raise SymbolResolveError(name) - return eval(resolver, val) - else: # if expr.concrete - assert(expr.concrete) - return expr.v - - # Expression is a non-trivial AST, i.e. a function - return _eval_op(resolver, expr.op, *expr.args) - -def _eval_op(resolver: SymbolResolver, op, *args) -> int: - """Evaluate a claripy operator expression. - - :param *args: Arguments to the function `op`. These are NOT evaluated yet! - """ - assert(type(op) is str) - - def concat(*vals): - res = 0 - for val in vals: - assert(type(val) is cp.ast.BV) - res = res << val.length - res = res | eval(resolver, val) - return res - - # Handle claripy's operators - if op == 'Concat': - res = concat(*args) - debug(f'Concatenating {args} to {hex(res)}') - return res - if op == 'Extract': - assert(len(args) == 3) - start, end, val = (eval(resolver, arg) for arg in args) - size = start - end + 1 - res = (val >> end) & ((1 << size) - 1) - debug(f'Extracing range [{start}, {end}] from {hex(val)}: {hex(res)}') - return res - if op == 'If': - assert(len(args) == 3) - cond, iftrue, iffalse = (eval(resolver, arg) for arg in args) - debug(f'Evaluated branch condition {args[0]} to {cond}') - return iftrue if bool(cond) else iffalse - if op == 'Reverse': - assert(len(args) == 1) - return concat(*reversed(args[0].chop(8))) - - # `op` is not one of claripy's special operators, so treat it as the name - # of a python operator function (because that is how claripy names its OR, - # EQ, etc.) - - # Convert some of the non-python names to magic names - # NOTE: We use python's signed comparison operators for unsigned - # comparisons. I'm not sure that this is legal. - if op in ['SGE', 'SGT', 'SLE', 'SLT', 'UGE', 'UGT', 'ULE', 'ULT']: - op = '__' + op[1:].lower() + '__' - - if op in ['And', 'Or']: - op = '__' + op.lower() + '__' - - resolved_args = [eval(resolver, arg) for arg in args] - try: - func = getattr(int, op) - except AttributeError: - raise NotImplementedError(op) - - # Sometimes claripy doesn't build its AST in an arity-respecting way if - # adjacent operations are associative. For example, it might pass five - # arguments to an XOR function instead of nesting the AST deeper. - # - # That's why we have to check with the python function's signature for its - # number of arguments and manually apply parentheses. - sig = signature(func) - assert(len(args) >= len(sig.parameters)) - - debug(f'Trying to evaluate function {func} with arguments {resolved_args}') - if len(sig.parameters) == len(args): - return func(*resolved_args) - else: - # Fold parameters from left by successively applying `op` to a - # subset of them - return _eval_op(resolver, - op, - func(*resolved_args[0:len(sig.parameters)]), - *resolved_args[len(sig.parameters):] - ) diff --git a/lldb_target.py b/lldb_target.py index e016005..f587b37 100644 --- a/lldb_target.py +++ b/lldb_target.py @@ -1,14 +1,32 @@ import lldb -from angr.errors import SimConcreteMemoryError, \ - SimConcreteRegisterError -from angr_targets.concrete import ConcreteTarget -from angr_targets.memory_map import MemoryMap - from arch import x86 from snapshot import ProgramState -class LLDBConcreteTarget(ConcreteTarget): +class MemoryMap: + """Description of a range of mapped memory. + + Inspired by https://github.com/angr/angr-targets/blob/master/angr_targets/memory_map.py, + meaning we initially used angr and I wanted to keep the interface when we + switched to a different tool. + """ + def __init__(self, start_address, end_address, name, perms): + self.start_address = start_address + self.end_address = end_address + self.name = name + self.perms = perms + + def __str__(self): + return f'MemoryMap[0x{self.start_address:x}, 0x{self.end_address:x}]' \ + f': {self.name}' + +class ConcreteRegisterError(Exception): + pass + +class ConcreteMemoryError(Exception): + pass + +class LLDBConcreteTarget: def __init__(self, executable: str, argv: list[str] = []): """Construct an LLDB concrete target. Stop at entry. @@ -35,36 +53,25 @@ class LLDBConcreteTarget(ConcreteTarget): raise RuntimeError(f'[In LLDBConcreteTarget.__init__]: Failed to' f' launch process.') - def set_breakpoint(self, addr, **kwargs): + def set_breakpoint(self, addr): command = f'b -a {addr} -s {self.module.GetFileSpec().GetFilename()}' result = lldb.SBCommandReturnObject() self.interpreter.HandleCommand(command, result) - def remove_breakpoint(self, addr, **kwargs): + def remove_breakpoint(self, addr): command = f'breakpoint delete {addr}' result = lldb.SBCommandReturnObject() self.interpreter.HandleCommand(command, result) - def is_running(self): - return self.process.GetState() == lldb.eStateRunning - def is_exited(self): - """Not part of the angr interface, but much more useful than - `is_running`. + """Signals whether the concrete process has exited. :return: True if the process has exited. False otherwise. """ return self.process.GetState() == lldb.eStateExited - def wait_for_running(self): - while self.process.GetState() != lldb.eStateRunning: - pass - - def wait_for_halt(self): - while self.process.GetState() != lldb.eStateStopped: - pass - def run(self): + """Continue execution of the concrete process.""" state = self.process.GetState() if state == lldb.eStateExited: raise RuntimeError(f'Tried to resume process execution, but the' @@ -73,16 +80,10 @@ class LLDBConcreteTarget(ConcreteTarget): self.process.Continue() def step(self): + """Step forward by a single instruction.""" thread: lldb.SBThread = self.process.GetThreadAtIndex(0) thread.StepInstruction(False) - def stop(self): - self.process.Stop() - - def exit(self): - self.debugger.Terminate() - print(f'Program exited with status {self.process.GetState()}') - def _get_register(self, regname: str) -> lldb.SBValue: """Find a register by name. @@ -92,7 +93,7 @@ class LLDBConcreteTarget(ConcreteTarget): frame = self.process.GetThreadAtIndex(0).GetFrameAtIndex(0) reg = frame.FindRegister(regname) if reg is None: - raise SimConcreteRegisterError( + raise ConcreteRegisterError( f'[In LLDBConcreteTarget._get_register]: Register {regname}' f' not found.') return reg @@ -101,7 +102,7 @@ class LLDBConcreteTarget(ConcreteTarget): reg = self._get_register(regname) val = reg.GetValue() if val is None: - raise SimConcreteRegisterError( + raise ConcreteRegisterError( f'[In LLDBConcreteTarget.read_register]: Register has an' f' invalid value of {val}.') @@ -112,7 +113,7 @@ class LLDBConcreteTarget(ConcreteTarget): error = lldb.SBError() reg.SetValueFromCString(hex(value), error) if not error.success: - raise SimConcreteRegisterError( + raise ConcreteRegisterError( f'[In LLDBConcreteTarget.write_register]: Unable to set' f' {regname} to value {hex(value)}!') @@ -120,7 +121,7 @@ class LLDBConcreteTarget(ConcreteTarget): err = lldb.SBError() content = self.process.ReadMemory(addr, size, err) if not err.success: - raise SimConcreteMemoryError(f'Error when reading {size} bytes at' + raise ConcreteMemoryError(f'Error when reading {size} bytes at' f' address {hex(addr)}: {err}') return content @@ -128,7 +129,7 @@ class LLDBConcreteTarget(ConcreteTarget): err = lldb.SBError() res = self.process.WriteMemory(addr, value, err) if not err.success or res != len(value): - raise SimConcreteMemoryError(f'Error when writing to address' + raise ConcreteMemoryError(f'Error when writing to address' f' {hex(addr)}: {err}') def get_mappings(self) -> list[MemoryMap]: @@ -146,7 +147,6 @@ class LLDBConcreteTarget(ConcreteTarget): mmap.append(MemoryMap(region.GetRegionBase(), region.GetRegionEnd(), - 0, # offset? name if name is not None else '<none>', perms)) return mmap @@ -167,7 +167,7 @@ def record_snapshot(target: LLDBConcreteTarget) -> ProgramState: state.set(regname, conc_val) except KeyError: pass - except SimConcreteRegisterError: + except ConcreteRegisterError: if regname in rflags: state.set(regname, rflags[regname]) @@ -178,7 +178,7 @@ def record_snapshot(target: LLDBConcreteTarget) -> ProgramState: try: data = target.read_memory(mapping.start_address, size) state.write_memory(mapping.start_address, data) - except SimConcreteMemoryError: + except ConcreteMemoryError: # Unable to read memory from mapping pass diff --git a/main.py b/main.py index b0aeb36..fabb05b 100755 --- a/main.py +++ b/main.py @@ -1,15 +1,41 @@ #! /bin/python3 import argparse +import platform +from typing import Iterable import arancini from arch import x86 from compare import compare_simple, compare_symbolic -from gen_trace import record_trace -from run import run_native_execution +from lldb_target import LLDBConcreteTarget, record_snapshot from symbolic import collect_symbolic_trace from utils import check_version, print_separator +def run_native_execution(oracle_program: str, breakpoints: Iterable[int]): + """Gather snapshots from a native execution via an external debugger. + + :param oracle_program: Program to execute. + :param breakpoints: List of addresses at which to break and record the + program's state. + + :return: A list of snapshots gathered from the execution. + """ + assert(platform.machine() == "x86_64") + + target = LLDBConcreteTarget(oracle_program) + + # Set breakpoints + for address in breakpoints: + target.set_breakpoint(address) + + # Execute the native program + snapshots = [] + while not target.is_exited(): + snapshots.append(record_snapshot(target)) + target.run() + + return snapshots + def parse_inputs(txl_path, ref_path, program): # Our architecture arch = x86.ArchX86() @@ -90,8 +116,15 @@ def main(): if args.symbolic: assert(program is not None) - full_trace = record_trace(program, args=[]) - transforms = collect_symbolic_trace(program, full_trace) + transforms = collect_symbolic_trace(program, [program]) + + new = transforms[0] \ + .concat(transforms[1]) \ + .concat(transforms[2]) \ + .concat(transforms[3]) \ + .concat(transforms[4]) + print(f'New transform: {new}') + exit(0) # TODO: Transform the traces so that the states match result = compare_symbolic(txl, transforms) @@ -114,8 +147,8 @@ def main(): print(f'Content of register {reg} is possibly false.' f' Expected difference: {err["expected"]}, actual difference' f' in the translation: {err["actual"]}.\n' - f' (txl) {reg}: {hex(txl.regs[reg])}\n' - f' (ref) {reg}: {hex(ref.regs[reg])}') + f' (txl) {reg}: {hex(txl.read(reg))}\n' + f' (ref) {reg}: {hex(ref.read(reg))}') print() print('#' * 60) diff --git a/miasm_util.py b/miasm_util.py index 3ceebea..0d3ab3d 100644 --- a/miasm_util.py +++ b/miasm_util.py @@ -41,13 +41,13 @@ class MiasmConcreteState: self.state = state self.loc_db = loc_db - def resolve_register(self, regname: str) -> int: + def resolve_register(self, regname: str) -> int | None: regname = regname.upper() if regname in self.miasm_flag_aliases: regname = self.miasm_flag_aliases[regname] return self.state.read(regname) - def resolve_memory(self, addr: int, size: int) -> bytes: + def resolve_memory(self, addr: int, size: int) -> bytes | None: return self.state.read_memory(addr, size) def resolve_location(self, loc: LocKey) -> int | None: @@ -95,14 +95,18 @@ def _eval_exprint(expr: ExprInt, _): def _eval_exprid(expr: ExprId, state: MiasmConcreteState): """Evaluate an ExprId using the current state""" val = state.resolve_register(expr.name) - return ExprInt(val, expr.size) + if val is None: + return expr + if isinstance(val, int): + return ExprInt(val, expr.size) + return val def _eval_exprloc(expr: ExprLoc, state: MiasmConcreteState): """Evaluate an ExprLoc using the current state""" offset = state.resolve_location(expr.loc_key) - if offset is not None: - return ExprInt(offset, expr.size) - return expr + if offset is None: + return expr + return ExprInt(offset, expr.size) def _eval_exprmem(expr: ExprMem, state: MiasmConcreteState): """Evaluate an ExprMem using the current state. @@ -116,10 +120,15 @@ def _eval_exprmem(expr: ExprMem, state: MiasmConcreteState): assert(expr.size % 8 == 0) addr = eval_expr(expr.ptr, state) - ret = state.resolve_memory(int(addr), int(expr.size / 8)) - assert(len(ret) * 8 == expr.size) - ival = ExprInt(int.from_bytes(ret, byteorder='little'), expr.size) - return ExprSlice(ival, 0, len(ret) * 8) + if not addr.is_int(): + return expr + + mem = state.resolve_memory(int(addr), int(expr.size / 8)) + if mem is None: + return expr + + assert(len(mem) * 8 == expr.size) + return ExprInt(int.from_bytes(mem, byteorder='little'), expr.size) def _eval_exprcond(expr, state: MiasmConcreteState): """Evaluate an ExprCond using the current state""" @@ -133,12 +142,34 @@ def _eval_exprslice(expr, state: MiasmConcreteState): arg = eval_expr(expr.arg, state) return ExprSlice(arg, expr.start, expr.stop) +def _eval_cpuid(rax: ExprInt, out_reg: ExprInt): + """Evaluate the `x86_cpuid` operator by performing a real invocation of + the CPUID instruction. + + :param rax: The current value of RAX. Must be concrete. + :param out_reg: An index in `[0, 4)` signaling which register's value + shall be returned. Must be concrete. + """ + from cpuid import cpuid + + regs = cpuid.CPUID()(int(rax)) + + if int(out_reg) >= len(regs): + raise ValueError(f'Output register may not be {out_reg}.') + return ExprInt(regs[int(out_reg)], out_reg.size) + def _eval_exprop(expr, state: MiasmConcreteState): """Evaluate an ExprOp using the current state""" args = [] for oarg in expr.args: arg = eval_expr(oarg, state) args.append(arg) + + if expr.op == 'x86_cpuid': + # Can't do this in an expression simplifier plugin because the + # arguments must be concrete. + assert(len(expr.args) == 2) + return _eval_cpuid(args[0], args[1]) return ExprOp(expr.op, *args) def _eval_exprcompose(expr, state: MiasmConcreteState): diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..220ba8b --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +miasm diff --git a/run.py b/run.py deleted file mode 100644 index 768a73d..0000000 --- a/run.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Functionality to execute native programs and collect snapshots via lldb.""" - -import platform -import sys -import lldb -from typing import Callable - -# TODO: The debugger callback is currently specific to a single architecture. -# We should make it generic. -from arch import Arch, x86 -from snapshot import ProgramState - -class SnapshotBuilder: - """At every breakpoint, writes register contents to a stream. - - Generated snapshots are stored in and can be read from `self.states`. - """ - def __init__(self, arch: Arch): - self.arch = arch - self.states = [] - self.regnames = set(arch.regnames) - - def create_snapshot(self, frame: lldb.SBFrame): - state = ProgramState(self.arch) - state.set('PC', frame.GetPC()) - for regname in self.arch.regnames: - reg = frame.FindRegister(regname) - regval = int(reg.GetValue(), base=16) - state.set(regname, regval) - if regname == 'RFLAGS': - flags = x86.decompose_rflags(regval) - for flag_name, val in flags.items(): - state.set(flag_name, val) - return state - - def __call__(self, frame): - snapshot = self.create_snapshot(frame) - self.states.append(snapshot) - -class Debugger: - def __init__(self, program): - self.debugger = lldb.SBDebugger.Create() - self.debugger.SetAsync(False) - self.target = self.debugger.CreateTargetWithFileAndArch(program, - lldb.LLDB_ARCH_DEFAULT) - self.module = self.target.FindModule(self.target.GetExecutable()) - self.interpreter = self.debugger.GetCommandInterpreter() - - def set_breakpoint_by_addr(self, address: int): - command = f"b -a {address} -s {self.module.GetFileSpec().GetFilename()}" - result = lldb.SBCommandReturnObject() - self.interpreter.HandleCommand(command, result) - - def get_breakpoints_count(self): - return self.target.GetNumBreakpoints() - - def execute(self, callback: Callable): - error = lldb.SBError() - listener = self.debugger.GetListener() - process = self.target.Launch(listener, None, None, None, None, None, None, 0, - True, error) - - # Check if the process has launched successfully - if process.IsValid(): - print(f'Launched process: {process}') - else: - print('Failed to launch process', file=sys.stderr) - - while True: - state = process.GetState() - if state == lldb.eStateStopped: - for thread in process: - callback(thread.GetFrameAtIndex(0)) - process.Continue() - if state == lldb.eStateExited: - break - - print(f'Process state: {process.GetState()}') - print('Program output:') - print(process.GetSTDOUT(1024)) - print(process.GetSTDERR(1024)) - -def run_native_execution(oracle_program: str, breakpoints: set[int]): - """Gather snapshots from a native execution via an external debugger. - - :param oracle_program: Program to execute. - :param breakpoints: List of addresses at which to break and record the - program's state. - - :return: A list of snapshots gathered from the execution. - """ - assert(platform.machine() == "x86_64") - - debugger = Debugger(oracle_program) - - # Set breakpoints - for address in breakpoints: - debugger.set_breakpoint_by_addr(address) - assert(debugger.get_breakpoints_count() == len(breakpoints)) - - # Execute the native program - builder = SnapshotBuilder(x86.ArchX86()) - debugger.execute(builder) - - return builder.states diff --git a/snapshot.py b/snapshot.py index a4bfb0f..80c1ac5 100644 --- a/snapshot.py +++ b/snapshot.py @@ -1,5 +1,4 @@ from arch.arch import Arch -from interpreter import SymbolResolver, SymbolResolveError class MemoryAccessError(Exception): def __init__(self, msg: str): @@ -112,12 +111,3 @@ class ProgramState: def __repr__(self): return repr(self.regs) - -class SnapshotSymbolResolver(SymbolResolver): - def __init__(self, snapshot: ProgramState): - self._state = snapshot - - def resolve(self, symbol: str): - if symbol not in self._state.arch.regnames: - raise SymbolResolveError(symbol, 'Symbol is not a register name.') - return self._state.read(symbol) diff --git a/symbolic.py b/symbolic.py index a5e1f70..2a328fd 100644 --- a/symbolic.py +++ b/symbolic.py @@ -1,21 +1,33 @@ """Tools and utilities for symbolic execution with Miasm.""" +from __future__ import annotations +from typing import Self + from miasm.analysis.binary import ContainerELF from miasm.analysis.machine import Machine from miasm.core.asmblock import AsmCFG from miasm.core.locationdb import LocationDB from miasm.ir.symbexec import SymbolicExecutionEngine -from miasm.expression.expression import Expr, ExprId, ExprMem +from miasm.expression.expression import Expr, ExprId, ExprMem, ExprInt from lldb_target import LLDBConcreteTarget, record_snapshot from miasm_util import MiasmConcreteState, eval_expr from snapshot import ProgramState +from arch import Arch, supported_architectures class SymbolicTransform: def __init__(self, from_addr: int, to_addr: int): self.addr = from_addr self.range = (from_addr, to_addr) + def concat(self, other: Self) -> Self: + """Concatenate another transform to this transform. + + The symbolic transform on which `concat` is called is the transform + that is applied first, meaning: `(a.concat(b))(state) == b(a(state))`. + """ + raise NotImplementedError('concat is abstract.') + def calc_register_transform(self, conc_state: ProgramState) \ -> dict[str, int]: raise NotImplementedError('calc_register_transform is abstract.') @@ -27,6 +39,7 @@ class SymbolicTransform: class MiasmSymbolicTransform(SymbolicTransform): def __init__(self, transform: dict[ExprId, Expr], + arch: Arch, loc_db: LocationDB, start_addr: int, end_addr: int): @@ -41,22 +54,67 @@ class MiasmSymbolicTransform(SymbolicTransform): self.regs_diff: dict[str, Expr] = {} self.mem_diff: dict[ExprMem, Expr] = {} - for id, expr in transform.items(): - if isinstance(id, ExprMem): - self.mem_diff[id.ptr] = expr - elif id.name != 'IRDst': - assert(isinstance(id, ExprId)) - self.regs_diff[id.name] = expr - + for dst, expr in transform.items(): + if isinstance(dst, ExprMem): + self.mem_diff[dst] = expr + else: + assert(isinstance(dst, ExprId)) + regname = arch.to_regname(dst.name) + if regname is not None: + self.regs_diff[regname] = expr + + self.arch = arch self.loc_db = loc_db + def concat(self, other: MiasmSymbolicTransform) -> Self: + class MiasmSymbolicState: + """Drop-in replacement for MiasmConcreteState in eval_expr that + returns the current transform's symbolic equations instead of + symbolic values. Calling eval_expr with this effectively nests the + transformation into the concatenated transformation. + """ + def __init__(self, transform: MiasmSymbolicTransform): + self.transform = transform + + def resolve_register(self, regname: str): + return self.transform.regs_diff.get(regname, None) + + def resolve_memory(self, addr: int, size: int): + mem = ExprMem(ExprInt(addr, 64), size) + return self.transform.mem_diff.get(mem, None) + + def resolve_location(self, _): + return None + + if self.range[1] != other.range[0]: + raise ValueError(f'The concatenated transformations must span a' + f' contiguous range of instructions.') + + ref_state = MiasmSymbolicState(self) + for reg, expr in other.regs_diff.items(): + if reg not in self.regs_diff: + self.regs_diff[reg] = expr + else: + self.regs_diff[reg] = eval_expr(expr, ref_state) + + for dst, expr in other.mem_diff.items(): + dst = eval_expr(dst, ref_state) + if dst not in self.mem_diff: + self.mem_diff[dst] = expr + else: + self.mem_diff[dst] = eval_expr(expr, ref_state) + + self.range = (self.range[0], other.range[1]) + + return self + def calc_register_transform(self, conc_state: ProgramState) \ -> dict[str, int]: ref_state = MiasmConcreteState(conc_state, self.loc_db) res = {} for regname, expr in self.regs_diff.items(): - res[regname] = eval_expr(expr, ref_state) + res[regname] = int(eval_expr(expr, ref_state)) return res def calc_memory_transform(self, conc_state: ProgramState) \ @@ -71,8 +129,14 @@ class MiasmSymbolicTransform(SymbolicTransform): return res def __repr__(self) -> str: - return f'Symbolic state transformation for instruction \ - {hex(self.addr)}.' + start, end = self.range + res = f'Symbolic state transformation {hex(start)} -> {hex(end)}:\n' + for reg, expr in self.regs_diff.items(): + res += f' {reg:6s} = {expr}\n' + for mem, expr in self.mem_diff.items(): + res += f' {mem} = {expr}\n' + + return res def _step_until(target: LLDBConcreteTarget, addr: int) -> list[int]: """Step a concrete target to a specific instruction. @@ -181,6 +245,12 @@ def collect_symbolic_trace(binary: str, cont = ContainerELF.from_stream(bin_file, loc_db) machine = Machine(cont.arch) + # Find corresponding architecture + if machine.name not in supported_architectures: + print(f'[ERROR] {machine.name} is not supported. Returning.') + return [] + arch = supported_architectures[machine.name] + # Create disassembly/lifting context mdis = machine.dis_engine(cont.bin_stream, loc_db=loc_db) mdis.follow_call = True @@ -243,8 +313,8 @@ def collect_symbolic_trace(binary: str, res = [] for (start, diff), (end, _) in zip(symb_trace[:-1], symb_trace[1:]): - res.append(MiasmSymbolicTransform(diff, loc_db, start, end)) + res.append(MiasmSymbolicTransform(diff, arch, loc_db, start, end)) start, diff = symb_trace[-1] - res.append(MiasmSymbolicTransform(diff, loc_db, start, start)) + res.append(MiasmSymbolicTransform(diff, arch, loc_db, start, start)) return res diff --git a/trace_symbols.py b/trace_symbols.py deleted file mode 100644 index 6e7cb3b..0000000 --- a/trace_symbols.py +++ /dev/null @@ -1,167 +0,0 @@ -import argparse -import sys - -import angr -import claripy as cp -from angr.exploration_techniques import Symbion - -from arch import x86 -from gen_trace import record_trace -from interpreter import eval, SymbolResolver, SymbolResolveError -from lldb_target import LLDBConcreteTarget -from symbolic import collect_symbolic_trace - -# Size of the memory region on the stack that is tracked symbolically -# We track [rbp - STACK_SIZE, rbp). -STACK_SIZE = 0x10 - -STACK_SYMBOL_NAME = 'stack' - -class SimStateResolver(SymbolResolver): - """A symbol resolver that resolves symbol names to program state in - `angr.SimState` objects. - """ - def __init__(self, state: angr.SimState): - self._state = state - - def resolve(self, symbol_name: str) -> cp.ast.Base: - # Process special (non-register) symbol names - if symbol_name == STACK_SYMBOL_NAME: - assert(self._state.regs.rbp.concrete) - assert(type(self._state.regs.rbp.v) is int) - rbp = self._state.regs.rbp.v - return self._state.memory.load(rbp - STACK_SIZE, STACK_SIZE) - - # Try to interpret the symbol as a register name - try: - return self._state.regs.get(symbol_name.lower()) - except AttributeError: - raise SymbolResolveError(symbol_name, - f'[SimStateResolver]: No attribute' - f' {symbol_name} in program state.') - -def print_state(state: angr.SimState, file=sys.stdout, conc_state=None): - """Print a program state in a fancy way. - - :param conc_state: Provide a concrete program state as a reference to - evaluate all symbolic values in `state` and print their - concrete values in addition to the symbolic expression. - """ - if conc_state is not None: - resolver = SimStateResolver(conc_state) - else: - resolver = None - - print('-' * 80, file=file) - print(f'State at {hex(state.addr)}:', file=file) - print('-' * 80, file=file) - for reg in x86.regnames: - try: - val = state.regs.get(reg.lower()) - except angr.SimConcreteRegisterError: val = '<inaccessible>' - except angr.SimConcreteMemoryError: val = '<inaccessible>' - except AttributeError: val = '<inaccessible>' - except KeyError: val = '<inaccessible>' - if resolver is not None: - concrete_value = eval(resolver, val) - if type(concrete_value) is int: - concrete_value = hex(concrete_value) - print(f'{reg} = {val} ({concrete_value})', file=file) - else: - print(f'{reg} = {val}', file=file) - - # Print some of the stack - print('\nStack:', file=file) - try: - # Ensure that the base pointer is concrete - rbp = state.regs.rbp - if not rbp.concrete: - if resolver is None: - raise SymbolResolveError(rbp, - '[In print_state]: rbp is symbolic,' - ' but no resolver is defined. Can\'t' - ' print stack.') - else: - rbp = eval(resolver, rbp) - - stack_mem = state.memory.load(rbp - STACK_SIZE, STACK_SIZE) - - if resolver is not None: - print(hex(eval(resolver, stack_mem)), file=file) - print(stack_mem, file=file) - stack = state.solver.eval(stack_mem, cast_to=bytes) - print(' '.join(f'{b:02x}' for b in stack[::-1]), file=file) - except angr.SimConcreteMemoryError: - print('<unable to read stack memory>', file=file) - print('-' * 80, file=file) - -def collect_concrete_trace(binary: str, trace: list[int]) -> list[angr.SimState]: - target = LLDBConcreteTarget(binary) - proj = angr.Project(binary, - concrete_target=target, - use_sim_procedures=False) - - state = proj.factory.entry_state() - state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC) - state.options.add(angr.options.SYMBION_SYNC_CLE) - - # Remove first address from trace if it is the entry point. - # Symbion doesn't find an address if it's the current state. - if len(trace) > 0 and trace[0] == state.addr: - trace = trace[1:] - - result = [] - - for inst in trace: - symbion = proj.factory.simgr(state) - symbion.use_technique(Symbion(find=[inst])) - - try: - conc_exploration = symbion.run() - except angr.AngrError: - assert(target.is_exited()) - break - state = conc_exploration.found[0] - result.append(state.copy()) - - return result - -def parse_args(): - prog = argparse.ArgumentParser() - prog.add_argument('binary', type=str) - prog.add_argument('--only-main', action='store_true', default=False) - return prog.parse_args() - -def main(): - args = parse_args() - binary = args.binary - only_main = args.only_main - - # Generate a program trace from a real execution - print('Collecting a program trace from a concrete execution...') - trace = record_trace(binary, [], - func_name='main' if only_main else None) - print(f'Found {len(trace)} trace points.') - - print('Executing the trace to collect concrete program states...') - concrete_trace = collect_concrete_trace(binary, trace) - - print('Re-tracing symbolically...') - try: - symbolic_trace = collect_symbolic_trace(binary, trace) - except KeyboardInterrupt: - print('Keyboard interrupt. Exiting.') - exit(0) - - with open('concrete.log', 'w') as conc_log: - for state in concrete_trace: - print_state(state, file=conc_log) - with open('symbolic.log', 'w') as symb_log: - for conc, symb in zip(concrete_trace, symbolic_trace): - print_state(symb.state, file=symb_log, conc_state=conc) - - print('Written symbolic trace to "symbolic.log".') - -if __name__ == "__main__": - main() - print('\nDone.') |