about summary refs log tree commit diff stats
path: root/compare.py
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-08-24 11:19:37 +0200
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-08-24 11:19:37 +0200
commit511d9fd98319b48cd3d0dc08a4108f90747af5b3 (patch)
treef5f911dea8885bf62ac7009300a127b803760af1 /compare.py
parent96094b79c8e3cc2583272a9daf06a19189c67176 (diff)
downloadfocaccia-511d9fd98319b48cd3d0dc08a4108f90747af5b3.tar.gz
focaccia-511d9fd98319b48cd3d0dc08a4108f90747af5b3.zip
Add progressive search over basic blocks
Diffstat (limited to '')
-rwxr-xr-xcompare.py87
1 files changed, 71 insertions, 16 deletions
diff --git a/compare.py b/compare.py
index b5bc87e..45c2f36 100755
--- a/compare.py
+++ b/compare.py
@@ -11,6 +11,8 @@ from utils import print_separator
 
 from run import Runner
 
+progressive = False
+
 class ContextBlock:
     regnames = ['PC',
                 'RAX',
@@ -38,6 +40,10 @@ class ContextBlock:
 
     def __init__(self):
         self.regs = {reg: None for reg in ContextBlock.regnames}
+        self.has_backwards = False
+
+    def set_backwards(self):
+        self.has_backwards = True
 
     def set(self, idx: int, value: int):
         self.regs[list(self.regs.keys())[idx]] = value
@@ -63,6 +69,9 @@ class Constructor:
 
         return register, self.structure[pattern](line)
 
+    def add_backwards(self):
+        self.cblocks[-1].set_backwards()
+
     def add(self, key: str, value: int):
         if key == 'PC':
             self.cblocks.append(ContextBlock())
@@ -84,10 +93,16 @@ class Transformations:
 def parse(lines: list, labels: list):
     ctor = Constructor(labels)
 
-    regex = re.compile("|".join(ctor.patterns))
+    patterns = ctor.patterns.copy()
+    patterns.append('Backwards')
+    regex = re.compile("|".join(patterns))
     lines = [l for l in lines if regex.match(l) is not None]
 
     for line in lines:
+        if 'Backwards' in line:
+            ctor.add_backwards()
+            continue
+
         key, value = ctor.match(line)
         ctor.add(key, value)
 
@@ -164,24 +179,58 @@ def compare(txl: List[ContextBlock], native: List[ContextBlock], stats: bool = F
 
     if len(txl) != len(native):
         print(f'Different number of blocks discovered translation: {len(txl)} vs. '
-              f'reference: {len(native)}', file=sys.stderr)
+              f'reference: {len(native)}', file=sys.stdout)
 
     previous_reference = native[0]
     previous_translation = txl[0]
 
     unmatched_pcs = {}
-    for translation, reference in zip(txl, native):
-        transformations = Transformations(previous_reference, reference)
-        if verify(translation, reference, transformations.transformation, previous_translation) == 1:
-            # TODO: add verbose output
-            print_separator(stream=sys.stderr)
-            print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stderr)
-            if translation.regs['PC'] not in unmatched_pcs:
-                unmatched_pcs[translation.regs['PC']] = 0
-            unmatched_pcs[translation.regs['PC']] += 1
-
-        previous_reference = reference
-        previous_translation = translation
+    if progressive:
+        i = 0
+        for translation in txl:
+            previous = i
+
+            while i < len(native):
+                reference = native[i]
+                transformations = Transformations(previous_reference, reference)
+                if verify(translation, reference, transformations.transformation, previous_translation) == 0:
+                    break
+
+                previous_reference = reference
+                previous_translation = translation
+                i += 1
+
+            # Didn't find anything
+            if i == len(native):
+                # TODO: add verbose output
+                print_separator(stream=sys.stdout)
+                print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout)
+                if translation.regs['PC'] not in unmatched_pcs:
+                    unmatched_pcs[translation.regs['PC']] = 0
+                unmatched_pcs[translation.regs['PC']] += 1
+
+                i = previous + 1
+
+            if translation.has_backwards:
+                i += 1
+    else:
+        txl = iter(txl)
+        native = iter(native)
+        for translation, reference in zip(txl, native):
+            transformations = Transformations(previous_reference, reference)
+            if verify(translation, reference, transformations.transformation, previous_translation) == 1:
+                # TODO: add verbose output
+                print_separator(stream=sys.stdout)
+                print(f'No match for PC {hex(translation.regs["PC"])}', file=sys.stdout)
+                if translation.regs['PC'] not in unmatched_pcs:
+                    unmatched_pcs[translation.regs['PC']] = 0
+                unmatched_pcs[translation.regs['PC']] += 1
+
+            if translation.has_backwards:
+                next(native)
+
+            previous_reference = reference
+            previous_translation = translation
 
     if stats:
         print_separator()
@@ -222,12 +271,17 @@ def parse_arguments():
                         help='Path to the translation log (gathered via Arancini)')
     parser.add_argument('-s', '--stats',
                         action='store_true',
-                        default='store_false',
+                        default=False,
                         help='Run statistics on comparisons')
     parser.add_argument('-v', '--verbose',
                         action='store_true',
-                        default='store_true',
+                        default=True,
                         help='Path to oracle program')
+    parser.add_argument('--progressive',
+                        action='store_true',
+                        default=False,
+                        help='Try to match exhaustively before declaring \
+                        mismatch')
     args = parser.parse_args()
     return args
 
@@ -242,6 +296,7 @@ if __name__ == "__main__":
 
     stats = args.stats
     verbose = args.verbose
+    progressive = args.progressive
 
     if program is None and native_path is None:
         raise ValueError('Either program or path to native file must be'