about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--arch/__init__.py7
-rw-r--r--arch/arch.py5
-rw-r--r--arch/x86.py2
-rw-r--r--gen_trace.py50
-rw-r--r--lldb_target.py123
-rwxr-xr-xmain.py42
-rw-r--r--run.py102
-rw-r--r--test.py180
9 files changed, 421 insertions, 92 deletions
diff --git a/.gitignore b/.gitignore
index 4586156..94631a0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,5 @@ build*/
 out-*/
 __pycache__/
 
+# Dev environment
+.gdbinit
diff --git a/arch/__init__.py b/arch/__init__.py
new file mode 100644
index 0000000..4943749
--- /dev/null
+++ b/arch/__init__.py
@@ -0,0 +1,7 @@
+from .arch import Arch
+from . import x86
+
+"""A dictionary containing all supported architectures at their names."""
+supported_architectures: dict[str, Arch] = {
+    "X86": x86.ArchX86(),
+}
diff --git a/arch/arch.py b/arch/arch.py
index 36a4e3f..a46439e 100644
--- a/arch/arch.py
+++ b/arch/arch.py
@@ -1,6 +1,7 @@
 class Arch():
-    def __init__(self, regnames: list[str]):
-        self.regnames = regnames
+    def __init__(self, archname: str, regnames: list[str]):
+        self.archname = archname
+        self.regnames = set(regnames)
 
     def __eq__(self, other):
         return self.regnames == other.regnames
