diff options
| author | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2023-07-12 14:02:44 +0200 |
|---|---|---|
| committer | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2023-07-12 14:02:44 +0200 |
| commit | 9e88720e1fbccc4ccb96f8a4770169c36037ef94 (patch) | |
| tree | 59c64320067e83d73584f917045da7ffb2dabbcd | |
| parent | 594ad72157c8b8a232e14d55f27e2fdb4881887d (diff) | |
| download | focaccia-9e88720e1fbccc4ccb96f8a4770169c36037ef94.tar.gz focaccia-9e88720e1fbccc4ccb96f8a4770169c36037ef94.zip | |
Add development scripts for comparing Arancini dumps
| -rwxr-xr-x | compare.py | 234 | ||||
| -rwxr-xr-x | run.py | 214 | ||||
| -rw-r--r-- | utils.py | 18 |
3 files changed, 466 insertions, 0 deletions
diff --git a/compare.py b/compare.py new file mode 100755 index 0000000..36cc9c7 --- /dev/null +++ b/compare.py @@ -0,0 +1,234 @@ +#! /bin/python3 +import re +import sys +import shutil +import argparse +from functools import partial as bind + +from utils import check_version +from utils import print_separator + +from run import Runner + +class ContextBlock: + regnames = ['PC', + 'RAX', + 'RBX', + 'RCX', + 'RDX', + 'RSI', + 'RDI', + 'RBP', + 'RSP', + 'R8', + 'R9', + 'R10', + 'R11', + 'R12', + 'R13', + 'R14', + 'R15', + 'flag ZF', + 'flag CF', + 'flag OF', + 'flag SF', + 'flag PF', + 'flag DF'] + + def __init__(self): + self.regs = {reg: None for reg in ContextBlock.regnames} + + def set(self, idx: int, value: int): + self.regs[list(self.regs.keys())[idx]] = value + + def __repr__(self): + return self.regs.__repr__() + +class Constructor: + def __init__(self, structure: dict): + self.cblocks = [] + self.structure = structure + self.patterns = list(self.structure.keys()) + + def match(self, line: str): + # find patterns that match it + regex = re.compile("|".join(self.patterns)) + match = regex.match(line) + + idx = self.patterns.index(match.group(0)) if match else 0 + + pattern = self.patterns[idx] + register = ContextBlock.regnames[idx] + + return register, self.structure[pattern](line) + + def add(self, key: str, value: int): + if key == 'PC': + self.cblocks.append(ContextBlock()) + + if self.cblocks[-1].regs[key] != None: + raise RuntimeError("Reassigning register") + + self.cblocks[-1].regs[key] = value + +def parse(lines: list, labels: list): + ctor = Constructor(labels) + + regex = re.compile("|".join(ctor.patterns)) + lines = [l for l in lines if regex.match(l) is not None] + + for line in lines: + key, value = ctor.match(line) + ctor.add(key, value) + + return ctor.cblocks + +def get_labels(): + split_value = lambda x,i: int(x.split()[i], 16) + + split_first = bind(split_value, i=1) + split_second = bind(split_value, i=2) + + split_equal = lambda x,i: int(x.split('=')[i], 16) + labels = {'INVOKE': bind(split_equal, i=1), + 'RAX': split_first, + 'RBX': split_first, + 'RCX': split_first, + 'RDX': split_first, + 'RSI': split_first, + 'RDI': split_first, + 'RBP': split_first, + 'RSP': split_first, + 'R8': split_first, + 'R9': split_first, + 'R10': split_first, + 'R11': split_first, + 'R12': split_first, + 'R13': split_first, + 'R14': split_first, + 'R15': split_first, + 'flag ZF': split_second, + 'flag CF': split_second, + 'flag OF': split_second, + 'flag SF': split_second, + 'flag PF': split_second, + 'flag DF': split_second} + return labels + +def equivalent(val1, val2): + return val1 == val2 + +def verify(translation: ContextBlock, reference: ContextBlock): + if translation.regs["PC"] != reference.regs["PC"]: + return 1 + + print_separator() + print(f'For PC={hex(translation.regs["PC"])}') + print_separator() + for el1 in translation.regs.keys(): + for el2 in reference.regs.keys(): + if el1 != el2: + continue + + if translation.regs[el1] is None: + print(f'Element not available in translation: {el1}') + continue + + if reference.regs[el2] is None: + print(f'Element not available in reference: {el2}') + continue + + if not equivalent(translation.regs[el1], reference.regs[el2]): + print(f'Difference for {el1}: {hex(translation.regs[el1])} != {hex(reference.regs[el2])}') + return 0 + +def compare(txl: list, native: list, stats: bool = False): + txl = parse(txl, get_labels()) + native = parse(native, get_labels()) + + if len(txl) != len(native): + print(f'Different number of blocks discovered translation: {len(txl)} vs. ' + f'reference: {len(native)}', file=sys.stderr) + + unmatched_pcs = {} + for translation, reference in zip(txl, native): + if verify(translation, reference) == 1: + # TODO: add verbose output + print_separator(stream=sys.stderr) + print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stderr) + if translation.regs['PC'] not in unmatched_pcs: + unmatched_pcs[translation.regs['PC']] = 0 + unmatched_pcs[translation.regs['PC']] += 1 + + if stats: + print_separator() + print('Statistics:') + print_separator() + + for pc in unmatched_pcs: + print(f'PC {hex(pc)} unmatched {unmatched_pcs[pc]} times') + return 0 + +def read_logs(txl_path, native_path, program): + txl = [] + with open(txl_path, "r") as txl_file: + txl = txl_file.readlines() + + native = [] + if program is not None: + runner = Runner(txl, program) + native = runner.run() + else: + with open(native_path, "r") as native_file: + native = native_file.readlines() + + return txl, native + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference') + parser.add_argument('-p', '--program', + type=str, + help='Path to oracle program') + parser.add_argument('-r', '--ref', + type=str, + required=True, + help='Path to the reference log (gathered with run.sh)') + parser.add_argument('-t', '--txl', + type=str, + required=True, + help='Path to the translation log (gathered via Arancini)') + parser.add_argument('-s', '--stats', + action='store_true', + default='store_false', + help='Run statistics on comparisons') + parser.add_argument('-v', '--verbose', + action='store_true', + default='store_true', + help='Path to oracle program') + args = parser.parse_args() + return args + +if __name__ == "__main__": + check_version('3.7') + + args = parse_arguments() + + txl_path = args.txl + native_path = args.ref + program = args.program + + stats = args.stats + verbose = args.verbose + + if program is None and native_path is None: + raise ValueError('Either program or path to native file must be' + 'provided') + + txl, native = read_logs(txl_path, native_path, program) + + if program != None and native_path != None: + with open(native_path, 'w') as w: + w.write(''.join(native)) + + compare(txl, native, stats) + diff --git a/run.py b/run.py new file mode 100755 index 0000000..f1f1060 --- /dev/null +++ b/run.py @@ -0,0 +1,214 @@ +#! /bin/python3 +import os +import re +import sys +import lldb +import shutil +import argparse + +from utils import print_separator + +verbose = False + +regnames = ['PC', + 'RAX', + 'RBX', + 'RCX', + 'RDX', + 'RSI', + 'RDI', + 'RBP', + 'RSP', + 'R8', + 'R9', + 'R10', + 'R11', + 'R12', + 'R13', + 'R14', + 'R15', + 'RFLAGS'] + +class DebuggerCallback: + def __init__(self, ostream=sys.stdout, skiplist: set = {}): + self.stream = ostream + self.regex = re.compile('(' + '|'.join(regnames) + ')$') + self.skiplist = skiplist + + @staticmethod + def parse_flags(flag_reg: int): + flags = {'ZF': 0, + 'CF': 0, + 'OF': 0, + 'SF': 0, + 'PF': 0, + 'DF': 0} + + # CF (Carry flag) Bit 0 + # PF (Parity flag) Bit 2 + # ZF (Zero flag) Bit 6 + # SF (Sign flag) Bit 7 + # TF (Trap flag) Bit 8 + # IF (Interrupt enable flag) Bit 9 + # DF (Direction flag) Bit 10 + # OF (Overflow flag) Bit 11 + flags['CF'] = int(0 != flag_reg & 1) + flags['ZF'] = int(0 != flag_reg & (1 << 6)) + flags['OF'] = int(0 != flag_reg & (1 << 11)) + flags['SF'] = int(0 != flag_reg & (1 << 7)) + flags['DF'] = int(0 != flag_reg & (1 << 10)) + flags['PF'] = int(0 != flag_reg & (1 << 1)) + return flags + + + def print_regs(self, frame): + for reg in frame.GetRegisters(): + for sub_reg in reg: + match = self.regex.match(sub_reg.GetName().upper()) + if match and match.group() == 'RFLAGS': + flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(), + base=16)) + for flag in flags: + print(f'flag {flag}:\t{hex(flags[flag])}', + file=self.stream) + elif match: + print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}", + file=self.stream) + + def print_stack(self, frame, element_count: int): + first = True + for i in range(element_count): + addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize() + error = lldb.SBError() + stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error)) + if error.Success() and not first: + print(f'{hex(stack_value)}', file=self.stream) + elif error.Success(): + print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream) + else: + print(f"Error reading memory at address 0x{addr:x}", + file=self.stream) + first=False + + def __call__(self, frame): + pc = frame.GetPC() + + # Skip this PC + if pc in self.skiplist: + self.skiplist.discard(pc) + return False + + print_separator('=', stream=self.stream, count=20) + print(f'INVOKE PC={hex(pc)}', file=self.stream) + print_separator('=', stream=self.stream, count=20) + + print("Register values:", file=self.stream) + self.print_regs(frame) + print_separator(stream=self.stream) + + print("STACK:", file=self.stream) + self.print_stack(frame, 20) + + return True # Continue execution + +class Debugger: + def __init__(self, program): + self.debugger = lldb.SBDebugger.Create() + self.debugger.SetAsync(False) + self.target = self.debugger.CreateTargetWithFileAndArch(program, + lldb.LLDB_ARCH_DEFAULT) + self.module = self.target.FindModule(self.target.GetExecutable()) + self.interpreter = self.debugger.GetCommandInterpreter() + + def set_breakpoint_by_addr(self, address: int): + command = f"b -a {address} -s {self.module.GetFileSpec().GetFilename()}" + result = lldb.SBCommandReturnObject() + self.interpreter.HandleCommand(command, result) + + if verbose: + print(f'Set breakpoint at address {hex(address)}') + + def get_breakpoints_count(self): + return self.target.GetNumBreakpoints() + + def execute(self, callback: callable): + error = lldb.SBError() + listener = self.debugger.GetListener() + process = self.target.Launch(listener, None, None, None, None, None, None, 0, + True, error) + + # Check if the process has launched successfully + if process.IsValid(): + print(f'Launched process: {process}') + else: + print('Failed to launch process', file=sys.stderr) + + while True: + state = process.GetState() + if state == lldb.eStateStopped: + for thread in process: + callback(thread.GetFrameAtIndex(0)) + process.Continue() + if state == lldb.eStateExited: + break + + self.debugger.Terminate() + + print(f'Process state: {process.GetState()}') + print('Program output:') + print(process.GetSTDOUT(1024)) + print(process.GetSTDERR(1024)) + +class ListWriter: + def __init__(self): + self.data = [] + + def write(self, s): + self.data.append(s) + + def __str__(self): + return "".join(self.data) + +class Runner: + def __init__(self, dbt_log: list, oracle_program: str): + self.log = dbt_log + self.program = oracle_program + self.debugger = Debugger(self.program) + self.writer = ListWriter() + + @staticmethod + def get_addresses(lines: list): + addresses = [] + + backlist = [] + backlist_regex = re.compile(r'^\s\s\d*:') + + skiplist = set() + for l in lines: + if l.startswith('INVOKE'): + addresses.append(int(l.split('=')[1].strip(), base=16)) + + if addresses[-1] in backlist: + skiplist.add(addresses[-1]) + backlist = [] + + if backlist_regex.match(l): + backlist.append(int(l.split()[0].split(':')[0], base=16)) + + return set(addresses), skiplist + + def run(self): + # Get all addresses to stop at + addresses, skiplist = Runner.get_addresses(self.log) + + # Set breakpoints + for address in addresses: + self.debugger.set_breakpoint_by_addr(address) + + # Sanity check + assert(self.debugger.get_breakpoints_count() == len(addresses)) + + self.debugger.execute(DebuggerCallback(self.writer, skiplist)) + + return self.writer.data + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d841c7c --- /dev/null +++ b/utils.py @@ -0,0 +1,18 @@ +#! /bin/python3 + +import sys +import shutil + +def print_separator(separator: str = '-', stream=sys.stdout, count: int = 80): + maxtermsize = count + termsize = shutil.get_terminal_size((80, 20)).columns + print(separator * min(termsize, maxtermsize), file=stream) + +def check_version(version: str): + # Script depends on ordered dicts in default dict() + split = version.split('.') + major = int(split[0]) + minor = int(split[1]) + if sys.version_info.major < major and sys.version_info.minor < minor: + raise EnvironmentError("Expected at least Python 3.7") + |