about summary refs log tree commit diff stats
path: root/run.py
diff options
context:
space:
mode:
Diffstat (limited to 'run.py')
-rw-r--r--[-rwxr-xr-x]run.py100
1 files changed, 27 insertions, 73 deletions
diff --git a/run.py b/run.py
index f1f1060..9b51fb5 100755..100644
--- a/run.py
+++ b/run.py
@@ -1,39 +1,23 @@
-#! /bin/python3
-import os
+"""Functionality to execute native programs and collect snapshots via lldb."""
+
 import re
 import sys
 import lldb
-import shutil
-import argparse
+from typing import Callable
 
+# TODO: The debugger callback is currently specific to a single architexture.
+#       We should make it generic.
+from arch import x86
 from utils import print_separator
 
 verbose = False
 
-regnames = ['PC',
-            'RAX',
-            'RBX',
-            'RCX',
-            'RDX',
-            'RSI',
-            'RDI',
-            'RBP',
-            'RSP',
-            'R8',
-            'R9',
-            'R10',
-            'R11',
-            'R12',
-            'R13',
-            'R14',
-            'R15',
-            'RFLAGS']
-
 class DebuggerCallback:
-    def __init__(self, ostream=sys.stdout, skiplist: set = {}):
+    """At every breakpoint, writes register contents to a stream."""
+
+    def __init__(self, ostream=sys.stdout):
         self.stream = ostream
-        self.regex = re.compile('(' + '|'.join(regnames) + ')$')
-        self.skiplist = skiplist
+        self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$')
 
     @staticmethod
     def parse_flags(flag_reg: int):
@@ -60,7 +44,6 @@ class DebuggerCallback:
         flags['PF'] = int(0 != flag_reg & (1 << 1))
         return flags
 
-
     def print_regs(self, frame):
         for reg in frame.GetRegisters():
             for sub_reg in reg:
@@ -93,11 +76,6 @@ class DebuggerCallback:
     def __call__(self, frame):
         pc = frame.GetPC()
 
-        # Skip this PC
-        if pc in self.skiplist:
-            self.skiplist.discard(pc)
-            return False
-
         print_separator('=', stream=self.stream, count=20)
         print(f'INVOKE PC={hex(pc)}', file=self.stream)
         print_separator('=', stream=self.stream, count=20)
@@ -109,8 +87,6 @@ class DebuggerCallback:
         print("STACK:", file=self.stream)
         self.print_stack(frame, 20)
 
-        return True  # Continue execution
-
 class Debugger:
     def __init__(self, program):
         self.debugger = lldb.SBDebugger.Create()
@@ -131,7 +107,7 @@ class Debugger:
     def get_breakpoints_count(self):
         return self.target.GetNumBreakpoints()
 
-    def execute(self, callback: callable):
+    def execute(self, callback: Callable):
         error = lldb.SBError()
         listener = self.debugger.GetListener()
         process = self.target.Launch(listener, None, None, None, None, None, None, 0,
@@ -169,46 +145,24 @@ class ListWriter:
     def __str__(self):
         return "".join(self.data)
 
-class Runner:
-    def __init__(self, dbt_log: list, oracle_program: str):
-        self.log = dbt_log
-        self.program = oracle_program
-        self.debugger = Debugger(self.program)
-        self.writer = ListWriter()
-
-    @staticmethod
-    def get_addresses(lines: list):
-        addresses = []
-
-        backlist = []
-        backlist_regex = re.compile(r'^\s\s\d*:')
-
-        skiplist = set()
-        for l in lines:
-            if l.startswith('INVOKE'):
-                addresses.append(int(l.split('=')[1].strip(), base=16))
-
-                if addresses[-1] in backlist:
-                    skiplist.add(addresses[-1])
-                backlist = []
-
-            if backlist_regex.match(l):
-                backlist.append(int(l.split()[0].split(':')[0], base=16))
-
-        return set(addresses), skiplist
-
-    def run(self):
-        # Get all addresses to stop at
-        addresses, skiplist = Runner.get_addresses(self.log)
+def run_native_execution(oracle_program: str, breakpoints: set[int]):
+    """Gather snapshots from a native execution via an external debugger.
 
-        # Set breakpoints
-        for address in addresses:
-            self.debugger.set_breakpoint_by_addr(address)
+    :param oracle_program: Program to execute.
+    :param breakpoints: List of addresses at which to break and record the
+                        program's state.
 
-        # Sanity check
-        assert(self.debugger.get_breakpoints_count() == len(addresses))
+    :return: A textual log of the program's execution in arancini's log format.
+    """
+    debugger = Debugger(oracle_program)
+    writer = ListWriter()
 
-        self.debugger.execute(DebuggerCallback(self.writer, skiplist))
+    # Set breakpoints
+    for address in breakpoints:
+        debugger.set_breakpoint_by_addr(address)
+    assert(debugger.get_breakpoints_count() == len(breakpoints))
 
-        return self.writer.data
+    # Execute the native program
+    debugger.execute(DebuggerCallback(writer))
 
+    return writer.data