From 69c55d68d68c00007afa1af76a1d06f74ee72fe6 Mon Sep 17 00:00:00 2001 From: Theofilos Augoustis Date: Wed, 11 Oct 2023 16:21:21 +0200 Subject: Refactor file structure - main.py: focaccia user-interface - snapshot.py: state trace snapshots handling - compare.py: snapshot comparison algorithms - run.py: native execution tracer - arancini.py: Arancini log handling - arch/: per-architecture abstractions Co-authored-by: Theofilos Augoustis Co-authored-by: Nicola Crivellin --- run.py | 100 ++++++++++++++++++----------------------------------------------- 1 file changed, 27 insertions(+), 73 deletions(-) mode change 100755 => 100644 run.py (limited to 'run.py') diff --git a/run.py b/run.py old mode 100755 new mode 100644 index f1f1060..9b51fb5 --- a/run.py +++ b/run.py @@ -1,39 +1,23 @@ -#! /bin/python3 -import os +"""Functionality to execute native programs and collect snapshots via lldb.""" + import re import sys import lldb -import shutil -import argparse +from typing import Callable +# TODO: The debugger callback is currently specific to a single architexture. +# We should make it generic. +from arch import x86 from utils import print_separator verbose = False -regnames = ['PC', - 'RAX', - 'RBX', - 'RCX', - 'RDX', - 'RSI', - 'RDI', - 'RBP', - 'RSP', - 'R8', - 'R9', - 'R10', - 'R11', - 'R12', - 'R13', - 'R14', - 'R15', - 'RFLAGS'] - class DebuggerCallback: - def __init__(self, ostream=sys.stdout, skiplist: set = {}): + """At every breakpoint, writes register contents to a stream.""" + + def __init__(self, ostream=sys.stdout): self.stream = ostream - self.regex = re.compile('(' + '|'.join(regnames) + ')$') - self.skiplist = skiplist + self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$') @staticmethod def parse_flags(flag_reg: int): @@ -60,7 +44,6 @@ class DebuggerCallback: flags['PF'] = int(0 != flag_reg & (1 << 1)) return flags - def print_regs(self, frame): for reg in frame.GetRegisters(): for sub_reg in reg: @@ -93,11 +76,6 @@ class DebuggerCallback: def __call__(self, frame): pc = frame.GetPC() - # Skip this PC - if pc in self.skiplist: - self.skiplist.discard(pc) - return False - print_separator('=', stream=self.stream, count=20) print(f'INVOKE PC={hex(pc)}', file=self.stream) print_separator('=', stream=self.stream, count=20) @@ -109,8 +87,6 @@ class DebuggerCallback: print("STACK:", file=self.stream) self.print_stack(frame, 20) - return True # Continue execution - class Debugger: def __init__(self, program): self.debugger = lldb.SBDebugger.Create() @@ -131,7 +107,7 @@ class Debugger: def get_breakpoints_count(self): return self.target.GetNumBreakpoints() - def execute(self, callback: callable): + def execute(self, callback: Callable): error = lldb.SBError() listener = self.debugger.GetListener() process = self.target.Launch(listener, None, None, None, None, None, None, 0, @@ -169,46 +145,24 @@ class ListWriter: def __str__(self): return "".join(self.data) -class Runner: - def __init__(self, dbt_log: list, oracle_program: str): - self.log = dbt_log - self.program = oracle_program - self.debugger = Debugger(self.program) - self.writer = ListWriter() - - @staticmethod - def get_addresses(lines: list): - addresses = [] - - backlist = [] - backlist_regex = re.compile(r'^\s\s\d*:') - - skiplist = set() - for l in lines: - if l.startswith('INVOKE'): - addresses.append(int(l.split('=')[1].strip(), base=16)) - - if addresses[-1] in backlist: - skiplist.add(addresses[-1]) - backlist = [] - - if backlist_regex.match(l): - backlist.append(int(l.split()[0].split(':')[0], base=16)) - - return set(addresses), skiplist - - def run(self): - # Get all addresses to stop at - addresses, skiplist = Runner.get_addresses(self.log) +def run_native_execution(oracle_program: str, breakpoints: set[int]): + """Gather snapshots from a native execution via an external debugger. - # Set breakpoints - for address in addresses: - self.debugger.set_breakpoint_by_addr(address) + :param oracle_program: Program to execute. + :param breakpoints: List of addresses at which to break and record the + program's state. - # Sanity check - assert(self.debugger.get_breakpoints_count() == len(addresses)) + :return: A textual log of the program's execution in arancini's log format. + """ + debugger = Debugger(oracle_program) + writer = ListWriter() - self.debugger.execute(DebuggerCallback(self.writer, skiplist)) + # Set breakpoints + for address in breakpoints: + debugger.set_breakpoint_by_addr(address) + assert(debugger.get_breakpoints_count() == len(breakpoints)) - return self.writer.data + # Execute the native program + debugger.execute(DebuggerCallback(writer)) + return writer.data -- cgit 1.4.1