diff --git a/arch/x86.py b/arch/x86.py
index 0f60457..2b27315 100644
--- a/arch/x86.py
+++ b/arch/x86.py
@@ -30,4 +30,4 @@ regnames = ['PC',
 
 class ArchX86(Arch):
     def __init__(self):
-        super().__init__(regnames)
+        super().__init__("X86", regnames)
diff --git a/gen_trace.py b/gen_trace.py
new file mode 100644
index 0000000..64fcf8f
--- /dev/null
+++ b/gen_trace.py
@@ -0,0 +1,50 @@
+import argparse
+import lldb
+import lldb_target
+
+def parse_args():
+    prog = argparse.ArgumentParser()
+    prog.add_argument('binary',
+                      help='The executable to trace.')
+    prog.add_argument('-o', '--output',
+                      default='breakpoints',
+                      type=str,
+                      help='File to which the recorded trace is written.')
+    prog.add_argument('--args',
+                      default=[],
+                      nargs='+',
+                      help='Arguments to the executable.')
+    return prog.parse_args()
+
+def record_trace(binary: str, args: list[str] = []) -> list[int]:
+    # Set up LLDB target
+    target = lldb_target.LLDBConcreteTarget(binary, args)
+
+    # Skip to first instruction in `main`
+    result = lldb.SBCommandReturnObject()
+    break_at_main = f'b -b main -s {target.module.GetFileSpec().GetFilename()}'
+    target.interpreter.HandleCommand(break_at_main, result)
+    target.run()
+
+    # Run until main function is exited
+    trace = []
+    while not target.is_exited():
+        thread = target.process.GetThreadAtIndex(0)
+        func_names = [thread.GetFrameAtIndex(i).GetFunctionName() for i in range(0, thread.GetNumFrames())]
+        if 'main' not in func_names:
+            break
+        trace.append(target.read_register('pc'))
+        thread.StepInstruction(False)
+
+    return trace
+
+def main():
+    args = parse_args()
+    trace = record_trace(args.binary, args.args)
+    with open(args.output, 'w') as file:
+        for addr in trace:
+            print(hex(addr), file=file)
+    print(f'Generated a trace of {len(trace)} instructions.')
+
+if __name__ == '__main__':
+    main()
diff --git a/lldb_target.py b/lldb_target.py
new file mode 100644
index 0000000..1ff9f53
--- /dev/null
+++ b/lldb_target.py
@@ -0,0 +1,123 @@
+import lldb
+
+from angr.errors import SimConcreteMemoryError, \
+                        SimConcreteRegisterError
+from angr_targets.concrete import ConcreteTarget
+from angr_targets.memory_map import MemoryMap
+
+class LLDBConcreteTarget(ConcreteTarget):
+    def __init__(self, executable: str, args: list[str] = []):
+        # Prepend the executable's path to argv, as is convention
+        args.insert(0, executable)
+
+        self.debugger = lldb.SBDebugger.Create()
+        self.debugger.SetAsync(False)
+        self.target = self.debugger.CreateTargetWithFileAndArch(executable,
+                                                                lldb.LLDB_ARCH_DEFAULT)
+        self.module = self.target.FindModule(self.target.GetExecutable())
+        self.interpreter = self.debugger.GetCommandInterpreter()
+
+        # Set up objects for process execution
+        self.error = lldb.SBError()
+        self.listener = self.debugger.GetListener()
+        self.process = self.target.Launch(self.listener,
+                                          args, None, None,
+                                          None, None, None, 0,
+                                          True, self.error)
+        if not self.process.IsValid():
+            raise RuntimeError(f'[In LLDBConcreteTarget.__init__]: Failed to'
+                               f' launch process.')
+
+    def set_breakpoint(self, addr, **kwargs):
+        command = f'b -a {addr} -s {self.module.GetFileSpec().GetFilename()}'
+        result = lldb.SBCommandReturnObject()
+        self.interpreter.HandleCommand(command, result)
+
+    def remove_breakpoint(self, addr, **kwargs):
+        command = f'breakpoint delete {addr}'
+        result = lldb.SBCommandReturnObject()
+        self.interpreter.HandleCommand(command, result)
+
+    def is_running(self):
+        return self.process.GetState() == lldb.eStateRunning
+
+    def is_exited(self):
+        """Not part of the angr interface, but much more useful than
+        `is_running`.
+
+        :return: True if the process has exited. False otherwise.
+        """
+        return self.process.GetState() == lldb.eStateExited
+
+    def wait_for_running(self):
+        while self.process.GetState() != lldb.eStateRunning:
+            pass
+
+    def wait_for_halt(self):
+        while self.process.GetState() != lldb.eStateStopped:
+            pass
+
+    def run(self):
+        state = self.process.GetState()
+        if state == lldb.eStateExited:
+            raise RuntimeError(f'Tried to resume process execution, but the'
+                               f' process has already exited.')
+        assert(state == lldb.eStateStopped)
+        self.process.Continue()
+
+    def stop(self):
+        self.process.Stop()
+
+    def exit(self):
+        self.debugger.Terminate()
+        print(f'Program exited with status {self.process.GetState()}')
+
+    def read_register(self, regname: str) -> int:
+        frame = self.process.GetThreadAtIndex(0).GetFrameAtIndex(0)
+        reg = frame.FindRegister(regname)
+        if reg is None:
+            raise SimConcreteRegisterError(
+                f'[In LLDBConcreteTarget.read_register]: Register {regname}'
+                f' not found.')
+
+        val = reg.GetValue()
+        if val is None:
+            raise SimConcreteRegisterError(
+                f'[In LLDBConcreteTarget.read_register]: Register has an'
+                f' invalid value of {val}.')
+
+        return int(val, 16)
+
+    def read_memory(self, addr, size):
+        err = lldb.SBError()
+        content = self.process.ReadMemory(addr, size, err)
+        if not err.success:
+            raise SimConcreteMemoryError(f'Error when reading {size} bytes at'
+                                         f' address {hex(addr)}: {err}')
+        return content
+
+    def write_memory(self, addr, value):
+        err = lldb.SBError()
+        res = self.process.WriteMemory(addr, value, err)
+        if not err.success or res != len(value):
+            raise SimConcreteMemoryError(f'Error when writing to address'
+                                         f' {hex(addr)}: {err}')
+
+    def get_mappings(self):
+        mmap = []
+
+        region_list = self.process.GetMemoryRegions()
+        for i in range(region_list.GetSize()):
+            region = lldb.SBMemoryRegionInfo()
+            region_list.GetMemoryRegionAtIndex(i, region)
+
+            perms = f'{"r" if region.IsReadable() else "-"}' \
+                    f'{"w" if region.IsWritable() else "-"}' \
+                    f'{"x" if region.IsExecutable() else "-"}' \
+
+            mmap.append(MemoryMap(region.GetRegionBase(),
+                                  region.GetRegionEnd(),
+                                  0,             # offset?
+                                  "<no-name>",   # name?
+                                  perms))
+        return mmap
diff --git a/main.py b/main.py
index d97a54d..9451e42 100755
--- a/main.py
+++ b/main.py
@@ -8,21 +8,25 @@ from compare import compare_simple
 from run import run_native_execution
 from utils import check_version, print_separator
 
-def read_logs(txl_path, native_path, program):
+def parse_inputs(txl_path, ref_path, program):
+    # Our architecture
+    arch = x86.ArchX86()
+
     txl = []
     with open(txl_path, "r") as txl_file:
-        txl = txl_file.readlines()
+        txl = arancini.parse(txl_file.readlines(), arch)
 
-    native = []
+    ref = []
     if program is not None:
-        breakpoints = arancini.parse_break_addresses(txl)
-        native = run_native_execution(program, breakpoints)
+        with open(txl_path, "r") as txl_file:
+            breakpoints = arancini.parse_break_addresses(txl_file.readlines())
+        ref = run_native_execution(program, breakpoints)
     else:
-        assert(native_path is not None)
-        with open(native_path, "r") as native_file:
-            native = native_file.readlines()
+        assert(ref_path is not None)
+        with open(ref_path, "r") as native_file:
+            ref = arancini.parse(native_file.readlines(), arch)
 
-    return txl, native
+    return txl, ref
 
 def parse_arguments():
     parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference')
@@ -57,35 +61,31 @@ def main():
     args = parse_arguments()
 
     txl_path = args.txl
-    native_path = args.ref
+    reference_path = args.ref
     program = args.program
 
     stats = args.stats
     verbose = args.verbose
     progressive = args.progressive
 
-    # Our architexture
-    arch = x86.ArchX86()
-
     if verbose:
         print("Enabling verbose program output")
         print(f"Verbose: {verbose}")
         print(f"Statistics: {stats}")
         print(f"Progressive: {progressive}")
 
-    if program is None and native_path is None:
+    if program is None and reference_path is None:
         raise ValueError('Either program or path to native file must be'
                          'provided')
 
-    txl, native = read_logs(txl_path, native_path, program)
+    txl, ref = parse_inputs(txl_path, reference_path, program)
 
-    if program != None and native_path != None:
-        with open(native_path, 'w') as w:
-            w.write(''.join(native))
+    if program != None and reference_path != None:
+        with open(reference_path, 'w') as w:
+            for snapshot in ref:
+                print(snapshot, file=w)
 
-    txl = arancini.parse(txl, arch)
-    native = arancini.parse(native, arch)
-    result = compare_simple(txl, native)
+    result = compare_simple(txl, ref)
 
     # Print results
     for res in result:
diff --git a/run.py b/run.py
index 9b51fb5..6aca4d2 100644
--- a/run.py
+++ b/run.py
@@ -1,23 +1,24 @@
 """Functionality to execute native programs and collect snapshots via lldb."""
 
-import re
+import platform
 import sys
 import lldb
 from typing import Callable
 
-# TODO: The debugger callback is currently specific to a single architexture.
+# TODO: The debugger callback is currently specific to a single architecture.
 #       We should make it generic.
-from arch import x86
-from utils import print_separator
+from arch import Arch, x86
+from snapshot import ProgramState
 
-verbose = False
+class SnapshotBuilder:
+    """At every breakpoint, writes register contents to a stream.
 
-class DebuggerCallback:
-    """At every breakpoint, writes register contents to a stream."""
-
-    def __init__(self, ostream=sys.stdout):
-        self.stream = ostream
-        self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$')
+    Generated snapshots are stored in and can be read from `self.states`.
+    """
+    def __init__(self, arch: Arch):
+        self.arch = arch
+        self.states = []
+        self.regnames = set(arch.regnames)
 
     @staticmethod
     def parse_flags(flag_reg: int):
@@ -44,48 +45,26 @@ class DebuggerCallback:
         flags['PF'] = int(0 != flag_reg & (1 << 1))
         return flags
 
-    def print_regs(self, frame):
+    def create_snapshot(self, frame):
+        state = ProgramState(self.arch)
+        state.set('PC', frame.GetPC())
         for reg in frame.GetRegisters():
             for sub_reg in reg:
-                match = self.regex.match(sub_reg.GetName().upper())
-                if match and match.group() == 'RFLAGS':
-                    flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(),
-                                                             base=16))
-                    for flag in flags:
-                        print(f'flag {flag}:\t{hex(flags[flag])}',
-                                                     file=self.stream)
-                elif match:
-                    print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}",
-                          file=self.stream)
-
-    def print_stack(self, frame, element_count: int):
-        first = True
-        for i in range(element_count):
-            addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize()
-            error = lldb.SBError()
-            stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error))
-            if error.Success() and not first:
-                print(f'{hex(stack_value)}', file=self.stream)
-            elif error.Success():
-                print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream)
-            else:
-                print(f"Error reading memory at address 0x{addr:x}",
-                      file=self.stream)
-            first=False
+                # Set the register's value in the current snapshot
+                regname = sub_reg.GetName().upper()
+                if regname in self.regnames:
+                    regval = int(sub_reg.GetValue(), base=16)
+                    if regname == 'RFLAGS':
+                        flags = SnapshotBuilder.parse_flags(regval)
+                        for flag, val in flags.items():
+                            state.set(f'flag {flag}', val)
+                    else:
+                        state.set(regname, regval)
+        return state
 
     def __call__(self, frame):
