about summary refs log tree commit diff stats
path: root/compare.py
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-10-09 15:45:00 +0200
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-10-09 15:45:00 +0200
commitb9c08cadc158b18d7cab14a830a9e11f590ec7bd (patch)
tree8f9fe22f20572dc5e6cec3bd4e731d63da1a7bfe /compare.py
parent8c9c319c35d9f26ab65a591eed3cc7671159a5a2 (diff)
downloadfocaccia-b9c08cadc158b18d7cab14a830a9e11f590ec7bd.tar.gz
focaccia-b9c08cadc158b18d7cab14a830a9e11f590ec7bd.zip
Simplify log file parsing
Employ some refactorings to make the parsing code simpler and faster.

Co-authored-by: Theofilos Augoustis <theofilos.augoustis@gmail.com>
Co-authored-by: Nicola Crivellin <nicola.crivellin98@gmail.com>
Diffstat (limited to 'compare.py')
-rwxr-xr-xcompare.py186
1 files changed, 94 insertions, 92 deletions
diff --git a/compare.py b/compare.py
index ffd1e93..f4576dd 100755
--- a/compare.py
+++ b/compare.py
@@ -3,7 +3,7 @@ import re
 import sys
 import shutil
 import argparse
-from typing import List
+from typing import List, Callable
 from functools import partial as bind
 
 from utils import check_version
@@ -39,73 +39,66 @@ class ContextBlock:
                 'flag DF']
 
     def __init__(self):
-        self.regs = {reg: None for reg in ContextBlock.regnames}
+        dict_type = dict[str, int|None]  # A register may not have a value
+        self.regs = dict_type({reg: None for reg in ContextBlock.regnames})
         self.has_backwards = False
         self.matched = False
 
     def set_backwards(self):
         self.has_backwards = True
 
-    def set(self, idx: int, value: int):
-        self.regs[list(self.regs.keys())[idx]] = value
+    def set(self, reg: str, value: int):
+        """Assign a value to a register.
+
+        :raises RuntimeError: if the register already has a value.
+        """
+        if self.regs[reg] != None:
+            raise RuntimeError("Reassigning register")
+        self.regs[reg] = value
 
     def __repr__(self):
         return self.regs.__repr__()
 
 class Constructor:
-    def __init__(self, structure: dict):
-        self.cblocks = []
-        self.structure = structure
-        self.patterns = list(self.structure.keys())
-
-    def match(self, line: str):
-        # find patterns that match it
-        regex = re.compile("|".join(self.patterns))
-        match = regex.match(line)
+    """Builds a list of context blocks."""
+    def __init__(self, structure: dict[str, tuple[str, Callable[[str], int]]]):
+        self.cblocks = list[ContextBlock]()
+        self.labels = structure
+        self.regex = re.compile("|".join(structure.keys()))
 
-        idx = self.patterns.index(match.group(0)) if match else 0
+    def match(self, line: str) -> (tuple[str, int] | None):
+        """Find a register name and that register's value in a line.
 
-        pattern = self.patterns[idx]
-        register = ContextBlock.regnames[idx]
+        :return: A register name and a register value.
+        """
+        match = self.regex.match(line)
+        if match:
+            label = match.group(0)
+            register, get_reg_value = self.labels[label]
+            return register, get_reg_value(line)
 
-        return register, self.structure[pattern](line)
+        return None
 
     def add_backwards(self):
         self.cblocks[-1].set_backwards()
 
-    def add(self, key: str, value: int):
-        if key == 'PC':
+    def add(self, reg: str, value: int):
+        if reg == 'PC':
             self.cblocks.append(ContextBlock())
+        self.cblocks[-1].set(reg, value)
 
-        if self.cblocks[-1].regs[key] != None:
-            raise RuntimeError("Reassigning register")
-
-        self.cblocks[-1].regs[key] = value
-
-class Transformations:
-    def __init__(self, previous: ContextBlock, current: ContextBlock):
-        self.transformation = ContextBlock()
-        for el1 in current.regs.keys():
-            for el2 in previous.regs.keys():
-                if el1 != el2:
-                    continue
-                self.transformation.regs[el1] = current.regs[el1] - previous.regs[el2]
-
-def parse(lines: list, labels: list):
+def parse(lines: list[str], labels: dict):
+    """Parse a list of lines into a list of cblocks."""
     ctor = Constructor(labels)
-
-    patterns = ctor.patterns.copy()
-    patterns.append('Backwards')
-    regex = re.compile("|".join(patterns))
-    lines = [l for l in lines if regex.match(l) is not None]
-
     for line in lines:
         if 'Backwards' in line:
             ctor.add_backwards()
             continue
 
-        key, value = ctor.match(line)
-        ctor.add(key, value)
+        match = ctor.match(line)
+        if match:
+            key, value = match
+            ctor.add(key, value)
 
     return ctor.cblocks
 
@@ -116,31 +109,48 @@ def get_labels():
     split_second = bind(split_value, i=2)
 
     split_equal = lambda x,i: int(x.split('=')[i], 16)
