diff options
| author | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2023-11-27 13:22:01 +0100 |
|---|---|---|
| committer | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2023-11-27 13:22:01 +0100 |
| commit | 5d51b4fe0bb41bc9e86c5775de35a9aef023fec5 (patch) | |
| tree | 09d1f87c8a3964f72b71b7a04945a7f5e7e12abe | |
| parent | 47894bb5d2e425f28d992aee6331b89b85b2058d (diff) | |
| download | focaccia-5d51b4fe0bb41bc9e86c5775de35a9aef023fec5.tar.gz focaccia-5d51b4fe0bb41bc9e86c5775de35a9aef023fec5.zip | |
Implement symbolic state comparison algorithm
This is the first draft of a `compare` algorithm that uses recorded symbolic transformations. Is currently based on angr, so it's probably going to be reworked to work with states generated by Miasm. Co-authored-by: Theofilos Augoustis <theofilos.augoustis@gmail.com> Co-authored-by: Nicola Crivellin <nicola.crivellin98@gmail.com>
| -rw-r--r-- | arch/x86.py | 2 | ||||
| -rw-r--r-- | compare.py | 215 | ||||
| -rw-r--r-- | gen_trace.py | 57 | ||||
| -rw-r--r-- | lldb_target.py | 9 | ||||
| -rwxr-xr-x | main.py | 33 | ||||
| -rw-r--r-- | snapshot.py | 9 | ||||
| -rw-r--r-- | symbolic.py | 31 | ||||
| -rw-r--r-- | trace_symbols.py | 45 |
8 files changed, 210 insertions, 191 deletions
diff --git a/arch/x86.py b/arch/x86.py index 01c1631..25213a0 100644 --- a/arch/x86.py +++ b/arch/x86.py @@ -22,6 +22,8 @@ regnames = [ 'R14', 'R15', 'RFLAGS', + # Segment registers + 'CS', 'DS', 'SS', 'ES', 'FS', 'GS', # FLAGS 'CF', 'PF', 'AF', 'ZF', 'SF', 'TF', 'IF', 'DF', 'OF', 'IOPL', 'NT', # EFLAGS diff --git a/compare.py b/compare.py index a191025..8a25d8a 100644 --- a/compare.py +++ b/compare.py @@ -1,4 +1,5 @@ -from snapshot import ProgramState +from snapshot import ProgramState, SnapshotSymbolResolver +from symbolic import SymbolicTransform from utils import print_separator def calc_transformation(previous: ProgramState, current: ProgramState): @@ -117,144 +118,76 @@ def compare_simple(test_states: list[ProgramState], return result -def equivalent(val1, val2, transformation, previous_translation): - if val1 == val2: - return True - - # TODO: maybe incorrect - return val1 - previous_translation == transformation - -def verify(translation: ProgramState, reference: ProgramState, - transformation: ProgramState, previous_translation: ProgramState): - assert(translation.arch == reference.arch) - - if translation.regs["PC"] != reference.regs["PC"]: - return 1 - - print_separator() - print(f'For PC={translation.as_repr("PC")}') - print_separator() - for reg in translation.arch.regnames: - if translation.regs[reg] is None: - print(f'Element not available in translation: {reg}') - elif reference.regs[reg] is None: - print(f'Element not available in reference: {reg}') - elif not equivalent(translation.regs[reg], reference.regs[reg], - transformation.regs[reg], - previous_translation.regs[reg]): - txl = translation.as_repr(reg) - ref = reference.as_repr(reg) - print(f'Difference for {reg}: {txl} != {ref}') - - return 0 - -def compare(txl: list[ProgramState], - native: list[ProgramState], - progressive: bool = False, - stats: bool = False): - """Compare two lists of snapshots and output the differences. - - :param txl: The translated, and possibly faulty, state of the program. - :param native: The 'correct' reference state of the program. - :param progressive: - :param stats: - """ +def find_errors_symbolic(txl_from: ProgramState, + txl_to: ProgramState, + transform_truth: SymbolicTransform) \ + -> list[dict]: + arch = txl_from.arch + resolver = SnapshotSymbolResolver(txl_from) - if len(txl) != len(native): - print(f'Different numbers of blocks discovered: ' - f'{len(txl)} in translation vs. {len(native)} in reference.') - - previous_reference = native[0] - previous_translation = txl[0] - - unmatched_pcs = {} - pc_to_skip = "" - if progressive: - i = 0 - for translation in txl: - previous = i - - while i < len(native): - reference = native[i] - transformation = calc_transformation(previous_reference, reference) - if verify(translation, reference, transformation, previous_translation) == 0: - reference.matched = True - break - - i += 1 - - matched = True - - # Didn't find anything - if i == len(native): - matched = False - # TODO: add verbose output - print_separator() - print(f'No match for PC {hex(translation.regs["PC"])}') - if translation.regs['PC'] not in unmatched_pcs: - unmatched_pcs[translation.regs['PC']] = 0 - unmatched_pcs[translation.regs['PC']] += 1 - - i = previous - - # Necessary since we may have run out of native BBs to check and - # previous becomes len(native) - # - # We continue checking to report unmatched translation PCs - if i < len(native): - previous_reference = native[i] - - previous_translation = translation - - # Skip next reference when there is a backwards branch - # NOTE: if a reference was skipped, don't skip it again - # necessary for loops which may have multiple backwards - # branches - if translation.has_backwards and translation.regs['PC'] != pc_to_skip: - pc_to_skip = translation.regs['PC'] - i += 1 - - if matched: - i += 1 - else: - txl = iter(txl) - native = iter(native) - for translation, reference in zip(txl, native): - transformation = calc_transformation(previous_reference, reference) - if verify(translation, reference, transformation, previous_translation) == 1: - # TODO: add verbose output - print_separator() - print(f'No match for PC {hex(translation.regs["PC"])}') - if translation.regs['PC'] not in unmatched_pcs: - unmatched_pcs[translation.regs['PC']] = 0 - unmatched_pcs[translation.regs['PC']] += 1 - else: - reference.matched = True - - if translation.has_backwards: - next(native) - - previous_reference = reference - previous_translation = translation - - if stats: - print_separator() - print('Statistics:') - print_separator() - - for pc in unmatched_pcs: - print(f'PC {hex(pc)} unmatched {unmatched_pcs[pc]} times') - - # NOTE: currently doesn't handle mismatched due backward branches - current = "" - unmatched_count = 0 - for ref in native: - ref_pc = ref.regs['PC'] - if ref_pc != current: - if unmatched_count: - print(f'Reference PC {hex(current)} unmatched {unmatched_count} times') - current = ref_pc - - if ref.matched == False: - unmatched_count += 1 - return 0 + assert(txl_from.read('PC') == transform_truth.start_addr) + assert(txl_to.read('PC') == transform_truth.end_addr) + + errors = [] + for reg in arch.regnames: + if txl_from.read(reg) is None or txl_to.read(reg) is None: + print(f'A value for {reg} must be set in all translated states.' + ' Skipping.') + continue + + txl_val = txl_to.read(reg) + try: + truth = transform_truth.eval_register_transform(reg.lower(), resolver) + print(f'Evaluated symbolic formula to {hex(txl_val)} vs. txl {hex(txl_val)}') + if txl_val != truth: + errors.append({ + 'reg': reg, + 'expected': truth, + 'actual': txl_val, + 'equation': transform_truth.state.regs.get(reg), + }) + except AttributeError: + print(f'Register {reg} does not exist.') + + return errors + +def compare_symbolic(test_states: list[ProgramState], + transforms: list[SymbolicTransform]): + #assert(len(test_states) == len(transforms) - 1) + PC_REGNAME = 'PC' + + result = [{ + 'pc': test_states[0].regs[PC_REGNAME], + 'txl': test_states[0], + 'ref': transforms[0], + 'errors': [] + }] + + _list = zip(test_states[:-1], test_states[1:], transforms) + for cur_state, next_state, transform in _list: + pc_cur = cur_state.read(PC_REGNAME) + pc_next = next_state.read(PC_REGNAME) + + # The program counter should always be set on a snapshot + assert(pc_cur is not None and pc_next is not None) + + if pc_cur != transform.start_addr: + print(f'Program counter {hex(pc_cur)} in translated code has no' + f' corresponding reference state! Skipping.' + f' (reference: {hex(transform.start_addr)})') + continue + if pc_next != transform.end_addr: + print(f'Tested state transformation is {hex(pc_cur)} ->' + f' {hex(pc_next)}, but reference transform is' + f' {hex(transform.start_addr)} -> {hex(transform.end_addr)}!' + f' Skipping.') + + errors = find_errors_symbolic(cur_state, next_state, transform) + result.append({ + 'pc': pc_cur, + 'txl': calc_transformation(cur_state, next_state), + 'ref': transform, + 'errors': errors + }) + + return result diff --git a/gen_trace.py b/gen_trace.py index 64fcf8f..ec5cb86 100644 --- a/gen_trace.py +++ b/gen_trace.py @@ -2,42 +2,55 @@ import argparse import lldb import lldb_target -def parse_args(): - prog = argparse.ArgumentParser() - prog.add_argument('binary', - help='The executable to trace.') - prog.add_argument('-o', '--output', - default='breakpoints', - type=str, - help='File to which the recorded trace is written.') - prog.add_argument('--args', - default=[], - nargs='+', - help='Arguments to the executable.') - return prog.parse_args() - -def record_trace(binary: str, args: list[str] = []) -> list[int]: +def record_trace(binary: str, + args: list[str] = [], + func_name: str | None = 'main') -> list[int]: + """ + :param binary: The binary file to execute. + :param args: Arguments to the program. Should *not* include the + executable's location as the usual first argument. + :param func_name: Only record trace of a specific function. + """ # Set up LLDB target target = lldb_target.LLDBConcreteTarget(binary, args) # Skip to first instruction in `main` - result = lldb.SBCommandReturnObject() - break_at_main = f'b -b main -s {target.module.GetFileSpec().GetFilename()}' - target.interpreter.HandleCommand(break_at_main, result) - target.run() + if func_name is not None: + result = lldb.SBCommandReturnObject() + break_at_func = f'b -b {func_name} -s {target.module.GetFileSpec().GetFilename()}' + target.interpreter.HandleCommand(break_at_func, result) + target.run() # Run until main function is exited trace = [] while not target.is_exited(): thread = target.process.GetThreadAtIndex(0) - func_names = [thread.GetFrameAtIndex(i).GetFunctionName() for i in range(0, thread.GetNumFrames())] - if 'main' not in func_names: - break + + # Break if the traced function is exited + if func_name is not None: + func_names = [thread.GetFrameAtIndex(i).GetFunctionName() \ + for i in range(0, thread.GetNumFrames())] + if func_name not in func_names: + break trace.append(target.read_register('pc')) thread.StepInstruction(False) return trace +def parse_args(): + prog = argparse.ArgumentParser() + prog.add_argument('binary', + help='The executable to trace.') + prog.add_argument('-o', '--output', + default='breakpoints', + type=str, + help='File to which the recorded trace is written.') + prog.add_argument('--args', + default=[], + nargs='+', + help='Arguments to the executable.') + return prog.parse_args() + def main(): args = parse_args() trace = record_trace(args.binary, args.args) diff --git a/lldb_target.py b/lldb_target.py index 5477ab7..dd0d543 100644 --- a/lldb_target.py +++ b/lldb_target.py @@ -124,7 +124,7 @@ class LLDBConcreteTarget(ConcreteTarget): raise SimConcreteMemoryError(f'Error when writing to address' f' {hex(addr)}: {err}') - def get_mappings(self): + def get_mappings(self) -> list[MemoryMap]: mmap = [] region_list = self.process.GetMemoryRegions() @@ -134,11 +134,12 @@ class LLDBConcreteTarget(ConcreteTarget): perms = f'{"r" if region.IsReadable() else "-"}' \ f'{"w" if region.IsWritable() else "-"}' \ - f'{"x" if region.IsExecutable() else "-"}' \ + f'{"x" if region.IsExecutable() else "-"}' + name = region.GetName() mmap.append(MemoryMap(region.GetRegionBase(), region.GetRegionEnd(), - 0, # offset? - "<no-name>", # name? + 0, # offset? + name if name is not None else '<none>', perms)) return mmap diff --git a/main.py b/main.py index 9451e42..b0aeb36 100755 --- a/main.py +++ b/main.py @@ -4,8 +4,10 @@ import argparse import arancini from arch import x86 -from compare import compare_simple +from compare import compare_simple, compare_symbolic +from gen_trace import record_trace from run import run_native_execution +from symbolic import collect_symbolic_trace from utils import check_version, print_separator def parse_inputs(txl_path, ref_path, program): @@ -49,11 +51,12 @@ def parse_arguments(): action='store_true', default=True, help='Path to oracle program') - parser.add_argument('--progressive', + parser.add_argument('--symbolic', action='store_true', default=False, - help='Try to match exhaustively before declaring \ - mismatch') + help='Use an advanced algorithm that uses symbolic' + ' execution to determine accurate data' + ' transformations') args = parser.parse_args() return args @@ -66,13 +69,12 @@ def main(): stats = args.stats verbose = args.verbose - progressive = args.progressive if verbose: print("Enabling verbose program output") print(f"Verbose: {verbose}") print(f"Statistics: {stats}") - print(f"Progressive: {progressive}") + print(f"Symbolic: {args.symbolic}") if program is None and reference_path is None: raise ValueError('Either program or path to native file must be' @@ -85,7 +87,18 @@ def main(): for snapshot in ref: print(snapshot, file=w) - result = compare_simple(txl, ref) + if args.symbolic: + assert(program is not None) + + full_trace = record_trace(program, args=[]) + transforms = collect_symbolic_trace(program, full_trace) + # TODO: Transform the traces so that the states match + result = compare_symbolic(txl, transforms) + + raise NotImplementedError('The symbolic comparison algorithm is not' + ' supported yet.') + else: + result = compare_simple(txl, ref) # Print results for res in result: @@ -104,6 +117,12 @@ def main(): f' (txl) {reg}: {hex(txl.regs[reg])}\n' f' (ref) {reg}: {hex(ref.regs[reg])}') + print() + print('#' * 60) + print(f'Found {sum(len(res["errors"]) for res in result)} errors.') + print('#' * 60) + print() + if __name__ == "__main__": check_version('3.7') main() diff --git a/snapshot.py b/snapshot.py index 01c6446..3170649 100644 --- a/snapshot.py +++ b/snapshot.py @@ -38,3 +38,12 @@ class ProgramState: def __repr__(self): return repr(self.regs) + +class SnapshotSymbolResolver(SymbolResolver): + def __init__(self, snapshot: ProgramState): + self._state = snapshot + + def resolve(self, symbol: str): + if symbol not in self._state.arch.regnames: + raise SymbolResolveError(symbol, 'Symbol is not a register name.') + return self._state.read(symbol) diff --git a/symbolic.py b/symbolic.py index 53e1bbf..56857d7 100644 --- a/symbolic.py +++ b/symbolic.py @@ -5,7 +5,7 @@ import claripy as cp from angr.exploration_techniques import Symbion from arch import Arch, x86 -from interpreter import SymbolResolver +from interpreter import eval as eval_symbol, SymbolResolver from lldb_target import LLDBConcreteTarget def symbolize_state(state: angr.SimState, @@ -28,7 +28,7 @@ def symbolize_state(state: angr.SimState, if stack_name not in _exclude: symb_stack = cp.BVS(stack_name, stack_size * 8, explicit_name=True) - state.memory.store(state.regs.rbp - stack_size, symb_stack) + state.memory.store(state.regs.rsp - stack_size, symb_stack) for reg in arch.regnames: if reg not in _exclude: @@ -68,7 +68,15 @@ class SymbolicTransform: self.end_addr = end_inst def eval_register_transform(self, regname: str, resolver: SymbolResolver): - raise NotImplementedError('TODO') + """ + :param regname: Name of the register to evaluate. + :param resolver: A provider for the values to be plugged into the + symbolic equation. + + :raise angr.SimConcreteRegisterError: If the state contains no register + named `regname`. + """ + return eval_symbol(resolver, self.state.regs.get(regname)) def __repr__(self) -> str: return f'Symbolic state transformation: \ @@ -87,7 +95,7 @@ def collect_symbolic_trace(binary: str, trace: list[int]) \ concrete_target=target, use_sim_procedures=False) - entry_state = proj.factory.entry_state() + entry_state = proj.factory.entry_state(addr=trace[0]) entry_state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC) entry_state.options.add(angr.options.SYMBION_SYNC_CLE) @@ -105,13 +113,26 @@ def collect_symbolic_trace(binary: str, trace: list[int]) \ symbion = proj.factory.simgr(entry_state) symbion.use_technique(Symbion(find=[cur_inst])) - conc_exploration = symbion.run() + try: + if cur_inst != entry_state.addr: + conc_exploration = symbion.run() + else: + symbion.move('active', 'found') + conc_exploration = symbion + except angr.AngrError as err: + print(f'Angr error: {err} Returning partial result.') + return result conc_state = conc_exploration.found[0] + entry_state = conc_state concrete_states[conc_state.addr] = conc_state.copy() # Start symbolic execution with the concrete ('truth') state and try # to reach the next instruction in the trace + # + # -- Notes -- + # It does not even work when I feed the entirely concrete state + # `conc_state` that I receive from Symbion into the symbolic simgr. simgr = proj.factory.simgr(symbolize_state(conc_state)) symb_exploration = simgr.explore(find=next_inst) diff --git a/trace_symbols.py b/trace_symbols.py index e529522..6e7cb3b 100644 --- a/trace_symbols.py +++ b/trace_symbols.py @@ -9,7 +9,7 @@ from arch import x86 from gen_trace import record_trace from interpreter import eval, SymbolResolver, SymbolResolveError from lldb_target import LLDBConcreteTarget -from symbolic import symbolize_state, collect_symbolic_trace +from symbolic import collect_symbolic_trace # Size of the memory region on the stack that is tracked symbolically # We track [rbp - STACK_SIZE, rbp). @@ -95,12 +95,7 @@ def print_state(state: angr.SimState, file=sys.stdout, conc_state=None): print('<unable to read stack memory>', file=file) print('-' * 80, file=file) -def parse_args(): - prog = argparse.ArgumentParser() - prog.add_argument('binary', type=str) - return prog.parse_args() - -def collect_concrete_trace(binary: str) -> list[angr.SimState]: +def collect_concrete_trace(binary: str, trace: list[int]) -> list[angr.SimState]: target = LLDBConcreteTarget(binary) proj = angr.Project(binary, concrete_target=target, @@ -110,29 +105,53 @@ def collect_concrete_trace(binary: str) -> list[angr.SimState]: state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC) state.options.add(angr.options.SYMBION_SYNC_CLE) + # Remove first address from trace if it is the entry point. + # Symbion doesn't find an address if it's the current state. + if len(trace) > 0 and trace[0] == state.addr: + trace = trace[1:] + result = [] - trace = record_trace(binary) for inst in trace: symbion = proj.factory.simgr(state) symbion.use_technique(Symbion(find=[inst])) - conc_exploration = symbion.run() + try: + conc_exploration = symbion.run() + except angr.AngrError: + assert(target.is_exited()) + break state = conc_exploration.found[0] result.append(state.copy()) return result +def parse_args(): + prog = argparse.ArgumentParser() + prog.add_argument('binary', type=str) + prog.add_argument('--only-main', action='store_true', default=False) + return prog.parse_args() + def main(): args = parse_args() binary = args.binary + only_main = args.only_main # Generate a program trace from a real execution - concrete_trace = collect_concrete_trace(binary) - trace = [int(state.addr) for state in concrete_trace] + print('Collecting a program trace from a concrete execution...') + trace = record_trace(binary, [], + func_name='main' if only_main else None) print(f'Found {len(trace)} trace points.') - symbolic_trace = collect_symbolic_trace(binary, trace) + print('Executing the trace to collect concrete program states...') + concrete_trace = collect_concrete_trace(binary, trace) + + print('Re-tracing symbolically...') + try: + symbolic_trace = collect_symbolic_trace(binary, trace) + except KeyboardInterrupt: + print('Keyboard interrupt. Exiting.') + exit(0) with open('concrete.log', 'w') as conc_log: for state in concrete_trace: @@ -141,6 +160,8 @@ def main(): for conc, symb in zip(concrete_trace, symbolic_trace): print_state(symb.state, file=symb_log, conc_state=conc) + print('Written symbolic trace to "symbolic.log".') + if __name__ == "__main__": main() print('\nDone.') |