diff options
| -rwxr-xr-x | compare.py | 186 |
1 files changed, 94 insertions, 92 deletions
diff --git a/compare.py b/compare.py index ffd1e93..f4576dd 100755 --- a/compare.py +++ b/compare.py @@ -3,7 +3,7 @@ import re import sys import shutil import argparse -from typing import List +from typing import List, Callable from functools import partial as bind from utils import check_version @@ -39,73 +39,66 @@ class ContextBlock: 'flag DF'] def __init__(self): - self.regs = {reg: None for reg in ContextBlock.regnames} + dict_type = dict[str, int|None] # A register may not have a value + self.regs = dict_type({reg: None for reg in ContextBlock.regnames}) self.has_backwards = False self.matched = False def set_backwards(self): self.has_backwards = True - def set(self, idx: int, value: int): - self.regs[list(self.regs.keys())[idx]] = value + def set(self, reg: str, value: int): + """Assign a value to a register. + + :raises RuntimeError: if the register already has a value. + """ + if self.regs[reg] != None: + raise RuntimeError("Reassigning register") + self.regs[reg] = value def __repr__(self): return self.regs.__repr__() class Constructor: - def __init__(self, structure: dict): - self.cblocks = [] - self.structure = structure - self.patterns = list(self.structure.keys()) - - def match(self, line: str): - # find patterns that match it - regex = re.compile("|".join(self.patterns)) - match = regex.match(line) + """Builds a list of context blocks.""" + def __init__(self, structure: dict[str, tuple[str, Callable[[str], int]]]): + self.cblocks = list[ContextBlock]() + self.labels = structure + self.regex = re.compile("|".join(structure.keys())) - idx = self.patterns.index(match.group(0)) if match else 0 + def match(self, line: str) -> (tuple[str, int] | None): + """Find a register name and that register's value in a line. - pattern = self.patterns[idx] - register = ContextBlock.regnames[idx] + :return: A register name and a register value. + """ + match = self.regex.match(line) + if match: + label = match.group(0) + register, get_reg_value = self.labels[label] + return register, get_reg_value(line) - return register, self.structure[pattern](line) + return None def add_backwards(self): self.cblocks[-1].set_backwards() - def add(self, key: str, value: int): - if key == 'PC': + def add(self, reg: str, value: int): + if reg == 'PC': self.cblocks.append(ContextBlock()) + self.cblocks[-1].set(reg, value) - if self.cblocks[-1].regs[key] != None: - raise RuntimeError("Reassigning register") - - self.cblocks[-1].regs[key] = value - -class Transformations: - def __init__(self, previous: ContextBlock, current: ContextBlock): - self.transformation = ContextBlock() - for el1 in current.regs.keys(): - for el2 in previous.regs.keys(): - if el1 != el2: - continue - self.transformation.regs[el1] = current.regs[el1] - previous.regs[el2] - -def parse(lines: list, labels: list): +def parse(lines: list[str], labels: dict): + """Parse a list of lines into a list of cblocks.""" ctor = Constructor(labels) - - patterns = ctor.patterns.copy() - patterns.append('Backwards') - regex = re.compile("|".join(patterns)) - lines = [l for l in lines if regex.match(l) is not None] - for line in lines: if 'Backwards' in line: ctor.add_backwards() continue - key, value = ctor.match(line) - ctor.add(key, value) + match = ctor.match(line) + if match: + key, value = match + ctor.add(key, value) return ctor.cblocks @@ -116,31 +109,48 @@ def get_labels(): split_second = bind(split_value, i=2) split_equal = lambda x,i: int(x.split('=')[i], 16) - labels = {'INVOKE': bind(split_equal, i=1), - 'RAX': split_first, - 'RBX': split_first, - 'RCX': split_first, - 'RDX': split_first, - 'RSI': split_first, - 'RDI': split_first, - 'RBP': split_first, - 'RSP': split_first, - 'R8': split_first, - 'R9': split_first, - 'R10': split_first, - 'R11': split_first, - 'R12': split_first, - 'R13': split_first, - 'R14': split_first, - 'R15': split_first, - 'flag ZF': split_second, - 'flag CF': split_second, - 'flag OF': split_second, - 'flag SF': split_second, - 'flag PF': split_second, - 'flag DF': split_second} + + # A mapping from regex patterns to the register name and a + # function that extracts that register's value from the line + labels = {'INVOKE': ('PC', bind(split_equal, i=1)), + 'RAX': ('RAX', split_first), + 'RBX': ('RBX', split_first), + 'RCX': ('RCX', split_first), + 'RDX': ('RDX', split_first), + 'RSI': ('RSI', split_first), + 'RDI': ('RDI', split_first), + 'RBP': ('RBP', split_first), + 'RSP': ('RSP', split_first), + 'R8': ('R8', split_first), + 'R9': ('R9', split_first), + 'R10': ('R10', split_first), + 'R11': ('R11', split_first), + 'R12': ('R12', split_first), + 'R13': ('R13', split_first), + 'R14': ('R14', split_first), + 'R15': ('R15', split_first), + 'flag ZF': ('flag ZF', split_second), + 'flag CF': ('flag CF', split_second), + 'flag OF': ('flag OF', split_second), + 'flag SF': ('flag SF', split_second), + 'flag PF': ('flag PF', split_second), + 'flag DF': ('flag DF', split_second)} return labels +def calc_transformation(previous: ContextBlock, current: ContextBlock): + """Calculate the difference between two context blocks. + + :return: A context block that contains in its registers the difference + between the corresponding input blocks' register values. + """ + transformation = ContextBlock() + for reg in ContextBlock.regnames: + prev_val, cur_val = previous.regs[reg], current.regs[reg] + if prev_val is not None and cur_val is not None: + transformation.regs[reg] = cur_val - prev_val + + return transformation + def equivalent(val1, val2, transformation, previous_translation): if val1 == val2: return True @@ -149,35 +159,28 @@ def equivalent(val1, val2, transformation, previous_translation): return val1 - previous_translation == transformation def verify(translation: ContextBlock, reference: ContextBlock, - transformation: Transformations, previous_translation: ContextBlock): + transformation: ContextBlock, previous_translation: ContextBlock): if translation.regs["PC"] != reference.regs["PC"]: return 1 print_separator() print(f'For PC={hex(translation.regs["PC"])}') print_separator() - for el1 in translation.regs.keys(): - for el2 in reference.regs.keys(): - if el1 != el2: - continue - - if translation.regs[el1] is None: - print(f'Element not available in translation: {el1}') - continue - - if reference.regs[el2] is None: - print(f'Element not available in reference: {el2}') - continue - - if not equivalent(translation.regs[el1], reference.regs[el2], - transformation.regs[el1], previous_translation.regs[el1]): - print(f'Difference for {el1}: {hex(translation.regs[el1])} != {hex(reference.regs[el2])}') + for reg in ContextBlock.regnames: + if translation.regs[reg] is None: + print(f'Element not available in translation: {reg}') + elif reference.regs[reg] is None: + print(f'Element not available in reference: {reg}') + elif not equivalent(translation.regs[reg], reference.regs[reg], + transformation.regs[reg], + previous_translation.regs[reg]): + txl = hex(translation.regs[reg]) + ref = hex(reference.regs[reg]) + print(f'Difference for {reg}: {txl} != {ref}') + return 0 def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = False): - txl = parse(txl, get_labels()) - native = parse(native, get_labels()) - if len(txl) != len(native): print(f'Different number of blocks discovered translation: {len(txl)} vs. ' f'reference: {len(native)}', file=sys.stdout) @@ -194,8 +197,8 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F while i < len(native): reference = native[i] - transformations = Transformations(previous_reference, reference) - if verify(translation, reference, transformations.transformation, previous_translation) == 0: + transformation = calc_transformation(previous_reference, reference) + if verify(translation, reference, transformation, previous_translation) == 0: reference.matched = True break @@ -235,11 +238,9 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F if matched: i += 1 else: - txl = iter(txl) - native = iter(native) for translation, reference in zip(txl, native): - transformations = Transformations(previous_reference, reference) - if verify(translation, reference, transformations.transformation, previous_translation) == 1: + transformation = calc_transformation(previous_reference, reference) + if verify(translation, reference, transformation, previous_translation) == 1: # TODO: add verbose output print_separator(stream=sys.stdout) print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout) @@ -350,5 +351,6 @@ if __name__ == "__main__": with open(native_path, 'w') as w: w.write(''.join(native)) + txl = parse(txl, get_labels()) + native = parse(native, get_labels()) compare(txl, native, stats) - |