-        pc = frame.GetPC()
-
-        print_separator('=', stream=self.stream, count=20)
-        print(f'INVOKE PC={hex(pc)}', file=self.stream)
-        print_separator('=', stream=self.stream, count=20)
-
-        print("Register values:", file=self.stream)
-        self.print_regs(frame)
-        print_separator(stream=self.stream)
-
-        print("STACK:", file=self.stream)
-        self.print_stack(frame, 20)
+        snapshot = self.create_snapshot(frame)
+        self.states.append(snapshot)
 
 class Debugger:
     def __init__(self, program):
@@ -101,9 +80,6 @@ class Debugger:
         result = lldb.SBCommandReturnObject()
         self.interpreter.HandleCommand(command, result)
 
-        if verbose:
-            print(f'Set breakpoint at address {hex(address)}')
-
     def get_breakpoints_count(self):
         return self.target.GetNumBreakpoints()
 
@@ -128,23 +104,11 @@ class Debugger:
             if state == lldb.eStateExited:
                 break
 
-        self.debugger.Terminate()
-
         print(f'Process state: {process.GetState()}')
         print('Program output:')
         print(process.GetSTDOUT(1024))
         print(process.GetSTDERR(1024))
 
-class ListWriter:
-    def __init__(self):
-        self.data = []
-
-    def write(self, s):
-        self.data.append(s)
-
-    def __str__(self):
-        return "".join(self.data)
-
 def run_native_execution(oracle_program: str, breakpoints: set[int]):
     """Gather snapshots from a native execution via an external debugger.
 
@@ -152,10 +116,11 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]):
     :param breakpoints: List of addresses at which to break and record the
                         program's state.
 
-    :return: A textual log of the program's execution in arancini's log format.
+    :return: A list of snapshots gathered from the execution.
     """
