diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | arancini.py | 94 | ||||
| -rw-r--r-- | arch/arch.py | 6 | ||||
| -rw-r--r-- | arch/x86.py | 33 | ||||
| -rw-r--r--[-rwxr-xr-x] | compare.py | 265 | ||||
| -rwxr-xr-x | main.py | 92 | ||||
| -rw-r--r--[-rwxr-xr-x] | run.py | 100 | ||||
| -rw-r--r-- | snapshot.py | 38 | ||||
| -rw-r--r-- | utils.py | 2 |
9 files changed, 326 insertions, 306 deletions
diff --git a/.gitignore b/.gitignore index ee32d85..4586156 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ build* *.dot build*/ out-*/ -__pycache__/* +__pycache__/ diff --git a/arancini.py b/arancini.py new file mode 100644 index 0000000..adbe6c8 --- /dev/null +++ b/arancini.py @@ -0,0 +1,94 @@ +"""Tools for working with arancini's output.""" + +import re +from functools import partial as bind + +from snapshot import ProgramState +from arch.arch import Arch + +def parse_break_addresses(lines: list[str]) -> set[int]: + """Parse all breakpoint addresses from an arancini log.""" + addresses = set() + for l in lines: + if l.startswith('INVOKE'): + addr = int(l.split('=')[1].strip(), base=16) + addresses.add(addr) + + return addresses + +def parse(lines: list[str], arch: Arch) -> list[ProgramState]: + """Parse an arancini log into a list of snapshots. + + :return: A list of program snapshots. + """ + + labels = get_labels() + + # The regex decides for a line whether it contains a register + # based on a match with that register's label. + regex = re.compile("|".join(labels.keys())) + + def try_parse_line(line: str) -> tuple[str, int] | None: + """Try to parse a register name and that register's value from a line. + + :return: A register name and a register value if the line contains + that information. None if parsing fails. + """ + match = regex.match(line) + if match: + label = match.group(0) + register, get_reg_value = labels[label] + return register, get_reg_value(line) + return None + + # Parse a list of program snapshots + snapshots = [] + for line in lines: + if 'Backwards' in line and len(snapshots) > 0: + snapshots[-1].set_backwards() + continue + + match = try_parse_line(line) + if match: + reg, value = match + if reg == 'PC': + snapshots.append(ProgramState(arch)) + snapshots[-1].set(reg, value) + + return snapshots + +def get_labels(): + """Construct a helper structure for the arancini log parser.""" + split_value = lambda x,i: int(x.split()[i], 16) + + split_first = bind(split_value, i=1) + split_second = bind(split_value, i=2) + + split_equal = lambda x,i: int(x.split('=')[i], 16) + + # A mapping from regex patterns to the register name and a + # function that extracts that register's value from the line + labels = {'INVOKE': ('PC', bind(split_equal, i=1)), + 'RAX': ('RAX', split_first), + 'RBX': ('RBX', split_first), + 'RCX': ('RCX', split_first), + 'RDX': ('RDX', split_first), + 'RSI': ('RSI', split_first), + 'RDI': ('RDI', split_first), + 'RBP': ('RBP', split_first), + 'RSP': ('RSP', split_first), + 'R8': ('R8', split_first), + 'R9': ('R9', split_first), + 'R10': ('R10', split_first), + 'R11': ('R11', split_first), + 'R12': ('R12', split_first), + 'R13': ('R13', split_first), + 'R14': ('R14', split_first), + 'R15': ('R15', split_first), + 'flag ZF': ('flag ZF', split_second), + 'flag CF': ('flag CF', split_second), + 'flag OF': ('flag OF', split_second), + 'flag SF': ('flag SF', split_second), + 'flag PF': ('flag PF', split_second), + 'flag DF': ('flag DF', split_second)} + return labels diff --git a/arch/arch.py b/arch/arch.py new file mode 100644 index 0000000..36a4e3f --- /dev/null +++ b/arch/arch.py @@ -0,0 +1,6 @@ +class Arch(): + def __init__(self, regnames: list[str]): + self.regnames = regnames + + def __eq__(self, other): + return self.regnames == other.regnames diff --git a/arch/x86.py b/arch/x86.py new file mode 100644 index 0000000..0f60457 --- /dev/null +++ b/arch/x86.py @@ -0,0 +1,33 @@ +"""Architexture-specific configuration.""" + +from .arch import Arch + +# Names of registers in the architexture +regnames = ['PC', + 'RAX', + 'RBX', + 'RCX', + 'RDX', + 'RSI', + 'RDI', + 'RBP', + 'RSP', + 'R8', + 'R9', + 'R10', + 'R11', + 'R12', + 'R13', + 'R14', + 'R15', + 'RFLAGS', + 'flag ZF', + 'flag CF', + 'flag OF', + 'flag SF', + 'flag PF', + 'flag DF'] + +class ArchX86(Arch): + def __init__(self): + super().__init__(regnames) diff --git a/compare.py b/compare.py index f4576dd..df8c378 100755..100644 --- a/compare.py +++ b/compare.py @@ -1,150 +1,17 @@ -#! /bin/python3 -import re -import sys -import shutil -import argparse -from typing import List, Callable -from functools import partial as bind - -from utils import check_version +from snapshot import ProgramState from utils import print_separator -from run import Runner - -progressive = False - -class ContextBlock: - regnames = ['PC', - 'RAX', - 'RBX', - 'RCX', - 'RDX', - 'RSI', - 'RDI', - 'RBP', - 'RSP', - 'R8', - 'R9', - 'R10', - 'R11', - 'R12', - 'R13', - 'R14', - 'R15', - 'flag ZF', - 'flag CF', - 'flag OF', - 'flag SF', - 'flag PF', - 'flag DF'] - - def __init__(self): - dict_type = dict[str, int|None] # A register may not have a value - self.regs = dict_type({reg: None for reg in ContextBlock.regnames}) - self.has_backwards = False - self.matched = False - - def set_backwards(self): - self.has_backwards = True - - def set(self, reg: str, value: int): - """Assign a value to a register. - - :raises RuntimeError: if the register already has a value. - """ - if self.regs[reg] != None: - raise RuntimeError("Reassigning register") - self.regs[reg] = value - - def __repr__(self): - return self.regs.__repr__() - -class Constructor: - """Builds a list of context blocks.""" - def __init__(self, structure: dict[str, tuple[str, Callable[[str], int]]]): - self.cblocks = list[ContextBlock]() - self.labels = structure - self.regex = re.compile("|".join(structure.keys())) - - def match(self, line: str) -> (tuple[str, int] | None): - """Find a register name and that register's value in a line. - - :return: A register name and a register value. - """ - match = self.regex.match(line) - if match: - label = match.group(0) - register, get_reg_value = self.labels[label] - return register, get_reg_value(line) - - return None - - def add_backwards(self): - self.cblocks[-1].set_backwards() - - def add(self, reg: str, value: int): - if reg == 'PC': - self.cblocks.append(ContextBlock()) - self.cblocks[-1].set(reg, value) - -def parse(lines: list[str], labels: dict): - """Parse a list of lines into a list of cblocks.""" - ctor = Constructor(labels) - for line in lines: - if 'Backwards' in line: - ctor.add_backwards() - continue - - match = ctor.match(line) - if match: - key, value = match - ctor.add(key, value) - - return ctor.cblocks - -def get_labels(): - split_value = lambda x,i: int(x.split()[i], 16) - - split_first = bind(split_value, i=1) - split_second = bind(split_value, i=2) - - split_equal = lambda x,i: int(x.split('=')[i], 16) - - # A mapping from regex patterns to the register name and a - # function that extracts that register's value from the line - labels = {'INVOKE': ('PC', bind(split_equal, i=1)), - 'RAX': ('RAX', split_first), - 'RBX': ('RBX', split_first), - 'RCX': ('RCX', split_first), - 'RDX': ('RDX', split_first), - 'RSI': ('RSI', split_first), - 'RDI': ('RDI', split_first), - 'RBP': ('RBP', split_first), - 'RSP': ('RSP', split_first), - 'R8': ('R8', split_first), - 'R9': ('R9', split_first), - 'R10': ('R10', split_first), - 'R11': ('R11', split_first), - 'R12': ('R12', split_first), - 'R13': ('R13', split_first), - 'R14': ('R14', split_first), - 'R15': ('R15', split_first), - 'flag ZF': ('flag ZF', split_second), - 'flag CF': ('flag CF', split_second), - 'flag OF': ('flag OF', split_second), - 'flag SF': ('flag SF', split_second), - 'flag PF': ('flag PF', split_second), - 'flag DF': ('flag DF', split_second)} - return labels - -def calc_transformation(previous: ContextBlock, current: ContextBlock): +def calc_transformation(previous: ProgramState, current: ProgramState): """Calculate the difference between two context blocks. :return: A context block that contains in its registers the difference between the corresponding input blocks' register values. """ - transformation = ContextBlock() - for reg in ContextBlock.regnames: + assert(previous.arch == current.arch) + + arch = previous.arch + transformation = ProgramState(arch) + for reg in arch.regnames: prev_val, cur_val = previous.regs[reg], current.regs[reg] if prev_val is not None and cur_val is not None: transformation.regs[reg] = cur_val - prev_val @@ -158,15 +25,17 @@ def equivalent(val1, val2, transformation, previous_translation): # TODO: maybe incorrect return val1 - previous_translation == transformation -def verify(translation: ContextBlock, reference: ContextBlock, - transformation: ContextBlock, previous_translation: ContextBlock): +def verify(translation: ProgramState, reference: ProgramState, + transformation: ProgramState, previous_translation: ProgramState): + assert(translation.arch == reference.arch) + if translation.regs["PC"] != reference.regs["PC"]: return 1 print_separator() - print(f'For PC={hex(translation.regs["PC"])}') + print(f'For PC={translation.as_repr("PC")}') print_separator() - for reg in ContextBlock.regnames: + for reg in translation.arch.regnames: if translation.regs[reg] is None: print(f'Element not available in translation: {reg}') elif reference.regs[reg] is None: @@ -174,16 +43,27 @@ def verify(translation: ContextBlock, reference: ContextBlock, elif not equivalent(translation.regs[reg], reference.regs[reg], transformation.regs[reg], previous_translation.regs[reg]): - txl = hex(translation.regs[reg]) - ref = hex(reference.regs[reg]) + txl = translation.as_repr(reg) + ref = reference.as_repr(reg) print(f'Difference for {reg}: {txl} != {ref}') return 0 -def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = False): +def compare(txl: list[ProgramState], + native: list[ProgramState], + progressive: bool = False, + stats: bool = False): + """Compare two lists of snapshots and output the differences. + + :param txl: The translated, and possibly faulty, state of the program. + :param native: The 'correct' reference state of the program. + :param progressive: + :param stats: + """ + if len(txl) != len(native): - print(f'Different number of blocks discovered translation: {len(txl)} vs. ' - f'reference: {len(native)}', file=sys.stdout) + print(f'Different numbers of blocks discovered: ' + f'{len(txl)} in translation vs. {len(native)} in reference.') previous_reference = native[0] previous_translation = txl[0] @@ -210,8 +90,8 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F if i == len(native): matched = False # TODO: add verbose output - print_separator(stream=sys.stdout) - print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout) + print_separator() + print(f'No match for PC {hex(translation.regs["PC"])}') if translation.regs['PC'] not in unmatched_pcs: unmatched_pcs[translation.regs['PC']] = 0 unmatched_pcs[translation.regs['PC']] += 1 @@ -238,12 +118,14 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F if matched: i += 1 else: + txl = iter(txl) + native = iter(native) for translation, reference in zip(txl, native): transformation = calc_transformation(previous_reference, reference) if verify(translation, reference, transformation, previous_translation) == 1: # TODO: add verbose output - print_separator(stream=sys.stdout) - print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout) + print_separator() + print(f'No match for PC {hex(translation.regs["PC"])}') if translation.regs['PC'] not in unmatched_pcs: unmatched_pcs[translation.regs['PC']] = 0 unmatched_pcs[translation.regs['PC']] += 1 @@ -277,80 +159,3 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F if ref.matched == False: unmatched_count += 1 return 0 - -def read_logs(txl_path, native_path, program): - txl = [] - with open(txl_path, "r") as txl_file: - txl = txl_file.readlines() - - native = [] - if program is not None: - runner = Runner(txl, program) - native = runner.run() - else: - with open(native_path, "r") as native_file: - native = native_file.readlines() - - return txl, native - -def parse_arguments(): - parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference') - parser.add_argument('-p', '--program', - type=str, - help='Path to oracle program') - parser.add_argument('-r', '--ref', - type=str, - required=True, - help='Path to the reference log (gathered with run.sh)') - parser.add_argument('-t', '--txl', - type=str, - required=True, - help='Path to the translation log (gathered via Arancini)') - parser.add_argument('-s', '--stats', - action='store_true', - default=False, - help='Run statistics on comparisons') - parser.add_argument('-v', '--verbose', - action='store_true', - default=True, - help='Path to oracle program') - parser.add_argument('--progressive', - action='store_true', - default=False, - help='Try to match exhaustively before declaring \ - mismatch') - args = parser.parse_args() - return args - -if __name__ == "__main__": - check_version('3.7') - - args = parse_arguments() - - txl_path = args.txl - native_path = args.ref - program = args.program - - stats = args.stats - verbose = args.verbose - progressive = args.progressive - - if verbose: - print("Enabling verbose program output") - print(f"Verbose: {verbose}") - print(f"Statistics: {stats}") - print(f"Progressive: {progressive}") - - if program is None and native_path is None: - raise ValueError('Either program or path to native file must be' - 'provided') - - txl, native = read_logs(txl_path, native_path, program) - - if program != None and native_path != None: - with open(native_path, 'w') as w: - w.write(''.join(native)) - - txl = parse(txl, get_labels()) - native = parse(native, get_labels()) - compare(txl, native, stats) diff --git a/main.py b/main.py new file mode 100755 index 0000000..076dc0e --- /dev/null +++ b/main.py @@ -0,0 +1,92 @@ +#! /bin/python3 + +import argparse + +import arancini +from arch import x86 +from compare import compare +from run import run_native_execution +from utils import check_version + +def read_logs(txl_path, native_path, program): + txl = [] + with open(txl_path, "r") as txl_file: + txl = txl_file.readlines() + + native = [] + if program is not None: + breakpoints = arancini.parse_break_addresses(txl) + native = run_native_execution(program, breakpoints) + else: + assert(native_path is not None) + with open(native_path, "r") as native_file: + native = native_file.readlines() + + return txl, native + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference') + parser.add_argument('-p', '--program', + type=str, + help='Path to oracle program') + parser.add_argument('-r', '--ref', + type=str, + required=True, + help='Path to the reference log (gathered with run.sh)') + parser.add_argument('-t', '--txl', + type=str, + required=True, + help='Path to the translation log (gathered via Arancini)') + parser.add_argument('-s', '--stats', + action='store_true', + default=False, + help='Run statistics on comparisons') + parser.add_argument('-v', '--verbose', + action='store_true', + default=True, + help='Path to oracle program') + parser.add_argument('--progressive', + action='store_true', + default=False, + help='Try to match exhaustively before declaring \ + mismatch') + args = parser.parse_args() + return args + +def main(): + args = parse_arguments() + + txl_path = args.txl + native_path = args.ref + program = args.program + + stats = args.stats + verbose = args.verbose + progressive = args.progressive + + # Our architexture + arch = x86.ArchX86() + + if verbose: + print("Enabling verbose program output") + print(f"Verbose: {verbose}") + print(f"Statistics: {stats}") + print(f"Progressive: {progressive}") + + if program is None and native_path is None: + raise ValueError('Either program or path to native file must be' + 'provided') + + txl, native = read_logs(txl_path, native_path, program) + + if program != None and native_path != None: + with open(native_path, 'w') as w: + w.write(''.join(native)) + + txl = arancini.parse(txl, arch) + native = arancini.parse(native, arch) + compare(txl, native, stats) + +if __name__ == "__main__": + check_version('3.7') + main() diff --git a/run.py b/run.py index f1f1060..9b51fb5 100755..100644 --- a/run.py +++ b/run.py @@ -1,39 +1,23 @@ -#! /bin/python3 -import os +"""Functionality to execute native programs and collect snapshots via lldb.""" + import re import sys import lldb -import shutil -import argparse +from typing import Callable +# TODO: The debugger callback is currently specific to a single architexture. +# We should make it generic. +from arch import x86 from utils import print_separator verbose = False -regnames = ['PC', - 'RAX', - 'RBX', - 'RCX', - 'RDX', - 'RSI', - 'RDI', - 'RBP', - 'RSP', - 'R8', - 'R9', - 'R10', - 'R11', - 'R12', - 'R13', - 'R14', - 'R15', - 'RFLAGS'] - class DebuggerCallback: - def __init__(self, ostream=sys.stdout, skiplist: set = {}): + """At every breakpoint, writes register contents to a stream.""" + + def __init__(self, ostream=sys.stdout): self.stream = ostream - self.regex = re.compile('(' + '|'.join(regnames) + ')$') - self.skiplist = skiplist + self.regex = re.compile('(' + '|'.join(x86.regnames) + ')$') @staticmethod def parse_flags(flag_reg: int): @@ -60,7 +44,6 @@ class DebuggerCallback: flags['PF'] = int(0 != flag_reg & (1 << 1)) return flags - def print_regs(self, frame): for reg in frame.GetRegisters(): for sub_reg in reg: @@ -93,11 +76,6 @@ class DebuggerCallback: def __call__(self, frame): pc = frame.GetPC() - # Skip this PC - if pc in self.skiplist: - self.skiplist.discard(pc) - return False - print_separator('=', stream=self.stream, count=20) print(f'INVOKE PC={hex(pc)}', file=self.stream) print_separator('=', stream=self.stream, count=20) @@ -109,8 +87,6 @@ class DebuggerCallback: print("STACK:", file=self.stream) self.print_stack(frame, 20) - return True # Continue execution - class Debugger: def __init__(self, program): self.debugger = lldb.SBDebugger.Create() @@ -131,7 +107,7 @@ class Debugger: def get_breakpoints_count(self): return self.target.GetNumBreakpoints() - def execute(self, callback: callable): + def execute(self, callback: Callable): error = lldb.SBError() listener = self.debugger.GetListener() process = self.target.Launch(listener, None, None, None, None, None, None, 0, @@ -169,46 +145,24 @@ class ListWriter: def __str__(self): return "".join(self.data) -class Runner: - def __init__(self, dbt_log: list, oracle_program: str): - self.log = dbt_log - self.program = oracle_program - self.debugger = Debugger(self.program) - self.writer = ListWriter() - - @staticmethod - def get_addresses(lines: list): - addresses = [] - - backlist = [] - backlist_regex = re.compile(r'^\s\s\d*:') - - skiplist = set() - for l in lines: - if l.startswith('INVOKE'): - addresses.append(int(l.split('=')[1].strip(), base=16)) - - if addresses[-1] in backlist: - skiplist.add(addresses[-1]) - backlist = [] - - if backlist_regex.match(l): - backlist.append(int(l.split()[0].split(':')[0], base=16)) - - return set(addresses), skiplist - - def run(self): - # Get all addresses to stop at - addresses, skiplist = Runner.get_addresses(self.log) +def run_native_execution(oracle_program: str, breakpoints: set[int]): + """Gather snapshots from a native execution via an external debugger. - # Set breakpoints - for address in addresses: - self.debugger.set_breakpoint_by_addr(address) + :param oracle_program: Program to execute. + :param breakpoints: List of addresses at which to break and record the + program's state. - # Sanity check - assert(self.debugger.get_breakpoints_count() == len(addresses)) + :return: A textual log of the program's execution in arancini's log format. + """ + debugger = Debugger(oracle_program) + writer = ListWriter() - self.debugger.execute(DebuggerCallback(self.writer, skiplist)) + # Set breakpoints + for address in breakpoints: + debugger.set_breakpoint_by_addr(address) + assert(debugger.get_breakpoints_count() == len(breakpoints)) - return self.writer.data + # Execute the native program + debugger.execute(DebuggerCallback(writer)) + return writer.data diff --git a/snapshot.py b/snapshot.py new file mode 100644 index 0000000..d5136ad --- /dev/null +++ b/snapshot.py @@ -0,0 +1,38 @@ +from arch.arch import Arch + +class ProgramState(): + """A snapshot of the program's state.""" + def __init__(self, arch: Arch): + self.arch = arch + + dict_t = dict[str, int] + self.regs = dict_t({ reg: None for reg in arch.regnames }) + self.has_backwards = False + self.matched = False + + def set_backwards(self): + self.has_backwards = True + + def set(self, reg: str, value: int): + """Assign a value to a register. + + :raises RuntimeError: if the register already has a value. + """ + assert(reg in self.arch.regnames) + + if self.regs[reg] != None: + raise RuntimeError("Reassigning register") + self.regs[reg] = value + + def as_repr(self, reg: str): + """Get a representational string of a register's value.""" + assert(reg in self.arch.regnames) + + value = self.regs[reg] + if value is not None: + return hex(value) + else: + return "<none>" + + def __repr__(self): + return self.regs.__repr__() diff --git a/utils.py b/utils.py index d841c7c..1390283 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,3 @@ -#! /bin/python3 - import sys import shutil |