about summary refs log tree commit diff stats
path: root/run.py
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-10-21 16:39:49 +0200
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-10-21 16:39:49 +0200
commit6f01367f9c8ad4c3d641cc63dbb1a3977ff4ec56 (patch)
tree5b9677b9a5cca449497cea4418b2bb10e2ab0509 /run.py
parent83d4b4dbe6f20c2fa7865e4888b89e888d3509f9 (diff)
downloadfocaccia-6f01367f9c8ad4c3d641cc63dbb1a3977ff4ec56.tar.gz
focaccia-6f01367f9c8ad4c3d641cc63dbb1a3977ff4ec56.zip
Support for testing concrete and emulated execution with angr
Diffstat (limited to 'run.py')
-rw-r--r--run.py102
1 files changed, 34 insertions, 68 deletions
diff --git a/run.py b/run.py
index 9b51fb5..6aca4d2 100644
--- a/run.py
+++ b/run.py
@@ -1,23 +1,24 @@
 """Functionality to execute native programs and collect snapshots via lldb."""
 
-import re
+import platform
 import sys
 import lldb
 from typing import Callable
 
-# TODO: The debugger callback is currently specific to a single architexture.
+# TODO: The debugger callback is currently specific to a single architecture.
 #       We should make it generic.
-from arch import x86
-from utils import print_separator
+from arch import Arch, x86
+from snapshot import ProgramState
 
-verbose = False
+class SnapshotBuilder:
+    """At every breakpoint, writes register contents to a stream.
 
-class DebuggerCallback:
-    """At every breakpoint, writes register contents to a stream."""
-
-    def __init__(self, ostream=sys.stdout):
-        self.stream = ostream
-        self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$')
+    Generated snapshots are stored in and can be read from `self.states`.
+    """
+    def __init__(self, arch: Arch):
+        self.arch = arch
+        self.states = []
+        self.regnames = set(arch.regnames)
 
     @staticmethod
     def parse_flags(flag_reg: int):
@@ -44,48 +45,26 @@ class DebuggerCallback:
         flags['PF'] = int(0 != flag_reg & (1 << 1))
         return flags
 
-    def print_regs(self, frame):
+    def create_snapshot(self, frame):
+        state = ProgramState(self.arch)
+        state.set('PC', frame.GetPC())
         for reg in frame.GetRegisters():
             for sub_reg in reg:
-                match = self.regex.match(sub_reg.GetName().upper())
-                if match and match.group() == 'RFLAGS':
-                    flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(),
-                                                             base=16))
-                    for flag in flags:
-                        print(f'flag {flag}:\t{hex(flags[flag])}',
-                                                     file=self.stream)
-                elif match:
-                    print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}",
-                          file=self.stream)
-
-    def print_stack(self, frame, element_count: int):
-        first = True
-        for i in range(element_count):
-            addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize()
-            error = lldb.SBError()
-            stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error))
-            if error.Success() and not first:
-                print(f'{hex(stack_value)}', file=self.stream)
-            elif error.Success():
-                print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream)
-            else:
-                print(f"Error reading memory at address 0x{addr:x}",
-                      file=self.stream)
-            first=False
+                # Set the register's value in the current snapshot
+                regname = sub_reg.GetName().upper()
+                if regname in self.regnames:
+                    regval = int(sub_reg.GetValue(), base=16)
+                    if regname == 'RFLAGS':
+                        flags = SnapshotBuilder.parse_flags(regval)
+                        for flag, val in flags.items():
+                            state.set(f'flag {flag}', val)
+                    else:
+                        state.set(regname, regval)
+        return state
 
     def __call__(self, frame):
-        pc = frame.GetPC()
-
-        print_separator('=', stream=self.stream, count=20)
-        print(f'INVOKE PC={hex(pc)}', file=self.stream)
-        print_separator('=', stream=self.stream, count=20)
-
-        print("Register values:", file=self.stream)
-        self.print_regs(frame)
-        print_separator(stream=self.stream)
-
-        print("STACK:", file=self.stream)
-        self.print_stack(frame, 20)
+        snapshot = self.create_snapshot(frame)
+        self.states.append(snapshot)
 
 class Debugger:
     def __init__(self, program):
@@ -101,9 +80,6 @@ class Debugger:
         result = lldb.SBCommandReturnObject()
         self.interpreter.HandleCommand(command, result)
 
-        if verbose:
-            print(f'Set breakpoint at address {hex(address)}')
-
     def get_breakpoints_count(self):
         return self.target.GetNumBreakpoints()
 
@@ -128,23 +104,11 @@ class Debugger:
             if state == lldb.eStateExited:
                 break
 
-        self.debugger.Terminate()
-
         print(f'Process state: {process.GetState()}')
         print('Program output:')
         print(process.GetSTDOUT(1024))
         print(process.GetSTDERR(1024))
 
-class ListWriter:
-    def __init__(self):
-        self.data = []
-
-    def write(self, s):
-        self.data.append(s)
-
-    def __str__(self):
-        return "".join(self.data)
-
 def run_native_execution(oracle_program: str, breakpoints: set[int]):
     """Gather snapshots from a native execution via an external debugger.
 
@@ -152,10 +116,11 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]):
     :param breakpoints: List of addresses at which to break and record the
                         program's state.
 
-    :return: A textual log of the program's execution in arancini's log format.
+    :return: A list of snapshots gathered from the execution.
     """
+    assert(platform.machine() == "x86_64")
+
     debugger = Debugger(oracle_program)
-    writer = ListWriter()
 
     # Set breakpoints
     for address in breakpoints:
@@ -163,6 +128,7 @@ def run_native_execution(oracle_program: str, breakpoints: set[int]):
     assert(debugger.get_breakpoints_count() == len(breakpoints))
 
     # Execute the native program
-    debugger.execute(DebuggerCallback(writer))
+    builder = SnapshotBuilder(x86.ArchX86())
+    debugger.execute(builder)
 
-    return writer.data
+    return builder.states