diff options
| author | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2023-10-21 16:39:49 +0200 |
|---|---|---|
| committer | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2023-10-21 16:39:49 +0200 |
| commit | 6f01367f9c8ad4c3d641cc63dbb1a3977ff4ec56 (patch) | |
| tree | 5b9677b9a5cca449497cea4418b2bb10e2ab0509 | |
| parent | 83d4b4dbe6f20c2fa7865e4888b89e888d3509f9 (diff) | |
| download | focaccia-6f01367f9c8ad4c3d641cc63dbb1a3977ff4ec56.tar.gz focaccia-6f01367f9c8ad4c3d641cc63dbb1a3977ff4ec56.zip | |
Support for testing concrete and emulated execution with angr
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | arch/__init__.py | 7 | ||||
| -rw-r--r-- | arch/arch.py | 5 | ||||
| -rw-r--r-- | arch/x86.py | 2 | ||||
| -rw-r--r-- | gen_trace.py | 50 | ||||
| -rw-r--r-- | lldb_target.py | 123 | ||||
| -rwxr-xr-x | main.py | 42 | ||||
| -rw-r--r-- | run.py | 102 | ||||
| -rw-r--r-- | test.py | 180 |
9 files changed, 421 insertions, 92 deletions
diff --git a/.gitignore b/.gitignore index 4586156..94631a0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ build*/ out-*/ __pycache__/ +# Dev environment +.gdbinit diff --git a/arch/__init__.py b/arch/__init__.py new file mode 100644 index 0000000..4943749 --- /dev/null +++ b/arch/__init__.py @@ -0,0 +1,7 @@ +from .arch import Arch +from . import x86 + +"""A dictionary containing all supported architectures at their names.""" +supported_architectures: dict[str, Arch] = { + "X86": x86.ArchX86(), +} diff --git a/arch/arch.py b/arch/arch.py index 36a4e3f..a46439e 100644 --- a/arch/arch.py +++ b/arch/arch.py @@ -1,6 +1,7 @@ class Arch(): - def __init__(self, regnames: list[str]): - self.regnames = regnames + def __init__(self, archname: str, regnames: list[str]): + self.archname = archname + self.regnames = set(regnames) def __eq__(self, other): return self.regnames == other.regnames diff --git a/arch/x86.py b/arch/x86.py index 0f60457..2b27315 100644 --- a/arch/x86.py +++ b/arch/x86.py @@ -30,4 +30,4 @@ regnames = ['PC', class ArchX86(Arch): def __init__(self): - super().__init__(regnames) + super().__init__("X86", regnames) diff --git a/gen_trace.py b/gen_trace.py new file mode 100644 index 0000000..64fcf8f --- /dev/null +++ b/gen_trace.py @@ -0,0 +1,50 @@ +import argparse +import lldb +import lldb_target + +def parse_args(): + prog = argparse.ArgumentParser() + prog.add_argument('binary', + help='The executable to trace.') + prog.add_argument('-o', '--output', + default='breakpoints', + type=str, + help='File to which the recorded trace is written.') + prog.add_argument('--args', + default=[], + nargs='+', + help='Arguments to the executable.') + return prog.parse_args() + +def record_trace(binary: str, args: list[str] = []) -> list[int]: + # Set up LLDB target + target = lldb_target.LLDBConcreteTarget(binary, args) + + # Skip to first instruction in `main` + result = lldb.SBCommandReturnObject() + break_at_main = f'b -b main -s {target.module.GetFileSpec().GetFilename()}' + target.interpreter.HandleCommand(break_at_main, result) + target.run() + + # Run until main function is exited + trace = [] + while not target.is_exited(): + thread = target.process.GetThreadAtIndex(0) + func_names = [thread.GetFrameAtIndex(i).GetFunctionName() for i in range(0, thread.GetNumFrames())] + if 'main' not in func_names: + break + trace.append(target.read_register('pc')) + thread.StepInstruction(False) + + return trace + +def main(): + args = parse_args() + trace = record_trace(args.binary, args.args) + with open(args.output, 'w') as file: + for addr in trace: + print(hex(addr), file=file) + print(f'Generated a trace of {len(trace)} instructions.') + +if __name__ == '__main__': + main() diff --git a/lldb_target.py b/lldb_target.py new file mode 100644 index 0000000..1ff9f53 --- /dev/null +++ b/lldb_target.py @@ -0,0 +1,123 @@ +import lldb + +from angr.errors import SimConcreteMemoryError, \ + SimConcreteRegisterError +from angr_targets.concrete import ConcreteTarget +from angr_targets.memory_map import MemoryMap + +class LLDBConcreteTarget(ConcreteTarget): + def __init__(self, executable: str, args: list[str] = []): + # Prepend the executable's path to argv, as is convention + args.insert(0, executable) + + self.debugger = lldb.SBDebugger.Create() + self.debugger.SetAsync(False) + self.target = self.debugger.CreateTargetWithFileAndArch(executable, + lldb.LLDB_ARCH_DEFAULT) + self.module = self.target.FindModule(self.target.GetExecutable()) + self.interpreter = self.debugger.GetCommandInterpreter() + + # Set up objects for process execution + self.error = lldb.SBError() + self.listener = self.debugger.GetListener() + self.process = self.target.Launch(self.listener, + args, None, None, + None, None, None, 0, + True, self.error) + if not self.process.IsValid(): + raise RuntimeError(f'[In LLDBConcreteTarget.__init__]: Failed to' + f' launch process.') + + def set_breakpoint(self, addr, **kwargs): + command = f'b -a {addr} -s {self.module.GetFileSpec().GetFilename()}' + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) + + def remove_breakpoint(self, addr, **kwargs): + command = f'breakpoint delete {addr}' + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) + + def is_running(self): + return self.process.GetState() == lldb.eStateRunning + + def is_exited(self): + """Not part of the angr interface, but much more useful than + `is_running`. + + :return: True if the process has exited. False otherwise. + """ + return self.process.GetState() == lldb.eStateExited + + def wait_for_running(self): + while self.process.GetState() != lldb.eStateRunning: + pass + + def wait_for_halt(self): + while self.process.GetState() != lldb.eStateStopped: + pass + + def run(self): + state = self.process.GetState() + if state == lldb.eStateExited: + raise RuntimeError(f'Tried to resume process execution, but the' + f' process has already exited.') + assert(state == lldb.eStateStopped) + self.process.Continue() + + def stop(self): + self.process.Stop() + + def exit(self): + self.debugger.Terminate() + print(f'Program exited with status {self.process.GetState()}') + + def read_register(self, regname: str) -> int: + frame = self.process.GetThreadAtIndex(0).GetFrameAtIndex(0) + reg = frame.FindRegister(regname) + if reg is None: + raise SimConcreteRegisterError( + f'[In LLDBConcreteTarget.read_register]: Register {regname}' + f' not found.') + + val = reg.GetValue() + if val is None: + raise SimConcreteRegisterError( + f'[In LLDBConcreteTarget.read_register]: Register has an' + f' invalid value of {val}.') + + return int(val, 16) + + def read_memory(self, addr, size): + err = lldb.SBError() + content = self.process.ReadMemory(addr, size, err) + if not err.success: + raise SimConcreteMemoryError(f'Error when reading {size} bytes at' + f' address {hex(addr)}: {err}') + return content + + def write_memory(self, addr, value): + err = lldb.SBError() + res = self.process.WriteMemory(addr, value, err) + if not err.success or res != len(value): + raise SimConcreteMemoryError(f'Error when writing to address' + f' {hex(addr)}: {err}') + + def get_mappings(self): + mmap = [] + + region_list = self.process.GetMemoryRegions() + for i in range(region_list.GetSize()): + region = lldb.SBMemoryRegionInfo() + region_list.GetMemoryRegionAtIndex(i, region) + + perms = f'{"r" if region.IsReadable() else "-"}' \ + f'{"w" if region.IsWritable() else "-"}' \ + f'{"x" if region.IsExecutable() else "-"}' \ + + mmap.append(MemoryMap(region.GetRegionBase(), + region.GetRegionEnd(), + 0, # offset? + "<no-name>", # name? + perms)) + return mmap diff --git a/main.py b/main.py index d97a54d..9451e42 100755 --- a/main.py +++ b/main.py @@ -8,21 +8,25 @@ from compare import compare_simple from run import run_native_execution from utils import check_version, print_separator -def read_logs(txl_path, native_path, program): +def parse_inputs(txl_path, ref_path, program): + # Our architecture + arch = x86.ArchX86() + txl = [] with open(txl_path, "r") as txl_file: - txl = txl_file.readlines() + txl = arancini.parse(txl_file.readlines(), arch) - native = [] + ref = [] if program is not None: - breakpoints = arancini.parse_break_addresses(txl) - native = run_native_execution(program, breakpoints) + with open(txl_path, "r") as txl_file: + breakpoints = arancini.parse_break_addresses(txl_file.readlines()) + ref = run_native_execution(program, breakpoints) else: - assert(native_path is not None) - with open(native_path, "r") as native_file: - native = native_file.readlines() + assert(ref_path is not None) + with open(ref_path, "r") as native_file: + ref = arancini.parse(native_file.readlines(), arch) - return txl, native + return txl, ref def parse_arguments(): parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference') @@ -57,35 +61,31 @@ def main(): args = parse_arguments() txl_path = args.txl - native_path = args.ref + reference_path = args.ref program = args.program stats = args.stats verbose = args.verbose progressive = args.progressive - # Our architexture - arch = x86.ArchX86() - if verbose: print("Enabling verbose program output") print(f"Verbose: {verbose}") print(f"Statistics: {stats}") print(f"Progressive: {progressive}") - if program is None and native_path is None: + if program is None and reference_path is None: raise ValueError('Either program or path to native file must be' 'provided') - txl, native = read_logs(txl_path, native_path, program) + txl, ref = parse_inputs(txl_path, reference_path, program) - if program != None and native_path != None: - with open(native_path, 'w') as w: - w.write(''.join(native)) + if program != None and reference_path != None: + with open(reference_path, 'w') as w: + for snapshot in ref: + print(snapshot, file=w) - txl = arancini.parse(txl, arch) - native = arancini.parse(native, arch) - result = compare_simple(txl, native) + result = compare_simple(txl, ref) # Print results for res in result: diff --git a/run.py b/run.py index 9b51fb5..6aca4d2 100644 --- a/run.py +++ b/run.py @@ -1,23 +1,24 @@ """Functionality to execute native programs and collect snapshots via lldb.""" -import re +import platform import sys import lldb from typing import Callable -# TODO: The debugger callback is currently specific to a single architexture. +# TODO: The debugger callback is currently specific to a single architecture. # We should make it generic. -from arch import x86 -from utils import print_separator +from arch import Arch, x86 +from snapshot import ProgramState -verbose = False +class SnapshotBuilder: + """At every breakpoint, writes register contents to a stream. -class DebuggerCallback: - """At every breakpoint, writes register contents to a stream.""" - - def __init__(self, ostream=sys.stdout): - self.stream = ostream - self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$') + Generated snapshots are stored in and can be read from `self.states`. + """ + def __init__(self, arch: Arch): + self.arch = arch + self.states = [] + self.regnames = set(arch.regnames) @staticmethod def parse_flags(flag_reg: int): @@ -44,48 +45,26 @@ class DebuggerCallback: flags['PF'] = int(0 != flag_reg & (1 << 1)) return flags - def print_regs(self, frame): + def create_snapshot(self, frame): + state = ProgramState(self.arch) + state.set('PC', frame.GetPC()) for reg in frame.GetRegisters(): for sub_reg in reg: - match = self.regex.match(sub_reg.GetName().upper()) - if match and match.group() == 'RFLAGS': - flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(), - base=16)) - for flag in flags: - print(f'flag {flag}:\t{hex(flags[flag])}', - file=self.stream) - elif match: - print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}", - file=self.stream) - - def print_stack(self, frame, element_count: int): - first = True - for i in range(element_count): - addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize() - error = lldb.SBError() - stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error)) - if error.Success() and not first: - print(f'{hex(stack_value)}', file=self.stream) - elif error.Success(): - print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream) - else: - print(f"Error reading memory at address 0x{addr:x}", - file=self.stream) - first=False + # Set the register's value in the current snapshot + regname = sub_reg.GetName().upper() + if regname in self.regnames: + regval = int(sub_reg.GetValue(), base=16) + if regname == 'RFLAGS': + flags = SnapshotBuilder.parse_flags(regval) + for flag, val in flags.items(): + state.set(f'flag {flag}', val) + else: + state.set(regname, regval) + return state def __call__(self, frame): - pc = frame.GetPC() - - print_separator('=', stream=self.stream, count=20) - print(f'INVOKE PC={hex(pc)}', file=self.stream) - print_separator('=', stream=self.stream, count=20) - - print("Register values:", file=self.stream) - self.print_regs(frame) - print_separator(stream=self.stream) - - print("STACK:", file=self.stream) - self.print_stack(frame, 20) + snapshot = self.create_snapshot(frame) + self.states.append(snapshot) class Debugger: def __init__(self, program): @@ -101,9 +80,6 @@ class Debugger: result = lldb.SBCommandReturnObject() self.interpreter.HandleCommand(command, result) - if verbose: - print(f'Set breakpoint at address {hex(address)}') - def get_breakpoints_count(self): return self.target.GetNumBreakpoints() @@ -128,23 +104,11 @@ class Debugger: if state == lldb.eStateExited: break - self.debugger.Terminate() - print(f'Process state: {process.GetState()}') print('Program output:') print(process.GetSTDOUT(1024)) print(process.GetSTDERR(1024)) -class ListWriter: - def __init__(self): - self.data = [] - - def write(self, s): - self.data.append(s) - - def __str__(self): - return "".join(self.data) - def run_native_execution(oracle_program: str, breakpoints: set[int]): """Gather snapshots from a native execution via an external debugger. @@ -152,10 +116,11 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]): :param breakpoints: List of addresses at which to break and record the program's state. - :return: A textual log of the program's execution in arancini's log format. + :return: A list of snapshots gathered from the execution. """ + assert(platform.machine() == "x86_64") + debugger = Debugger(oracle_program) - writer = ListWriter() # Set breakpoints for address in breakpoints: @@ -163,6 +128,7 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]): assert(debugger.get_breakpoints_count() == len(breakpoints)) # Execute the native program - debugger.execute(DebuggerCallback(writer)) + builder = SnapshotBuilder(x86.ArchX86()) + debugger.execute(builder) - return writer.data + return builder.states diff --git a/test.py b/test.py new file mode 100644 index 0000000..72a1438 --- /dev/null +++ b/test.py @@ -0,0 +1,180 @@ +import angr +import angr_targets +import claripy as cp +import sys + +from lldb_target import LLDBConcreteTarget + +from arancini import parse_break_addresses +from arch import x86 + +def print_state(state, file=sys.stdout): + for reg in x86.regnames: + try: + val = state.regs.__getattr__(reg.lower()) + print(f'{reg} = {val}', file=file) + except angr.SimConcreteRegisterError: + print(f'Unable to read value of register {reg}: register error', + file=file) + except angr.SimConcreteMemoryError: + print(f'Unable to read value of register {reg}: memory error', + file=file) + except AttributeError: + print(f'Unable to read value of register {reg}: AttributeError', + file=file) + except KeyError: + print(f'Unable to read value of register {reg}: KeyError', + file=file) + +def copy_state(src: angr_targets.ConcreteTarget, dst: angr.SimState): + """Copy a concrete program state to an `angr.SimState` object.""" + # Copy register contents + for reg in x86.regnames: + regname = reg.lower() + try: + dst.regs.__setattr__(regname, src.read_register(regname)) + except angr.SimConcreteRegisterError: + # Register does not exist (i.e. "flag ZF") + pass + + # Copy memory contents + for mapping in src.get_mappings(): + addr = mapping.start_address + size = mapping.end_address - mapping.start_address + try: + dst.memory.store(addr, src.read_memory(addr, size), size) + except angr.SimConcreteMemoryError: + # Invalid memory access + pass + +def symbolize_state(state: angr.SimState): + for reg in x86.regnames: + if reg != 'PC': + symb_val = cp.BVS(reg, 64) + try: + state.regs.__setattr__(reg.lower(), symb_val) + except AttributeError: + pass + +def output_truth(breakpoints: set[int]): + import run + res = run.run_native_execution(BINARY, breakpoints) + with open('truth.log', 'w') as file: + for snapshot in res: + print(cp.BVV(snapshot.regs['PC'], 64), file=file) + +BINARY = "hello-static-musl" +BREAKPOINT_LOG = "emulator-log.txt" + +# Read breakpoint addresses from a file +with open(BREAKPOINT_LOG, "r") as file: + breakpoints = parse_break_addresses(file.readlines()) + +print(f'Found {len(breakpoints)} breakpoints.') + +class ConcreteExecution: + def __init__(self, executable: str, breakpoints: list[int]): + self.target = LLDBConcreteTarget(executable) + self.proj = angr.Project(executable, + concrete_target=self.target, + use_sim_procedures=False) + + # Set the initial state + state = self.proj.factory.entry_state() + state.options.add(angr.options.SYMBION_SYNC_CLE) + state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC) + self.simgr = self.proj.factory.simgr(state) + self.simgr.use_technique( + angr.exploration_techniques.Symbion(find=breakpoints)) + + def is_running(self): + return not self.target.is_exited() + + def step(self) -> angr.SimState | None: + self.simgr.run() + self.simgr.unstash(to_stash='active', from_stash='found') + if len(self.simgr.active) > 0: + state = self.simgr.active[0] + print(f'-- Concrete execution hit a breakpoint at {state.regs.pc}!') + return state + return None + +class SymbolicExecution: + def __init__(self, executable: str): + self.proj = angr.Project(executable, use_sim_procedures=False) + self.simgr = self.proj.factory.simgr(self.proj.factory.entry_state()) + + def is_running(self): + return len(self.simgr.active) > 0 + + def step(self, find) -> angr.SimState | None: + self.simgr.explore(find=find) + self.simgr.unstash(to_stash='active', from_stash='found') + if len(self.simgr.active) == 0: + print(f'No states found. Stashes: {self.simgr.stashes}') + return None + + state = self.simgr.active[0] + assert(len(self.simgr.active) == 1) + print(f'-- Symbolic execution stopped at {state.regs.pc}!') + print(f' Found the following stashes: {self.simgr.stashes}') + + return state + +output_truth(breakpoints) + +conc = ConcreteExecution(BINARY, list(breakpoints)) +symb = SymbolicExecution(BINARY) + +conc_log = open('concrete.log', 'w') +symb_log = open('symbolic.log', 'w') + +while True: + if not (conc.is_running() and symb.is_running()): + assert(not conc.is_running() and not symb.is_running()) + print(f'Execution has exited.') + exit(0) + + # It seems that we have to copy the program's state manually to the state + # handed to the symbolic engine, otherwise the program emulation is + # incorrect. Something in angr's emulation is scuffed. + copy_state(conc.target, symb.simgr.active[0]) + + # angr performs a sanity check to ensure that the address at which the + # concrete engine stops actually is one of the breakpoints specified by + # the user. This sanity check is faulty because it is performed before the + # user has a chance determine whether the program has exited. If the + # program counter is read after the concrete execution has exited, LLDB + # returns a null value and the check fails, resulting in a crash. This + # try/catch block prevents that. + # + # As of angr commit `cbeace5d7`, this faulty read of the program counter + # can be found at `angr/engines/concrete.py:148`. + try: + conc_state = conc.step() + if conc_state is None: + print(f'Execution has exited: ConcreteExecution.step() returned null.') + exit(0) + except angr.SimConcreteRegisterError: + print(f'Done.') + exit(0) + + pc = conc_state.solver.eval(conc_state.regs.pc) + print(f'-- Trying to find address {hex(pc)} with symbolic execution...') + + # TODO: + #symbolize_state(symb.simgr.active[0]) + symb_state = symb.step(pc) + + # Check exit conditions + if symb_state is None: + print(f'Execution has exited: SymbolicExecution.step() returned null.') + exit(0) + assert(pc == symb_state.solver.eval(symb_state.regs.pc)) + + # Log some stuff + print(f'-- Concrete breakpoint {conc_state.regs.pc}' + f' vs symbolic breakpoint {symb_state.regs.pc}') + + print(conc_state.regs.pc, file=conc_log) + print(symb_state.regs.pc, file=symb_log) |