+    assert(platform.machine() == "x86_64")
+
     debugger = Debugger(oracle_program)
-    writer = ListWriter()
 
     # Set breakpoints
     for address in breakpoints:
@@ -163,6 +128,7 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]):
     assert(debugger.get_breakpoints_count() == len(breakpoints))
 
     # Execute the native program
-    debugger.execute(DebuggerCallback(writer))
+    builder = SnapshotBuilder(x86.ArchX86())
+    debugger.execute(builder)
 
-    return writer.data
+    return builder.states
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..72a1438
--- /dev/null
+++ b/test.py
@@ -0,0 +1,180 @@
+import angr
+import angr_targets
+import claripy as cp
+import sys
+
+from lldb_target import LLDBConcreteTarget
+
+from arancini import parse_break_addresses
+from arch import x86
+
+def print_state(state, file=sys.stdout):
+    for reg in x86.regnames:
+        try:
+            val = state.regs.__getattr__(reg.lower())
+            print(f'{reg} = {val}', file=file)
+        except angr.SimConcreteRegisterError:
+            print(f'Unable to read value of register {reg}: register error',
+                  file=file)
+        except angr.SimConcreteMemoryError:
+            print(f'Unable to read value of register {reg}: memory error',
+                  file=file)
+        except AttributeError:
+            print(f'Unable to read value of register {reg}: AttributeError',
+                  file=file)
+        except KeyError:
+            print(f'Unable to read value of register {reg}: KeyError',
+                  file=file)
+
+def copy_state(src: angr_targets.ConcreteTarget, dst: angr.SimState):
+    """Copy a concrete program state to an `angr.SimState` object."""
+    # Copy register contents
+    for reg in x86.regnames:
+        regname = reg.lower()
+        try:
+            dst.regs.__setattr__(regname, src.read_register(regname))
+        except angr.SimConcreteRegisterError:
+            # Register does not exist (i.e. "flag ZF")
+            pass
+
+    # Copy memory contents
+    for mapping in src.get_mappings():
+        addr = mapping.start_address
+        size = mapping.end_address - mapping.start_address
+        try:
+            dst.memory.store(addr, src.read_memory(addr, size), size)
+        except angr.SimConcreteMemoryError:
+            # Invalid memory access
+            pass
+
+def symbolize_state(state: angr.SimState):
+    for reg in x86.regnames:
+        if reg != 'PC':
+            symb_val = cp.BVS(reg, 64)
+            try:
+                state.regs.__setattr__(reg.lower(), symb_val)
+            except AttributeError:
+                pass
+
+def output_truth(breakpoints: set[int]):
+    import run
+    res = run.run_native_execution(BINARY, breakpoints)
+    with open('truth.log', 'w') as file:
+        for snapshot in res:
+            print(cp.BVV(snapshot.regs['PC'], 64), file=file)
+
+BINARY = "hello-static-musl"
+BREAKPOINT_LOG = "emulator-log.txt"
+
+# Read breakpoint addresses from a file
+with open(BREAKPOINT_LOG, "r") as file:
+    breakpoints = parse_break_addresses(file.readlines())
+
+print(f'Found {len(breakpoints)} breakpoints.')
+
+class ConcreteExecution:
+    def __init__(self, executable: str, breakpoints: list[int]):
+        self.target = LLDBConcreteTarget(executable)
+        self.proj = angr.Project(executable,
+                                 concrete_target=self.target,
+                                 use_sim_procedures=False)
+
+        # Set the initial state
+        state = self.proj.factory.entry_state()
+        state.options.add(angr.options.SYMBION_SYNC_CLE)
+        state.options.add(angr.options.SYMBION_KEEP_STUBS_ON_SYNC)
+        self.simgr = self.proj.factory.simgr(state)
+        self.simgr.use_technique(
+            angr.exploration_techniques.Symbion(find=breakpoints))
+
+    def is_running(self):
+        return not self.target.is_exited()
+
+    def step(self) -> angr.SimState | None:
+        self.simgr.run()
+        self.simgr.unstash(to_stash='active', from_stash='found')
+        if len(self.simgr.active) > 0:
+            state = self.simgr.active[0]
+            print(f'-- Concrete execution hit a breakpoint at {state.regs.pc}!')
+            return state
+        return None
+
+class SymbolicExecution:
+    def __init__(self, executable: str):
+        self.proj = angr.Project(executable, use_sim_procedures=False)
+        self.simgr = self.proj.factory.simgr(self.proj.factory.entry_state())
+
+    def is_running(self):
+        return len(self.simgr.active) > 0
+
+    def step(self, find) -> angr.SimState | None:
+        self.simgr.explore(find=find)
+        self.simgr.unstash(to_stash='active', from_stash='found')
+        if len(self.simgr.active) == 0:
+            print(f'No states found. Stashes: {self.simgr.stashes}')
+            return None
+
+        state = self.simgr.active[0]
+        assert(len(self.simgr.active) == 1)
+        print(f'-- Symbolic execution stopped at {state.regs.pc}!')
+        print(f'   Found the following stashes: {self.simgr.stashes}')
+
+        return state
+
+output_truth(breakpoints)
+
+conc = ConcreteExecution(BINARY, list(breakpoints))
+symb = SymbolicExecution(BINARY)
+
+conc_log = open('concrete.log', 'w')
+symb_log = open('symbolic.log', 'w')
+
+while True:
+    if not (conc.is_running() and symb.is_running()):
+        assert(not conc.is_running() and not symb.is_running())
+        print(f'Execution has exited.')
+        exit(0)
+
+    # It seems that we have to copy the program's state manually to the state
+    # handed to the symbolic engine, otherwise the program emulation is
+    # incorrect. Something in angr's emulation is scuffed.
+    copy_state(conc.target, symb.simgr.active[0])
+
+    # angr performs a sanity check to ensure that the address at which the
+    # concrete engine stops actually is one of the breakpoints specified by
+    # the user. This sanity check is faulty because it is performed before the
+    # user has a chance determine whether the program has exited. If the
+    # program counter is read after the concrete execution has exited, LLDB
+    # returns a null value and the check fails, resulting in a crash. This
+    # try/catch block prevents that.
+    #
+    # As of angr commit `cbeace5d7`, this faulty read of the program counter
+    # can be found at `angr/engines/concrete.py:148`.
+    try:
+        conc_state = conc.step()
+        if conc_state is None:
+            print(f'Execution has exited: ConcreteExecution.step() returned null.')
+            exit(0)
+    except angr.SimConcreteRegisterError:
+        print(f'Done.')
+        exit(0)
+
+    pc = conc_state.solver.eval(conc_state.regs.pc)
+    print(f'-- Trying to find address {hex(pc)} with symbolic execution...')
+
+    # TODO:
+    #symbolize_state(symb.simgr.active[0])
+    symb_state = symb.step(pc)
+
+    # Check exit conditions
+    if symb_state is None:
+        print(f'Execution has exited: SymbolicExecution.step() returned null.')
+        exit(0)
+    assert(pc == symb_state.solver.eval(symb_state.regs.pc))
+
+    # Log some stuff
+    print(f'-- Concrete breakpoint {conc_state.regs.pc}'
+          f' vs symbolic breakpoint {symb_state.regs.pc}')
+
+    print(conc_state.regs.pc, file=conc_log)
+    print(symb_state.regs.pc, file=symb_log)