-    labels = {'INVOKE': bind(split_equal, i=1),
-              'RAX': split_first,
-              'RBX': split_first,
-              'RCX': split_first,
-              'RDX': split_first,
-              'RSI': split_first,
-              'RDI': split_first,
-              'RBP': split_first,
-              'RSP': split_first,
-              'R8':  split_first,
-              'R9':  split_first,
-              'R10': split_first,
-              'R11': split_first,
-              'R12': split_first,
-              'R13': split_first,
-              'R14': split_first,
-              'R15': split_first,
-              'flag ZF': split_second,
-              'flag CF': split_second,
-              'flag OF': split_second,
-              'flag SF': split_second,
-              'flag PF': split_second,
-              'flag DF': split_second}
+
+    # A mapping from regex patterns to the register name and a
+    # function that extracts that register's value from the line
+    labels = {'INVOKE':  ('PC',      bind(split_equal, i=1)),
+              'RAX':     ('RAX',     split_first),
+              'RBX':     ('RBX',     split_first),
+              'RCX':     ('RCX',     split_first),
+              'RDX':     ('RDX',     split_first),
+              'RSI':     ('RSI',     split_first),
+              'RDI':     ('RDI',     split_first),
+              'RBP':     ('RBP',     split_first),
+              'RSP':     ('RSP',     split_first),
+              'R8':      ('R8',      split_first),
+              'R9':      ('R9',      split_first),
+              'R10':     ('R10',     split_first),
+              'R11':     ('R11',     split_first),
+              'R12':     ('R12',     split_first),
+              'R13':     ('R13',     split_first),
+              'R14':     ('R14',     split_first),
+              'R15':     ('R15',     split_first),
+              'flag ZF': ('flag ZF', split_second),
+              'flag CF': ('flag CF', split_second),
+              'flag OF': ('flag OF', split_second),
+              'flag SF': ('flag SF', split_second),
+              'flag PF': ('flag PF', split_second),
+              'flag DF': ('flag DF', split_second)}
     return labels
 
+def calc_transformation(previous: ContextBlock, current: ContextBlock):
+    """Calculate the difference between two context blocks.
+
+    :return: A context block that contains in its registers the difference
+             between the corresponding input blocks' register values.
+    """
+    transformation = ContextBlock()
+    for reg in ContextBlock.regnames:
+        prev_val, cur_val = previous.regs[reg], current.regs[reg]
+        if prev_val is not None and cur_val is not None:
+            transformation.regs[reg] = cur_val - prev_val
+
+    return transformation
+
 def equivalent(val1, val2, transformation, previous_translation):
     if val1 == val2:
         return True
@@ -149,35 +159,28 @@ def equivalent(val1, val2, transformation, previous_translation):
     return val1 - previous_translation == transformation
 
 def verify(translation: ContextBlock, reference: ContextBlock,
-           transformation: Transformations, previous_translation: ContextBlock):
+           transformation: ContextBlock, previous_translation: ContextBlock):
     if translation.regs["PC"] != reference.regs["PC"]:
         return 1
 
     print_separator()
     print(f'For PC={hex(translation.regs["PC"])}')
     print_separator()
-    for el1 in translation.regs.keys():
-        for el2 in reference.regs.keys():
-            if el1 != el2:
-                continue
-
-            if translation.regs[el1] is None:
-                print(f'Element not available in translation: {el1}')
-                continue
-
-            if reference.regs[el2] is None:
-                print(f'Element not available in reference: {el2}')
-                continue
-
-            if not equivalent(translation.regs[el1], reference.regs[el2],
-                              transformation.regs[el1], previous_translation.regs[el1]):
-                print(f'Difference for {el1}: {hex(translation.regs[el1])} != {hex(reference.regs[el2])}')
+    for reg in ContextBlock.regnames:
+        if translation.regs[reg] is None:
+            print(f'Element not available in translation: {reg}')
+        elif reference.regs[reg] is None:
+            print(f'Element not available in reference: {reg}')
+        elif not equivalent(translation.regs[reg], reference.regs[reg],
+                            transformation.regs[reg],
+                            previous_translation.regs[reg]):
+            txl = hex(translation.regs[reg])
+            ref = hex(reference.regs[reg])
+            print(f'Difference for {reg}: {txl} != {ref}')
+
     return 0
 
 def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = False):
-    txl = parse(txl, get_labels())
-    native = parse(native, get_labels())
-
     if len(txl) != len(native):
         print(f'Different number of blocks discovered translation: {len(txl)} vs. '
               f'reference: {len(native)}', file=sys.stdout)
@@ -194,8 +197,8 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F
 
             while i < len(native):
                 reference = native[i]
-                transformations = Transformations(previous_reference, reference)
-                if verify(translation, reference, transformations.transformation, previous_translation) == 0:
+                transformation = calc_transformation(previous_reference, reference)
+                if verify(translation, reference, transformation, previous_translation) == 0:
                     reference.matched = True
                     break
 
@@ -235,11 +238,9 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F
             if matched:
                 i += 1
     else:
-        txl = iter(txl)
-        native = iter(native)
         for translation, reference in zip(txl, native):
-            transformations = Transformations(previous_reference, reference)
-            if verify(translation, reference, transformations.transformation, previous_translation) == 1:
+            transformation = calc_transformation(previous_reference, reference)
+            if verify(translation, reference, transformation, previous_translation) == 1:
                 # TODO: add verbose output
                 print_separator(stream=sys.stdout)
                 print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout)
@@ -350,5 +351,6 @@ if __name__ == "__main__":
         with open(native_path, 'w') as w:
             w.write(''.join(native))
 
+    txl = parse(txl, get_labels())
+    native = parse(native, get_labels())
     compare(txl, native, stats)
-