about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-12-27 14:41:01 +0100
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-12-27 14:41:01 +0100
commitf2246e641d494d5df76458db4fb4928f5c2cfc7f (patch)
treeb2a0c2a1493dadb002f90f2932e22d89f659f3ea
parent2ddf26ab93c5c625c468c7d554b995e5d6b04d3a (diff)
downloadfocaccia-f2246e641d494d5df76458db4fb4928f5c2cfc7f.tar.gz
focaccia-f2246e641d494d5df76458db4fb4928f5c2cfc7f.zip
Extend error reporting system
Add error severities and the ability to filter for them. Include more
information in comparison error messages.
Diffstat (limited to '')
-rw-r--r--compare.py173
-rwxr-xr-xmain.py78
-rw-r--r--snapshot.py7
-rw-r--r--symbolic.py3
4 files changed, 169 insertions, 92 deletions
diff --git a/compare.py b/compare.py
index e5ac244..dfa5dbd 100644
--- a/compare.py
+++ b/compare.py
@@ -1,6 +1,50 @@
+from functools import total_ordering
+from typing import Self
+
 from snapshot import ProgramState, MemoryAccessError
 from symbolic import SymbolicTransform
 
+@total_ordering
+class ErrorSeverity:
+    def __init__(self, num: int, name: str):
+        """Construct an error severity.
+
+        :param num:  A numerical value that orders the severity with respect
+                     to other `ErrorSeverity` objects. Smaller values are less
+                     severe.
+        :param name: A descriptive name for the error severity, e.g. 'fatal'
+                     or 'info'.
+        """
+        self._numeral = num
+        self.name = name
+
+    def __repr__(self) -> str:
+        return f'[{self.name}]'
+
+    def __eq__(self, other: Self) -> bool:
+        return self._numeral == other._numeral
+
+    def __lt__(self, other: Self) -> bool:
+        return self._numeral < other._numeral
+
+    def __hash__(self) -> int:
+        return hash(self._numeral)
+
+class ErrorTypes:
+    INFO       = ErrorSeverity(0, 'INFO')
+    INCOMPLETE = ErrorSeverity(2, 'INCOMPLETE DATA')
+    POSSIBLE   = ErrorSeverity(4, 'UNCONFIRMED ERROR')
+    CONFIRMED  = ErrorSeverity(5, 'ERROR')
+
+class Error:
+    """A state comparison error."""
+    def __init__(self, severity: ErrorSeverity, msg: str):
+        self.severity = severity
+        self.error_msg = msg
+
+    def __repr__(self) -> str:
+        return f'{self.severity} {self.error_msg}'
+
 def _calc_transformation(previous: ProgramState, current: ProgramState):
     """Calculate the difference between two context blocks.
 
@@ -22,9 +66,8 @@ def _calc_transformation(previous: ProgramState, current: ProgramState):
 
     return transformation
 
-def _find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
-                 truth_state: ProgramState, prev_truth_state: ProgramState) \
-        -> list[dict]:
+def _find_errors(transform_txl: ProgramState, transform_truth: ProgramState) \
+        -> list[Error]:
     """Find possible errors between a reference and a tested state.
 
     :param txl_state: The translated state to check for errors.
