about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rwxr-xr-xcompare.py234
-rwxr-xr-xrun.py214
-rw-r--r--utils.py18
3 files changed, 466 insertions, 0 deletions
diff --git a/compare.py b/compare.py
new file mode 100755
index 0000000..36cc9c7
--- /dev/null
+++ b/compare.py
@@ -0,0 +1,234 @@
+#! /bin/python3
+import re
+import sys
+import shutil
+import argparse
+from functools import partial as bind
+
+from utils import check_version
+from utils import print_separator
+
+from run import Runner
+
+class ContextBlock:
+    regnames = ['PC',
+                'RAX',
+                'RBX',
+                'RCX',
+                'RDX',
+                'RSI',
+                'RDI',
+                'RBP',
+                'RSP',
+                'R8',
+                'R9',
+                'R10',
+                'R11',
+                'R12',
+                'R13',
+                'R14',
+                'R15',
+                'flag ZF',
+                'flag CF',
+                'flag OF',
+                'flag SF',
+                'flag PF',
+                'flag DF']
+
+    def __init__(self):
+        self.regs = {reg: None for reg in ContextBlock.regnames}
+
+    def set(self, idx: int, value: int):
+        self.regs[list(self.regs.keys())[idx]] = value
+
+    def __repr__(self):
+        return self.regs.__repr__()
+
+class Constructor:
+    def __init__(self, structure: dict):
+        self.cblocks = []
+        self.structure = structure
+        self.patterns = list(self.structure.keys())
+
+    def match(self, line: str):
+        # find patterns that match it
+        regex = re.compile("|".join(self.patterns))
+        match = regex.match(line)
+
+        idx = self.patterns.index(match.group(0)) if match else 0
+
+        pattern = self.patterns[idx]
+        register = ContextBlock.regnames[idx]
+
+        return register, self.structure[pattern](line)
+
+    def add(self, key: str, value: int):
+        if key == 'PC':
+            self.cblocks.append(ContextBlock())
+
+        if self.cblocks[-1].regs[key] != None:
+            raise RuntimeError("Reassigning register")
+
+        self.cblocks[-1].regs[key] = value
+
+def parse(lines: list, labels: list):
+    ctor = Constructor(labels)
+
+    regex = re.compile("|".join(ctor.patterns))
+    lines = [l for l in lines if regex.match(l) is not None]
+
+    for line in lines:
+        key, value = ctor.match(line)
+        ctor.add(key, value)
+
+    return ctor.cblocks
+
+def get_labels():
+    split_value = lambda x,i: int(x.split()[i], 16)
+
+    split_first = bind(split_value, i=1)
+    split_second = bind(split_value, i=2)
+
+    split_equal = lambda x,i: int(x.split('=')[i], 16)
+    labels = {'INVOKE': bind(split_equal, i=1),
+              'RAX': split_first,
+              'RBX': split_first,
+              'RCX': split_first,
+              'RDX': split_first,
+              'RSI': split_first,
+              'RDI': split_first,
+              'RBP': split_first,
+              'RSP': split_first,
+              'R8':  split_first,
+              'R9':  split_first,
+              'R10': split_first,
+              'R11': split_first,
+              'R12': split_first,
+              'R13': split_first,
+              'R14': split_first,
+              'R15': split_first,
+              'flag ZF': split_second,
+              'flag CF': split_second,
+              'flag OF': split_second,
+              'flag SF': split_second,
+              'flag PF': split_second,
+              'flag DF': split_second}
+    return labels
+
+def equivalent(val1, val2):
+    return val1 == val2
+
+def verify(translation: ContextBlock, reference: ContextBlock):
+    if translation.regs["PC"] != reference.regs["PC"]:
+        return 1
+
+    print_separator()
+    print(f'For PC={hex(translation.regs["PC"])}')
+    print_separator()
+    for el1 in translation.regs.keys():
+        for el2 in reference.regs.keys():
+            if el1 != el2:
+                continue
+
+            if translation.regs[el1] is None:
+                print(f'Element not available in translation: {el1}')
+                continue
+
+            if reference.regs[el2] is None:
+                print(f'Element not available in reference: {el2}')
+                continue
+
+            if not equivalent(translation.regs[el1], reference.regs[el2]):
+                print(f'Difference for {el1}: {hex(translation.regs[el1])} != {hex(reference.regs[el2])}')
+    return 0
+
+def compare(txl: list, native: list, stats: bool = False):
+    txl = parse(txl, get_labels())
+    native = parse(native, get_labels())
+
+    if len(txl) != len(native):
+        print(f'Different number of blocks discovered translation: {len(txl)} vs. '
+              f'reference: {len(native)}', file=sys.stderr)
+
+    unmatched_pcs = {}
+    for translation, reference in zip(txl, native):
+        if verify(translation, reference) == 1:
+            # TODO: add verbose output
+            print_separator(stream=sys.stderr)
+            print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stderr)
+            if translation.regs['PC'] not in unmatched_pcs:
+                unmatched_pcs[translation.regs['PC']] = 0
+            unmatched_pcs[translation.regs['PC']] += 1
+
+    if stats:
+        print_separator()
+        print('Statistics:')
+        print_separator()
+
+        for pc in unmatched_pcs:
+            print(f'PC {hex(pc)} unmatched {unmatched_pcs[pc]} times')
+    return 0
+
+def read_logs(txl_path, native_path, program):
+    txl = []
+    with open(txl_path, "r") as txl_file:
+        txl = txl_file.readlines()
+
+    native = []
+    if program is not None:
+        runner = Runner(txl, program)
+        native = runner.run()
+    else:
+        with open(native_path, "r") as native_file:
+            native = native_file.readlines()
+
+    return txl, native
+
+def parse_arguments():
+    parser = argparse.ArgumentParser(description='Comparator for emulator logs to reference')
+    parser.add_argument('-p', '--program',
+                        type=str,
+                        help='Path to oracle program')
+    parser.add_argument('-r', '--ref',
+                        type=str,
+                        required=True,
+                        help='Path to the reference log (gathered with run.sh)')
+    parser.add_argument('-t', '--txl',
+                        type=str,
+                        required=True,
+                        help='Path to the translation log (gathered via Arancini)')
+    parser.add_argument('-s', '--stats',
+                        action='store_true',
+                        default='store_false',
+                        help='Run statistics on comparisons')
+    parser.add_argument('-v', '--verbose',
+                        action='store_true',
+                        default='store_true',
+                        help='Path to oracle program')
+    args = parser.parse_args()
+    return args
+
+if __name__ == "__main__":
+    check_version('3.7')
+
+    args = parse_arguments()
+
+    txl_path = args.txl
+    native_path = args.ref
+    program = args.program
+
+    stats = args.stats
+    verbose = args.verbose
+
+    if program is None and native_path is None:
+        raise ValueError('Either program or path to native file must be'
+                         'provided')
+
+    txl, native = read_logs(txl_path, native_path, program)
+
+    if program != None and native_path != None:
+        with open(native_path, 'w') as w:
+            w.write(''.join(native))
+
+    compare(txl, native, stats)
+
diff --git a/run.py b/run.py
new file mode 100755
index 0000000..f1f1060
--- /dev/null
+++ b/run.py
@@ -0,0 +1,214 @@
+#! /bin/python3
+import os
+import re
+import sys
+import lldb
+import shutil
+import argparse
+
+from utils import print_separator
+
+verbose = False
+
+regnames = ['PC',
+            'RAX',
+            'RBX',
+            'RCX',
+            'RDX',
+            'RSI',
+            'RDI',
+            'RBP',
+            'RSP',
+            'R8',
+            'R9',
+            'R10',
+            'R11',
+            'R12',
+            'R13',
+            'R14',
+            'R15',
+            'RFLAGS']
+
+class DebuggerCallback:
+    def __init__(self, ostream=sys.stdout, skiplist: set = {}):
+        self.stream = ostream
+        self.regex = re.compile('(' + '|'.join(regnames) + ')$')
+        self.skiplist = skiplist
+
+    @staticmethod
+    def parse_flags(flag_reg: int):
+        flags = {'ZF': 0,
+                 'CF': 0,
+                 'OF': 0,
+                 'SF': 0,
+                 'PF': 0,
+                 'DF': 0}
+
+        # CF (Carry flag) Bit 0
+        # PF (Parity flag) Bit 2
+        # ZF (Zero flag) Bit 6
+        # SF (Sign flag) Bit 7
+        # TF (Trap flag) Bit 8
+        # IF (Interrupt enable flag) Bit 9
+        # DF (Direction flag) Bit 10
+        # OF (Overflow flag) Bit 11
+        flags['CF'] = int(0 != flag_reg & 1)
+        flags['ZF'] = int(0 != flag_reg & (1 << 6))
+        flags['OF'] = int(0 != flag_reg & (1 << 11))
+        flags['SF'] = int(0 != flag_reg & (1 << 7))
+        flags['DF'] = int(0 != flag_reg & (1 << 10))
+        flags['PF'] = int(0 != flag_reg & (1 << 1))
+        return flags
+
+
+    def print_regs(self, frame):
+        for reg in frame.GetRegisters():
+            for sub_reg in reg:
+                match = self.regex.match(sub_reg.GetName().upper())
+                if match and match.group() == 'RFLAGS':
+                    flags = DebuggerCallback.parse_flags(int(sub_reg.GetValue(),
+                                                             base=16))
+                    for flag in flags:
+                        print(f'flag {flag}:\t{hex(flags[flag])}',
+                                                     file=self.stream)
+                elif match:
+                    print(f"{sub_reg.GetName().upper()}:\t\t {hex(int(sub_reg.GetValue(), base=16))}",
+                          file=self.stream)
+
+    def print_stack(self, frame, element_count: int):
+        first = True
+        for i in range(element_count):
+            addr = frame.GetSP() + i * frame.GetThread().GetProcess().GetAddressByteSize()
+            error = lldb.SBError()
+            stack_value = int(frame.GetThread().GetProcess().ReadPointerFromMemory(addr, error))
+            if error.Success() and not first:
+                print(f'{hex(stack_value)}', file=self.stream)
+            elif error.Success():
+                print(f'{hex(stack_value)}\t\t<- rsp', file=self.stream)
+            else:
+                print(f"Error reading memory at address 0x{addr:x}",
+                      file=self.stream)
+            first=False
+
+    def __call__(self, frame):
+        pc = frame.GetPC()
+
+        # Skip this PC
+        if pc in self.skiplist:
+            self.skiplist.discard(pc)
+            return False
+
+        print_separator('=', stream=self.stream, count=20)
+        print(f'INVOKE PC={hex(pc)}', file=self.stream)
+        print_separator('=', stream=self.stream, count=20)
+
+        print("Register values:", file=self.stream)
+        self.print_regs(frame)
+        print_separator(stream=self.stream)
+
+        print("STACK:", file=self.stream)
+        self.print_stack(frame, 20)
+
+        return True  # Continue execution
+
+class Debugger:
+    def __init__(self, program):
+        self.debugger = lldb.SBDebugger.Create()
+        self.debugger.SetAsync(False)
+        self.target = self.debugger.CreateTargetWithFileAndArch(program,
+                                                                lldb.LLDB_ARCH_DEFAULT)
+        self.module = self.target.FindModule(self.target.GetExecutable())
+        self.interpreter = self.debugger.GetCommandInterpreter()
+
+    def set_breakpoint_by_addr(self, address: int):
+        command = f"b -a {address} -s {self.module.GetFileSpec().GetFilename()}"
+        result = lldb.SBCommandReturnObject()
+        self.interpreter.HandleCommand(command, result)
+
+        if verbose:
+            print(f'Set breakpoint at address {hex(address)}')
+
+    def get_breakpoints_count(self):
+        return self.target.GetNumBreakpoints()
+
+    def execute(self, callback: callable):
+        error = lldb.SBError()
+        listener = self.debugger.GetListener()
+        process = self.target.Launch(listener, None, None, None, None, None, None, 0,
+                                     True, error)
+
+        # Check if the process has launched successfully
+        if process.IsValid():
+            print(f'Launched process: {process}')
+        else:
+            print('Failed to launch process', file=sys.stderr)
+
+        while True:
+            state = process.GetState()
+            if state == lldb.eStateStopped:
+                 for thread in process:
+                    callback(thread.GetFrameAtIndex(0))
+                 process.Continue()
+            if state == lldb.eStateExited:
+                break
+
+        self.debugger.Terminate()
+
+        print(f'Process state: {process.GetState()}')
+        print('Program output:')
+        print(process.GetSTDOUT(1024))
+        print(process.GetSTDERR(1024))
+
+class ListWriter:
+    def __init__(self):
+        self.data = []
+
+    def write(self, s):
+        self.data.append(s)
+
+    def __str__(self):
+        return "".join(self.data)
+
+class Runner:
+    def __init__(self, dbt_log: list, oracle_program: str):
+        self.log = dbt_log
+        self.program = oracle_program
+        self.debugger = Debugger(self.program)
+        self.writer = ListWriter()
+
+    @staticmethod
+    def get_addresses(lines: list):
+        addresses = []
+
+        backlist = []
+        backlist_regex = re.compile(r'^\s\s\d*:')
+
+        skiplist = set()
+        for l in lines:
+            if l.startswith('INVOKE'):
+                addresses.append(int(l.split('=')[1].strip(), base=16))
+
+                if addresses[-1] in backlist:
+                    skiplist.add(addresses[-1])
+                backlist = []
+
+            if backlist_regex.match(l):
+                backlist.append(int(l.split()[0].split(':')[0], base=16))
+
+        return set(addresses), skiplist
+
+    def run(self):
+        # Get all addresses to stop at
+        addresses, skiplist = Runner.get_addresses(self.log)
+
+        # Set breakpoints
+        for address in addresses:
+            self.debugger.set_breakpoint_by_addr(address)
+
+        # Sanity check
+        assert(self.debugger.get_breakpoints_count() == len(addresses))
+
+        self.debugger.execute(DebuggerCallback(self.writer, skiplist))
+
+        return self.writer.data
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..d841c7c
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,18 @@
+#! /bin/python3
+
+import sys
+import shutil
+
+def print_separator(separator: str = '-', stream=sys.stdout, count: int = 80):
+    maxtermsize = count
+    termsize = shutil.get_terminal_size((80, 20)).columns
+    print(separator * min(termsize, maxtermsize), file=stream)
+
+def check_version(version: str):
+    # Script depends on ordered dicts in default dict()
+    split = version.split('.')
+    major = int(split[0])
+    minor = int(split[1])
+    if sys.version_info.major < major and sys.version_info.minor < minor:
+        raise EnvironmentError("Expected at least Python 3.7")
+