about summary refs log tree commit diff stats
path: root/compare.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compare.py173
1 files changed, 108 insertions, 65 deletions
diff --git a/compare.py b/compare.py
index e5ac244..dfa5dbd 100644
--- a/compare.py
+++ b/compare.py
@@ -1,6 +1,50 @@
+from functools import total_ordering
+from typing import Self
+
 from snapshot import ProgramState, MemoryAccessError
 from symbolic import SymbolicTransform
 
+@total_ordering
+class ErrorSeverity:
+    def __init__(self, num: int, name: str):
+        """Construct an error severity.
+
+        :param num:  A numerical value that orders the severity with respect
+                     to other `ErrorSeverity` objects. Smaller values are less
+                     severe.
+        :param name: A descriptive name for the error severity, e.g. 'fatal'
+                     or 'info'.
+        """
+        self._numeral = num
+        self.name = name
+
+    def __repr__(self) -> str:
+        return f'[{self.name}]'
+
+    def __eq__(self, other: Self) -> bool:
+        return self._numeral == other._numeral
+
+    def __lt__(self, other: Self) -> bool:
+        return self._numeral < other._numeral
+
+    def __hash__(self) -> int:
+        return hash(self._numeral)
+
+class ErrorTypes:
+    INFO       = ErrorSeverity(0, 'INFO')
+    INCOMPLETE = ErrorSeverity(2, 'INCOMPLETE DATA')
+    POSSIBLE   = ErrorSeverity(4, 'UNCONFIRMED ERROR')
+    CONFIRMED  = ErrorSeverity(5, 'ERROR')
+
+class Error:
+    """A state comparison error."""
+    def __init__(self, severity: ErrorSeverity, msg: str):
+        self.severity = severity
+        self.error_msg = msg
+
+    def __repr__(self) -> str:
+        return f'{self.severity} {self.error_msg}'
+
 def _calc_transformation(previous: ProgramState, current: ProgramState):
     """Calculate the difference between two context blocks.
 
@@ -22,9 +66,8 @@ def _calc_transformation(previous: ProgramState, current: ProgramState):
 
     return transformation
 
-def _find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
-                 truth_state: ProgramState, prev_truth_state: ProgramState) \
-        -> list[dict]:
+def _find_errors(transform_txl: ProgramState, transform_truth: ProgramState) \
+        -> list[Error]:
     """Find possible errors between a reference and a tested state.
 
     :param txl_state: The translated state to check for errors.
@@ -38,33 +81,25 @@ def _find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
     :return: A list of errors; one entry for each register that may have
              faulty contents. Is empty if no errors were found.
     """
-    arch = txl_state.arch
-    errors = []
+    assert(transform_truth.arch == transform_txl.arch)
 
-    transform_truth = _calc_transformation(prev_truth_state, truth_state)
-    transform_txl = _calc_transformation(prev_txl_state, txl_state)
-    for reg in arch.regnames:
+    errors = []
+    for reg in transform_truth.arch.regnames:
         try:
             diff_txl = transform_txl.read(reg)
             diff_truth = transform_truth.read(reg)
         except ValueError:
-            # Register is not set in either state
+            errors.append(Error(ErrorTypes.INFO,
+                                f'Value for register {reg} is not set in'
+                                f' either the tested or the reference state.'))
             continue
 
-        if diff_txl == diff_truth:
-            # The register contains a value that is expected
-            # by the transformation.
-            continue
-        if diff_truth is not None:
-            if diff_txl is None:
-                print(f'[WARNING] Expected the value of register {reg} to be'
-                      f' defined, but it is undefined in the translation.'
-                      f' This might hint at an error in the input data.')
-            else:
-                errors.append({
-                    'reg': reg,
-                    'expected': diff_truth, 'actual': diff_txl,
-                })
+        if diff_txl != diff_truth:
+            errors.append(Error(
+                ErrorTypes.CONFIRMED,
+                f'Transformation of register {reg} is false.'
+                f' Expected difference: {hex(diff_truth)},'
+                f' actual difference in the translation: {hex(diff_txl)}.'))
 
     return errors
 
@@ -111,23 +146,21 @@ def compare_simple(test_states: list[ProgramState],
                   f' in translated code!')
             continue
 
-        errors = _find_errors(txl, prev_txl, truth, prev_truth)
+        transform_truth = _calc_transformation(prev_truth, truth)
+        transform_txl = _calc_transformation(prev_txl, txl)
+        errors = _find_errors(transform_txl, transform_truth)
         result.append({
             'pc': pc_txl,
-            'txl': txl, 'ref': truth,
+            'txl': transform_txl, 'ref': transform_truth,
             'errors': errors
         })
 
-        # TODO: Why do we skip backward branches?
-        #if txl.has_backwards:
-        #    print(f' -- Encountered backward branch. Don\'t skip.')
-
     return result
 
 def _find_register_errors(txl_from: ProgramState,
                           txl_to: ProgramState,
                           transform_truth: SymbolicTransform) \
-        -> list[str]:
+        -> list[Error]:
     """Find errors in register values.
 
     Errors might be:
