about summary refs log tree commit diff stats
path: root/run.py
diff options
context:
space:
mode:
Diffstat (limited to 'run.py')
-rwxr-xr-xrun.py214
1 files changed, 214 insertions, 0 deletions
diff --git a/run.py b/run.py
new file mode 100755
index 0000000..f1f1060
--- /dev/null
+++ b/run.py
@@ -0,0 +1,214 @@
+#! /bin/python3
+import os
+import re
+import sys
+import lldb
+import shutil
+import argparse
+
+from utils import print_separator
+
+verbose = False
+
+regnames = ['PC',
+            'RAX',
+            'RBX',
+            'RCX',
+            'RDX',
+            'RSI',
+            'RDI',
+            'RBP',
+            'RSP',
+            'R8',
+            'R9',
+            'R10',
+            'R11',
+            'R12',
+            'R13',
+            'R14',
+            'R15',
+            'RFLAGS']
+
+class DebuggerCallback:
+    def __init__(self, ostream=sys.stdout, skiplist: set = {}):
+        self.stream = ostream
+        self.regex = re.compile('(' + '|'.join(regnames) + ')$')
+        self.skiplist = skiplist
+
+    @staticmethod
+    def parse_flags(flag_reg: int):
+        flags = {'ZF': 0,
+                 'CF': 0,
+                 'OF': 0,
+                 'SF': 0,
+                 'PF': 0,
+                 'DF': 0}
+
+        # CF (Carry flag) Bit 0
+        # PF (Parity flag) Bit 2
+        # ZF (Zero flag) Bit 6
+        # SF (Sign flag) Bit 7
+        # TF (Trap flag) Bit 8
+        # IF (Interrupt enable flag) Bit 9
+        # DF (Direction flag) Bit 10
+        # OF (Overflow flag) Bit 11
+        flags['CF'] = int(0 != flag_reg & 1)
+        flags['ZF'] = int(0 != flag_reg & (1 << 6))
+        flags['OF'] = int(0 != flag_reg & (1 << 11))
+        flags['SF'] = int(0 != flag_reg & (1 << 7))
+        flags['DF'] = int(0 != flag_reg & (1 << 10))
+        flags['PF'] = int(0 != flag_reg & (1 << 1))
+        return flags
+
+
+    def print_regs(self, frame):
+        for reg in frame.GetRegisters():
+            for sub_reg in reg:
+                match = self.regex.match(sub_reg.GetName().upper())
+                if match and match.group() == 'RFLAGS':
+                    flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(),
+                                                             base=16))
+                    for flag in flags:
+                        print(f'flag {flag}:\t{hex(flags[flag])}',
+                                                     file=self.stream)
+                elif match:
+                    print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}",
+                          file=self.stream)
+
+    def print_stack(self, frame, element_count: int):
+        first = True
+        for i in range(element_count):
+            addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize()
+            error = lldb.SBError()
+            stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error))
+            if error.Success() and not first:
+                print(f'{hex(stack_value)}', file=self.stream)
+            elif error.Success():
+                print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream)
+            else:
+                print(f"Error reading memory at address 0x{addr:x}",
+                      file=self.stream)
+            first=False
+
+    def __call__(self, frame):
+        pc = frame.GetPC()
+
+        # Skip this PC
+        if pc in self.skiplist:
+            self.skiplist.discard(pc)
+            return False
+
+        print_separator('=', stream=self.stream, count=20)
+        print(f'INVOKE PC={hex(pc)}', file=self.stream)
+        print_separator('=', stream=self.stream, count=20)
+
+        print("Register values:", file=self.stream)
+        self.print_regs(frame)
+        print_separator(stream=self.stream)
+
+        print("STACK:", file=self.stream)
+        self.print_stack(frame, 20)
+
+        return True  # Continue execution
+
+class Debugger:
+    def __init__(self, program):
+        self.debugger = lldb.SBDebugger.Create()
+        self.debugger.SetAsync(False)
+        self.target = self.debugger.CreateTargetWithFileAndArch(program,
+                                                                lldb.LLDB_ARCH_DEFAULT)
+        self.module = self.target.FindModule(self.target.GetExecutable())
+        self.interpreter = self.debugger.GetCommandInterpreter()
+
+    def set_breakpoint_by_addr(self, address: int):
+        command = f"b -a {address} -s {self.module.GetFileSpec().GetFilename()}"
+        result = lldb.SBCommandReturnObject()
+        self.interpreter.HandleCommand(command, result)
+
+        if verbose:
+            print(f'Set breakpoint at address {hex(address)}')
+
+    def get_breakpoints_count(self):
+        return self.target.GetNumBreakpoints()
+
+    def execute(self, callback: callable):
+        error = lldb.SBError()
+        listener = self.debugger.GetListener()
+        process = self.target.Launch(listener, None, None, None, None, None, None, 0,
+                                     True, error)
+
+        # Check if the process has launched successfully
+        if process.IsValid():
+            print(f'Launched process: {process}')
+        else:
+            print('Failed to launch process', file=sys.stderr)
+
+        while True:
+            state = process.GetState()
+            if state == lldb.eStateStopped:
+                 for thread in process:
+                    callback(thread.GetFrameAtIndex(0))
+                 process.Continue()
+            if state == lldb.eStateExited:
+                break
+
+        self.debugger.Terminate()
+
+        print(f'Process state: {process.GetState()}')
+        print('Program output:')
+        print(process.GetSTDOUT(1024))
+        print(process.GetSTDERR(1024))
+
+class ListWriter:
+    def __init__(self):
+        self.data = []
+
+    def write(self, s):
+        self.data.append(s)
+
+    def __str__(self):
+        return "".join(self.data)
+
+class Runner:
+    def __init__(self, dbt_log: list, oracle_program: str):
+        self.log = dbt_log
+        self.program = oracle_program
+        self.debugger = Debugger(self.program)
+        self.writer = ListWriter()
+
+    @staticmethod
+    def get_addresses(lines: list):
+        addresses = []
+
+        backlist = []
+        backlist_regex = re.compile(r'^\s\s\d*:')
+
+        skiplist = set()
+        for l in lines:
+            if l.startswith('INVOKE'):
+                addresses.append(int(l.split('=')[1].strip(), base=16))
+
+                if addresses[-1] in backlist:
+                    skiplist.add(addresses[-1])
+                backlist = []
+
+            if backlist_regex.match(l):
+                backlist.append(int(l.split()[0].split(':')[0], base=16))
+
+        return set(addresses), skiplist
+
+    def run(self):
+        # Get all addresses to stop at
+        addresses, skiplist = Runner.get_addresses(self.log)
+
+        # Set breakpoints
+        for address in addresses:
+            self.debugger.set_breakpoint_by_addr(address)
+
+        # Sanity check
+        assert(self.debugger.get_breakpoints_count() == len(addresses))
+
+        self.debugger.execute(DebuggerCallback(self.writer, skiplist))
+
+        return self.writer.data
+