diff options
Diffstat (limited to 'run.py')
| -rw-r--r-- | run.py | 102 |
1 files changed, 34 insertions, 68 deletions
diff --git a/run.py b/run.py index 9b51fb5..6aca4d2 100644 --- a/run.py +++ b/run.py @@ -1,23 +1,24 @@ """Functionality to execute native programs and collect snapshots via lldb.""" -import re +import platform import sys import lldb from typing import Callable -# TODO: The debugger callback is currently specific to a single architexture. +# TODO: The debugger callback is currently specific to a single architecture. # We should make it generic. -from arch import x86 -from utils import print_separator +from arch import Arch, x86 +from snapshot import ProgramState -verbose = False +class SnapshotBuilder: + """At every breakpoint, writes register contents to a stream. -class DebuggerCallback: - """At every breakpoint, writes register contents to a stream.""" - - def __init__(self, ostream=sys.stdout): - self.stream = ostream - self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$') + Generated snapshots are stored in and can be read from `self.states`. + """ + def __init__(self, arch: Arch): + self.arch = arch + self.states = [] + self.regnames = set(arch.regnames) @staticmethod def parse_flags(flag_reg: int): @@ -44,48 +45,26 @@ class DebuggerCallback: flags['PF'] = int(0 != flag_reg & (1 << 1)) return flags - def print_regs(self, frame): + def create_snapshot(self, frame): + state = ProgramState(self.arch) + state.set('PC', frame.GetPC()) for reg in frame.GetRegisters(): for sub_reg in reg: - match = self.regex.match(sub_reg.GetName().upper()) - if match and match.group() == 'RFLAGS': - flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(), - base=16)) - for flag in flags: - print(f'flag {flag}:\t{hex(flags[flag])}', - file=self.stream) - elif match: - print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}", - file=self.stream) - - def print_stack(self, frame, element_count: int): - first = True - for i in range(element_count): - addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize() - error = lldb.SBError() - stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error)) - if error.Success() and not first: - print(f'{hex(stack_value)}', file=self.stream) - elif error.Success(): - print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream) - else: - print(f"Error reading memory at address 0x{addr:x}", - file=self.stream) - first=False + # Set the register's value in the current snapshot + regname = sub_reg.GetName().upper() + if regname in self.regnames: + regval = int(sub_reg.GetValue(), base=16) + if regname == 'RFLAGS': + flags = SnapshotBuilder.parse_flags(regval) + for flag, val in flags.items(): + state.set(f'flag {flag}', val) + else: + state.set(regname, regval) + return state def __call__(self, frame): - pc = frame.GetPC() - - print_separator('=', stream=self.stream, count=20) - print(f'INVOKE PC={hex(pc)}', file=self.stream) - print_separator('=', stream=self.stream, count=20) - - print("Register values:", file=self.stream) - self.print_regs(frame) - print_separator(stream=self.stream) - - print("STACK:", file=self.stream) - self.print_stack(frame, 20) + snapshot = self.create_snapshot(frame) + self.states.append(snapshot) class Debugger: def __init__(self, program): @@ -101,9 +80,6 @@ class Debugger: result = lldb.SBCommandReturnObject() self.interpreter.HandleCommand(command, result) - if verbose: - print(f'Set breakpoint at address {hex(address)}') - def get_breakpoints_count(self): return self.target.GetNumBreakpoints() @@ -128,23 +104,11 @@ class Debugger: if state == lldb.eStateExited: break - self.debugger.Terminate() - print(f'Process state: {process.GetState()}') print('Program output:') print(process.GetSTDOUT(1024)) print(process.GetSTDERR(1024)) -class ListWriter: - def __init__(self): - self.data = [] - - def write(self, s): - self.data.append(s) - - def __str__(self): - return "".join(self.data) - def run_native_execution(oracle_program: str, breakpoints: set[int]): """Gather snapshots from a native execution via an external debugger. @@ -152,10 +116,11 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]): :param breakpoints: List of addresses at which to break and record the program's state. - :return: A textual log of the program's execution in arancini's log format. + :return: A list of snapshots gathered from the execution. """ + assert(platform.machine() == "x86_64") + debugger = Debugger(oracle_program) - writer = ListWriter() # Set breakpoints for address in breakpoints: @@ -163,6 +128,7 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]): assert(debugger.get_breakpoints_count() == len(breakpoints)) # Execute the native program - debugger.execute(DebuggerCallback(writer)) + builder = SnapshotBuilder(x86.ArchX86()) + debugger.execute(builder) - return writer.data + return builder.states |