@@ -139,10 +172,15 @@ def _find_register_errors(txl_from: ProgramState,
     # Calculate expected register values
     try:
         truth = transform_truth.calc_register_transform(txl_from)
-    except MemoryAccessError:
-        print(f'Transformation at {hex(transform_truth.addr)} depends on'
-              f' memory that is not set in the tested state. Skipping.')
-        return []
+    except MemoryAccessError as err:
+        s, e = transform_truth.range
+        return [Error(
+            ErrorTypes.INCOMPLETE,
+            f'Register transformations {hex(s)} -> {hex(e)} depend on'
+            f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}'
+            f' that are not entirely present in the tested state'
+            f' {hex(txl_from.read("pc"))}. Skipping.',
+        )]
 
     # Compare expected values to actual values in the tested state
     errors = []
@@ -150,23 +188,25 @@ def _find_register_errors(txl_from: ProgramState,
         try:
             txl_val = txl_to.read(regname)
         except ValueError:
-            errors.append(f'Value of register {regname} has changed, but is'
-                          f' not set in the tested state. Skipping.')
+            errors.append(Error(ErrorTypes.INCOMPLETE,
+                                f'Value of register {regname} has changed, but'
+                                f' is not set in the tested state. Skipping.'))
             continue
         except KeyError as err:
             print(f'[WARNING] {err}')
             continue
 
         if txl_val != truth_val:
-            errors.append(f'Content of register {regname} is possibly false.'
-                          f' Expected value: {hex(truth_val)}, actual'
-                          f' value in the translation: {hex(txl_val)}.')
+            errors.append(Error(ErrorTypes.CONFIRMED,
+                                f'Content of register {regname} is false.'
+                                f' Expected value: {hex(truth_val)}, actual'
+                                f' value in the translation: {hex(txl_val)}.'))
     return errors
 
 def _find_memory_errors(txl_from: ProgramState,
                         txl_to: ProgramState,
                         transform_truth: SymbolicTransform) \
-        -> list[str]:
+        -> list[Error]:
     """Find errors in memory values.
 
     Errors might be:
@@ -178,31 +218,37 @@ def _find_memory_errors(txl_from: ProgramState,
     # Calculate expected register values
     try:
         truth = transform_truth.calc_memory_transform(txl_from)
-    except MemoryAccessError:
-        print(f'Transformation at {hex(transform_truth.addr)} depends on'
-              f' memory that is not set in the tested state. Skipping.')
-        return []
+    except MemoryAccessError as err:
+        s, e = transform_truth.range
+        return [Error(ErrorTypes.INCOMPLETE,
+                      f'Memory transformations {hex(s)} -> {hex(e)} depend on'
+                      f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}'
+                      f' that are not entirely present in the tested state'
+                      f' {hex(txl_from.read("pc"))}. Skipping.')]
 
     # Compare expected values to actual values in the tested state
     errors = []
     for addr, truth_bytes in truth.items():
+        size = len(truth_bytes)
         try:
-            txl_bytes = txl_to.read_memory(addr, len(truth_bytes))
+            txl_bytes = txl_to.read_memory(addr, size)
         except MemoryAccessError:
-            errors.append(f'Memory range [{addr}, {addr + len(truth_bytes)})'
-                          f' is not set in the test-result state. Skipping.')
+            errors.append(Error(ErrorTypes.POSSIBLE,
+                                f'Memory range [{addr}, {addr + size}) is not'
+                                f' set in the tested result state. Skipping.'))
             continue
 
         if txl_bytes != truth_bytes:
-            errors.append(f'Content of memory at {addr} is possibly false.'
-                          f' Expected content: {truth_bytes.hex()}, actual'
-                          f' content in the translation: {txl_bytes.hex()}.')
+            errors.append(Error(ErrorTypes.CONFIRMED,
+                                f'Content of memory at {addr} is false.'
+                                f' Expected content: {truth_bytes.hex()}, actual'
+                                f' content in the translation: {txl_bytes.hex()}.'))
     return errors
 
 def _find_errors_symbolic(txl_from: ProgramState,
                           txl_to: ProgramState,
                           transform_truth: SymbolicTransform) \
-        -> list[str]:
+        -> list[Error]:
     """Tries to find errors in transformations between tested states.
 
     Applies a transformation to a source state and tests whether the result
@@ -220,12 +266,12 @@ def _find_errors_symbolic(txl_from: ProgramState,
     if (txl_from.read('PC') != transform_truth.range[0]) \
             or (txl_to.read('PC') != transform_truth.range[1]):
         tstart, tend = transform_truth.range
-        print(f'[WARNING] Program counters of the tested transformation do not'
-              f' match the truth transformation:'
-              f' {hex(txl_from.read("PC"))} -> {hex(txl_to.read("PC"))} (test)'
-              f' vs. {hex(tstart)} -> {hex(tend)} (truth).'
-              f' Skipping with no errors.')
-        return []
+        return [Error(ErrorTypes.POSSIBLE,
+                      f'Program counters of the tested transformation'
+                      f' do not match the truth transformation:'
+                      f' {hex(txl_from.read("PC"))} -> {hex(txl_to.read("PC"))}'
+                      f' (test) vs. {hex(tstart)} -> {hex(tend)} (truth).'
+                      f' Skipping with no errors.')]
 
     errors = []
     errors.extend(_find_register_errors(txl_from, txl_to, transform_truth))
@@ -234,12 +280,12 @@ def _find_errors_symbolic(txl_from: ProgramState,
     return errors
 
 def compare_symbolic(test_states: list[ProgramState],
-                     transforms: list[SymbolicTransform]):
+                     transforms: list[SymbolicTransform]) \
+        -> list[dict]:
     #assert(len(test_states) == len(transforms) - 1)
-    PC_REGNAME = 'PC'
 
     result = [{
-        'pc': test_states[0].read(PC_REGNAME),
+        'pc': test_states[0].read('PC'),
         'txl': test_states[0],
         'ref': transforms[0],
         'errors': []
@@ -247,11 +293,8 @@ def compare_symbolic(test_states: list[ProgramState],
 
     _list = zip(test_states[:-1], test_states[1:], transforms)
     for cur_state, next_state, transform in _list:
-        pc_cur = cur_state.read(PC_REGNAME)
-        pc_next = next_state.read(PC_REGNAME)
-
-        # The program counter should always be set on a snapshot
-        assert(pc_cur is not None and pc_next is not None)
+        pc_cur = cur_state.read('PC')
+        pc_next = next_state.read('PC')
 
         start_addr, end_addr = transform.range
         if pc_cur != start_addr: