about summary refs log tree commit diff stats
path: root/compare.py
diff options
context:
space:
mode:
Diffstat (limited to 'compare.py')
-rw-r--r--compare.py215
1 files changed, 74 insertions, 141 deletions
diff --git a/compare.py b/compare.py
index a191025..8a25d8a 100644
--- a/compare.py
+++ b/compare.py
@@ -1,4 +1,5 @@
-from snapshot import ProgramState
+from snapshot import ProgramState, SnapshotSymbolResolver
+from symbolic import SymbolicTransform
 from utils import print_separator
 
 def calc_transformation(previous: ProgramState, current: ProgramState):
@@ -117,144 +118,76 @@ def compare_simple(test_states: list[ProgramState],
 
     return result
 
-def equivalent(val1, val2, transformation, previous_translation):
-    if val1 == val2:
-        return True
-
-    # TODO: maybe incorrect
-    return val1 - previous_translation == transformation
-
-def verify(translation: ProgramState, reference: ProgramState,
-           transformation: ProgramState, previous_translation: ProgramState):
-    assert(translation.arch == reference.arch)
-
-    if translation.regs["PC"] != reference.regs["PC"]:
-        return 1
-
-    print_separator()
-    print(f'For PC={translation.as_repr("PC")}')
-    print_separator()
-    for reg in translation.arch.regnames:
-        if translation.regs[reg] is None:
-            print(f'Element not available in translation: {reg}')
-        elif reference.regs[reg] is None:
-            print(f'Element not available in reference: {reg}')
-        elif not equivalent(translation.regs[reg], reference.regs[reg],
-                            transformation.regs[reg],
-                            previous_translation.regs[reg]):
-            txl = translation.as_repr(reg)
-            ref = reference.as_repr(reg)
-            print(f'Difference for {reg}: {txl} != {ref}')
-
-    return 0
-
-def compare(txl: list[ProgramState],
-            native: list[ProgramState],
-            progressive: bool = False,
-            stats: bool = False):
-    """Compare two lists of snapshots and output the differences.
-
-    :param txl: The translated, and possibly faulty, state of the program.
-    :param native: The 'correct' reference state of the program.
-    :param progressive:
-    :param stats:
-    """
+def find_errors_symbolic(txl_from: ProgramState,
+                         txl_to: ProgramState,
+                         transform_truth: SymbolicTransform) \
+        -> list[dict]:
+    arch = txl_from.arch
+    resolver = SnapshotSymbolResolver(txl_from)
 
-    if len(txl) != len(native):
-        print(f'Different numbers of blocks discovered: '
-              f'{len(txl)} in translation vs. {len(native)} in reference.')
-
-    previous_reference = native[0]
-    previous_translation = txl[0]
-
-    unmatched_pcs = {}
-    pc_to_skip = ""
-    if progressive:
-        i = 0
-        for translation in txl:
-            previous = i
-
-            while i < len(native):
-                reference = native[i]
-                transformation = calc_transformation(previous_reference, reference)
-                if verify(translation, reference, transformation, previous_translation) == 0:
-                    reference.matched = True
-                    break
-
-                i += 1
-
-            matched = True
-
-            # Didn't find anything
-            if i == len(native):
-                matched = False
-                # TODO: add verbose output
-                print_separator()
-                print(f'No match for PC {hex(translation.regs["PC"])}')
-                if translation.regs['PC'] not in unmatched_pcs:
-                    unmatched_pcs[translation.regs['PC']] = 0
-                unmatched_pcs[translation.regs['PC']] += 1
-
-                i = previous
-
-            # Necessary since we may have run out of native BBs to check and
-            # previous becomes len(native)
-            #
-            # We continue checking to report unmatched translation PCs
-            if i < len(native):
-                previous_reference = native[i]
-
-            previous_translation = translation
-
-            # Skip next reference when there is a backwards branch
-            # NOTE: if a reference was skipped, don't skip it again
-            #       necessary for loops which may have multiple backwards
-            #       branches
-            if translation.has_backwards and translation.regs['PC'] != pc_to_skip:
-                pc_to_skip = translation.regs['PC']
-                i += 1
-
-            if matched:
-                i += 1
-    else:
-        txl = iter(txl)
-        native = iter(native)
-        for translation, reference in zip(txl, native):
-            transformation = calc_transformation(previous_reference, reference)
-            if verify(translation, reference, transformation, previous_translation) == 1:
-                # TODO: add verbose output
-                print_separator()
-                print(f'No match for PC {hex(translation.regs["PC"])}')
-                if translation.regs['PC'] not in unmatched_pcs:
-                    unmatched_pcs[translation.regs['PC']] = 0
-                unmatched_pcs[translation.regs['PC']] += 1
-            else:
-                reference.matched = True
-
-            if translation.has_backwards:
-                next(native)
-
-            previous_reference = reference
-            previous_translation = translation
-
-    if stats:
-        print_separator()
-        print('Statistics:')
-        print_separator()
-
-        for pc in unmatched_pcs:
-            print(f'PC {hex(pc)} unmatched {unmatched_pcs[pc]} times')
-
-        # NOTE: currently doesn't handle mismatched due backward branches
-        current = ""
-        unmatched_count = 0
-        for ref in native:
-            ref_pc = ref.regs['PC']
-            if ref_pc != current:
-                if unmatched_count:
-                    print(f'Reference PC {hex(current)} unmatched {unmatched_count} times')
-                current = ref_pc
-
-            if ref.matched == False:
-                unmatched_count += 1
-    return 0
+    assert(txl_from.read('PC') == transform_truth.start_addr)
+    assert(txl_to.read('PC') == transform_truth.end_addr)
+
+    errors = []
+    for reg in arch.regnames:
+        if txl_from.read(reg) is None or txl_to.read(reg) is None:
+            print(f'A value for {reg} must be set in all translated states.'
+                  ' Skipping.')
+            continue
+
+        txl_val = txl_to.read(reg)
+        try:
+            truth = transform_truth.eval_register_transform(reg.lower(), resolver)
+            print(f'Evaluated symbolic formula to {hex(txl_val)} vs. txl {hex(txl_val)}')
+            if txl_val != truth:
+                errors.append({
+                    'reg': reg,
+                    'expected': truth,
+                    'actual': txl_val,
+                    'equation': transform_truth.state.regs.get(reg),
+                })
+        except AttributeError:
+            print(f'Register {reg} does not exist.')
+
+    return errors
+
+def compare_symbolic(test_states: list[ProgramState],
+                     transforms: list[SymbolicTransform]):
+    #assert(len(test_states) == len(transforms) - 1)
+    PC_REGNAME = 'PC'
+
+    result = [{
+        'pc': test_states[0].regs[PC_REGNAME],
+        'txl': test_states[0],
+        'ref': transforms[0],
+        'errors': []
+    }]
+
+    _list = zip(test_states[:-1], test_states[1:], transforms)
+    for cur_state, next_state, transform in _list:
+        pc_cur = cur_state.read(PC_REGNAME)
+        pc_next = next_state.read(PC_REGNAME)
+
+        # The program counter should always be set on a snapshot
+        assert(pc_cur is not None and pc_next is not None)
+
+        if pc_cur != transform.start_addr:
+            print(f'Program counter {hex(pc_cur)} in translated code has no'
+                  f' corresponding reference state! Skipping.'
+                  f' (reference: {hex(transform.start_addr)})')
+            continue
+        if pc_next != transform.end_addr:
+            print(f'Tested state transformation is {hex(pc_cur)} ->'
+                  f' {hex(pc_next)}, but reference transform is'
+                  f' {hex(transform.start_addr)} -> {hex(transform.end_addr)}!'
+                  f' Skipping.')
+
+        errors = find_errors_symbolic(cur_state, next_state, transform)
+        result.append({
+            'pc': pc_cur,
+            'txl': calc_transformation(cur_state, next_state),
+            'ref': transform,
+            'errors': errors
+        })
+
+    return result