@@ -38,33 +81,25 @@ def _find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
     :return: A list of errors; one entry for each register that may have
              faulty contents. Is empty if no errors were found.
     """
-    arch = txl_state.arch
-    errors = []
+    assert(transform_truth.arch == transform_txl.arch)
 
-    transform_truth = _calc_transformation(prev_truth_state, truth_state)
-    transform_txl = _calc_transformation(prev_txl_state, txl_state)
-    for reg in arch.regnames:
+    errors = []
+    for reg in transform_truth.arch.regnames:
         try:
             diff_txl = transform_txl.read(reg)
             diff_truth = transform_truth.read(reg)
         except ValueError:
-            # Register is not set in either state
+            errors.append(Error(ErrorTypes.INFO,
+                                f'Value for register {reg} is not set in'
+                                f' either the tested or the reference state.'))
             continue
 
-        if diff_txl == diff_truth:
-            # The register contains a value that is expected
-            # by the transformation.
-            continue
-        if diff_truth is not None:
-            if diff_txl is None:
-                print(f'[WARNING] Expected the value of register {reg} to be'
-                      f' defined, but it is undefined in the translation.'
-                      f' This might hint at an error in the input data.')
-            else:
-                errors.append({
-                    'reg': reg,
-                    'expected': diff_truth, 'actual': diff_txl,
-                })
+        if diff_txl != diff_truth:
+            errors.append(Error(
+                ErrorTypes.CONFIRMED,
+                f'Transformation of register {reg} is false.'
+                f' Expected difference: {hex(diff_truth)},'
+                f' actual difference in the translation: {hex(diff_txl)}.'))
 
     return errors
 
@@ -111,23 +146,21 @@ def compare_simple(test_states: list[ProgramState],
                   f' in translated code!')
             continue
 
-        errors = _find_errors(txl, prev_txl, truth, prev_truth)
+        transform_truth = _calc_transformation(prev_truth, truth)
+        transform_txl = _calc_transformation(prev_txl, txl)
+        errors = _find_errors(transform_txl, transform_truth)
         result.append({
             'pc': pc_txl,
-            'txl': txl, 'ref': truth,
+            'txl': transform_txl, 'ref': transform_truth,
             'errors': errors
         })
 
-        # TODO: Why do we skip backward branches?
-        #if txl.has_backwards:
-        #    print(f' -- Encountered backward branch. Don\'t skip.')
-
     return result
 
 def _find_register_errors(txl_from: ProgramState,
                           txl_to: ProgramState,
                           transform_truth: SymbolicTransform) \
-        -> list[str]:
+        -> list[Error]:
     """Find errors in register values.
 
     Errors might be:
@@ -139,10 +172,15 @@ def _find_register_errors(txl_from: ProgramState,
     # Calculate expected register values
     try:
         truth = transform_truth.calc_register_transform(txl_from)
-    except MemoryAccessError:
-        print(f'Transformation at {hex(transform_truth.addr)} depends on'
-              f' memory that is not set in the tested state. Skipping.')
-        return []
+    except MemoryAccessError as err:
+        s, e = transform_truth.range
+        return [Error(
+            ErrorTypes.INCOMPLETE,
+            f'Register transformations {hex(s)} -> {hex(e)} depend on'
+            f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}'
+            f' that are not entirely present in the tested state'
+            f' {hex(txl_from.read("pc"))}. Skipping.',
+        )]
 
     # Compare expected values to actual values in the tested state
     errors = []
@@ -150,23 +188,25 @@ def _find_register_errors(txl_from: ProgramState,
         try:
             txl_val = txl_to.read(regname)
         except ValueError:
-            errors.append(f'Value of register {regname} has changed, but is'
-                          f' not set in the tested state. Skipping.')
+            errors.append(Error(ErrorTypes.INCOMPLETE,
+                                f'Value of register {regname} has changed, but'
+                                f' is not set in the tested state. Skipping.'))
             continue
         except KeyError as err:
             print(f'[WARNING] {err}')
             continue
 
         if txl_val != truth_val:
-            errors.append(f'Content of register {regname} is possibly false.'
-                          f' Expected value: {hex(truth_val)}, actual'
-                          f' value in the translation: {hex(txl_val)}.')
+            errors.append(Error(ErrorTypes.CONFIRMED,
+                                f'Content of register {regname} is false.'
+                                f' Expected value: {hex(truth_val)}, actual'
+                                f' value in the translation: {hex(txl_val)}.'))
     return errors
 
 def _find_memory_errors(txl_from: ProgramState,
                         txl_to: ProgramState,
                         transform_truth: SymbolicTransform) \
-        -> list[str]:
+        -> list[Error]:
     """Find errors in memory values.
 
     Errors might be:
@@ -178,31 +218,37 @@ def _find_memory_errors(txl_from: ProgramState,
     # Calculate expected register values
     try:
         truth = transform_truth.calc_memory_transform(txl_from)
-    except MemoryAccessError:
-        print(f'Transformation at {hex(transform_truth.addr)} depends on'
-              f' memory that is not set in the tested state. Skipping.')
-        return []
+    except MemoryAccessError as err:
+        s, e = transform_truth.range
+        return [Error(ErrorTypes.INCOMPLETE,
+                      f'Memory transformations {hex(s)} -> {hex(e)} depend on'
+                      f' {err.mem_size} bytes at memory address {hex(err.mem_addr)}'
+                      f' that are not entirely present in the tested state'
+                      f' {hex(txl_from.read("pc"))}. Skipping.')]
 
     # Compare expected values to actual values in the tested state
     errors = []
     for addr, truth_bytes in truth.items():
+        size = len(truth_bytes)
         try:
-            txl_bytes = txl_to.read_memory(addr, len(truth_bytes))
+            txl_bytes = txl_to.read_memory(addr, size)
         except MemoryAccessError:
-            errors.append(f'Memory range [{addr}, {addr + len(truth_bytes)})'
-                          f' is not set in the test-result state. Skipping.')
+            errors.append(Error(ErrorTypes.POSSIBLE,
+                                f'Memory range [{addr}, {addr + size}) is not'
+                                f' set in the tested result state. Skipping.'))
             continue
 
         if txl_bytes != truth_bytes:
-            errors.append(f'Content of memory at {addr} is possibly false.'
-                          f' Expected content: {truth_bytes.hex()}, actual'
-                          f' content in the translation: {txl_bytes.hex()}.')
+            errors.append(Error(ErrorTypes.CONFIRMED,
+                                f'Content of memory at {addr} is false.'
+                                f' Expected content: {truth_bytes.hex()}, actual'
+                                f' content in the translation: {txl_bytes.hex()}.'))
     return errors
 
 def _find_errors_symbolic(txl_from: ProgramState,
                           txl_to: ProgramState,
                           transform_truth: SymbolicTransform) \
-        -> list[str]:
+        -> list[Error]:
     """Tries to find errors in transformations between tested states.
 
     Applies a transformation to a source state and tests whether the result
@@ -220,12 +266,12 @@ def _find_errors_symbolic(txl_from: ProgramState,
     if (txl_from.read('PC') != transform_truth.range[0]) \
             or (txl_to.read('PC') != transform_truth.range[1]):
         tstart, tend = transform_truth.range
-        print(f'[WARNING] Program counters of the tested transformation do not'
-              f' match the truth transformation:'
-              f' {hex(txl_from.read("PC"))} -> {hex(txl_to.read("PC"))} (test)'
-              f' vs. {hex(tstart)} -> {hex(tend)} (truth).'
-              f' Skipping with no errors.')
-        return []
+        return [Error(ErrorTypes.POSSIBLE,
+                      f'Program counters of the tested transformation'
+                      f' do not match the truth transformation:'
+                      f' {hex(txl_from.read("PC"))} -> {hex(txl_to.read("PC"))}'
+                      f' (test) vs. {hex(tstart)} -> {hex(tend)} (truth).'
+                      f' Skipping with no errors.')]
 
     errors = []
     errors.extend(_find_register_errors(txl_from, txl_to, transform_truth))
@@ -234,12 +280,12 @@ def _find_errors_symbolic(txl_from: ProgramState,
     return errors
 
 def compare_symbolic(test_states: list[ProgramState],
-                     transforms: list[SymbolicTransform]):
+                     transforms: list[SymbolicTransform]) \
+        -> list[dict]:
     #assert(len(test_states) == len(transforms) - 1)
-    PC_REGNAME = 'PC'
 
     result = [{
-        'pc': test_states[0].read(PC_REGNAME),
+        'pc': test_states[0].read('PC'),
         'txl': test_states[0],
         'ref': transforms[0],
         'errors': []
@@ -247,11 +293,8 @@ def compare_symbolic(test_states: list[ProgramState],
 
     _list = zip(test_states[:-1], test_states[1:], transforms)
     for cur_state, next_state, transform in _list:
-        pc_cur = cur_state.read(PC_REGNAME)
-        pc_next = next_state.read(PC_REGNAME)
-
-        # The program counter should always be set on a snapshot
-        assert(pc_cur is not None and pc_next is not None)
+        pc_cur = cur_state.read('PC')
+        pc_next = next_state.read('PC')
 
         start_addr, end_addr = transform.range
         if pc_cur != start_addr:
diff --git a/main.py b/main.py
index a51ecf7..3167bbd 100755
--- a/main.py
+++ b/main.py
@@ -5,7 +5,8 @@ import platform
 from typing import Iterable
 
 from arch import x86
-from compare import compare_simple, compare_symbolic
+from compare import compare_simple, compare_symbolic, \
+                    ErrorSeverity, ErrorTypes
 from lldb_target import LLDBConcreteTarget
 from parser import parse_arancini
 from snapshot import ProgramState
@@ -108,11 +109,61 @@ def parse_arguments():
                         default=False,
                         help='Use an advanced algorithm that uses symbolic'
                              ' execution to determine accurate data'
-                             ' transformations')
+                             ' transformations. This improves the quality of'
+                             ' generated errors significantly, but may take'
+                             ' more time to run.')
+    parser.add_argument('--error-level',
+                        type=str,
+                        default='verbose',
+                        choices=['verbose', 'errors', 'restricted'],
+                        help='Verbosity of reported errors. \'errors\' reports'
+                             ' everything that might be an error in the'
+                             ' translation, while \'verbose\' may report'
+                             ' additional errors from incomplete input'
+                             ' data, etc. [Default: verbose]')
     args = parser.parse_args()
     return args
 
+def print_result(result, min_severity: ErrorSeverity):
+    shown = 0
+    suppressed = 0
+
+    for res in result:
+        pc = res['pc']
+        print_separator()
+        print(f'For PC={hex(pc)}')
+        print_separator()
+
+        # Filter errors by severity
+        errs = [e for e in res['errors'] if e.severity >= min_severity]
+        suppressed += len(res['errors']) - len(errs)
+        shown += len(errs)
+
+        # Print all non-suppressed errors
+        for n, err in enumerate(errs, start=1):
+            print(f' {n:2}. {err}')
+
+        if errs:
+            print()
+            print(f'Expected transformation: {res["ref"]}')
+            print(f'Actual transformation:   {res["txl"]}')
+        else:
+            print('No errors found.')
+
+    print()
+    print('#' * 60)
+    print(f'Found {shown} errors.')
+    print(f'Suppressed {suppressed} low-priority errors'
+          f' (showing {min_severity} and higher).')
+    print('#' * 60)
+    print()
+
 def main():
+    verbosity = {
+        'verbose': ErrorTypes.INFO,
+        'errors': ErrorTypes.POSSIBLE,
+        'restricted': ErrorTypes.CONFIRMED,
+    }
     args = parse_arguments()
 
     txl_path = args.txl
@@ -123,33 +174,14 @@ def main():
     if args.symbolic:
         assert(program is not None)
 
-        print(f'Tracing {program} with arguments {prog_args}...')
+        print(f'Tracing {program} symbolically with arguments {prog_args}...')
         transforms = collect_symbolic_trace(program, [program, *prog_args])
         txl, transforms = match_traces(txl, transforms)
         result = compare_symbolic(txl, transforms)
     else:
         result = compare_simple(txl, ref)
 
-    # Print results
-    for res in result:
-        pc = res['pc']
-        print_separator()
-        print(f'For PC={hex(pc)}')
-        print_separator()
-
-        ref = res['ref']
-        for err in res['errors']:
-            print(f' - {err}')
-        if res['errors']:
-            print(ref)
-        else:
-            print('No errors found.')
-
-    print()
-    print('#' * 60)
-    print(f'Found {sum(len(res["errors"]) for res in result)} errors.')
-    print('#' * 60)
-    print()
+    print_result(result, verbosity[args.error_level])
 
 if __name__ == "__main__":
     check_version('3.7')
diff --git a/snapshot.py b/snapshot.py
index be18af3..9c9e4b3 100644
--- a/snapshot.py
+++ b/snapshot.py
@@ -1,8 +1,10 @@
 from arch.arch import Arch
 
 class MemoryAccessError(Exception):
-    def __init__(self, msg: str):
+    def __init__(self, addr: int, size: int, msg: str):
         super().__init__(msg)
+        self.mem_addr = addr
+        self.mem_size = size
 
 class SparseMemory:
     """Sparse memory.
@@ -35,7 +37,8 @@ class SparseMemory:
         while size > 0:
             page_addr, off = self._to_page_addr_and_offset(addr)
             if page_addr not in self._pages:
-                raise MemoryAccessError(f'Address {addr} is not contained in'
+                raise MemoryAccessError(addr, size,
+                                        f'Address {addr} is not contained in'
                                         f' the sparse memory.')
             data = self._pages[page_addr]
             assert(len(data) == self.page_size)
diff --git a/symbolic.py b/symbolic.py
index b005c5e..6e70bc9 100644
--- a/symbolic.py
+++ b/symbolic.py
@@ -147,8 +147,7 @@ class MiasmSymbolicTransform(SymbolicTransform):
             res += f'   {reg:6s} = {expr}\n'
         for mem, expr in self.mem_diff.items():
             res += f'   {mem} = {expr}\n'
-
-        return res
+        return res[:-2]  # Remove trailing newline
 
 def _step_until(target: LLDBConcreteTarget, addr: int) -> list[int]:
     """Step a concrete target to a specific instruction.