about summary refs log tree commit diff stats
path: root/compare.py
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-12-08 16:17:35 +0100
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2023-12-08 16:17:35 +0100
commit4a5584d8f69d8ff511285387971d8cbf803f16b7 (patch)
tree11c9e104fadc9b47f3f423f4be3bf0be34edf4f8 /compare.py
parent0cf4f736fd5d7cd99f00d6c5896af9a608d2df8b (diff)
downloadfocaccia-4a5584d8f69d8ff511285387971d8cbf803f16b7.tar.gz
focaccia-4a5584d8f69d8ff511285387971d8cbf803f16b7.zip
Adapt symbolic compare to new transform interface
Also implement a `MiasmSymbolicTransform.concat` function that
concatenates two transformations. Some minor adaptions to the eval_expr
code was necessary to remove some assumptions that don't work if the
resolver state returns symbols instead of concrete values.

Remove obsolete utilities that were used for angr.

Co-authored-by: Theofilos Augoustis <theofilos.augoustis@gmail.com>
Co-authored-by: Nicola Crivellin <nicola.crivellin98@gmail.com>
Diffstat (limited to 'compare.py')
-rw-r--r--compare.py78
1 files changed, 42 insertions, 36 deletions
diff --git a/compare.py b/compare.py
index 8a25d8a..0f144bf 100644
--- a/compare.py
+++ b/compare.py
@@ -1,8 +1,7 @@
-from snapshot import ProgramState, SnapshotSymbolResolver
+from snapshot import ProgramState
 from symbolic import SymbolicTransform
-from utils import print_separator
 
-def calc_transformation(previous: ProgramState, current: ProgramState):
+def _calc_transformation(previous: ProgramState, current: ProgramState):
     """Calculate the difference between two context blocks.
 
     :return: A context block that contains in its registers the difference
@@ -13,14 +12,18 @@ def calc_transformation(previous: ProgramState, current: ProgramState):
     arch = previous.arch
     transformation = ProgramState(arch)
     for reg in arch.regnames:
-        prev_val, cur_val = previous.regs[reg], current.regs[reg]
-        if prev_val is not None and cur_val is not None:
-            transformation.regs[reg] = cur_val - prev_val
+        try:
+            prev_val, cur_val = previous.read(reg), current.read(reg)
+            if prev_val is not None and cur_val is not None:
+                transformation.set(reg, cur_val - prev_val)
+        except ValueError:
+            # Register is not set in either state
+            pass
 
     return transformation
 
-def find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
-                truth_state: ProgramState, prev_truth_state: ProgramState) \
+def _find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
+                 truth_state: ProgramState, prev_truth_state: ProgramState) \
         -> list[dict]:
     """Find possible errors between a reference and a tested state.
 
@@ -38,11 +41,16 @@ def find_errors(txl_state: ProgramState, prev_txl_state: ProgramState,
     arch = txl_state.arch
     errors = []
 
-    transform_truth = calc_transformation(prev_truth_state, truth_state)
-    transform_txl = calc_transformation(prev_txl_state, txl_state)
+    transform_truth = _calc_transformation(prev_truth_state, truth_state)
+    transform_txl = _calc_transformation(prev_txl_state, txl_state)
     for reg in arch.regnames:
-        diff_txl = transform_txl.regs[reg]
-        diff_truth = transform_truth.regs[reg]
+        try:
+            diff_txl = transform_txl.read(reg)
+            diff_truth = transform_truth.read(reg)
+        except ValueError:
+            # Register is not set in either state
+            continue
+
         if diff_txl == diff_truth:
             # The register contains a value that is expected
             # by the transformation.
@@ -80,7 +88,7 @@ def compare_simple(test_states: list[ProgramState],
     # No errors in initial snapshot because we can't perform difference
     # calculations on it
     result = [{
-        'pc': test_states[0].regs[PC_REGNAME],
+        'pc': test_states[0].read(PC_REGNAME),
         'txl': test_states[0], 'ref': truth_states[0],
         'errors': []
     }]
@@ -91,21 +99,19 @@ def compare_simple(test_states: list[ProgramState],
     for txl, truth in it_cur:
         prev_txl, prev_truth = next(it_prev)
 
-        pc_txl = txl.regs[PC_REGNAME]
-        pc_truth = truth.regs[PC_REGNAME]
+        pc_txl = txl.read(PC_REGNAME)
+        pc_truth = truth.read(PC_REGNAME)
 
         # The program counter should always be set on a snapshot
         assert(pc_truth is not None)
         assert(pc_txl is not None)
 
         if pc_txl != pc_truth:
-            print(f'Unmatched program counter {txl.as_repr(PC_REGNAME)}'
+            print(f'Unmatched program counter {hex(txl.read(PC_REGNAME))}'
                   f' in translated code!')
             continue
-        else:
-            txl.matched = True
 
-        errors = find_errors(txl, prev_txl, truth, prev_truth)
+        errors = _find_errors(txl, prev_txl, truth, prev_truth)
         result.append({
             'pc': pc_txl,
             'txl': txl, 'ref': truth,
@@ -113,20 +119,19 @@ def compare_simple(test_states: list[ProgramState],
         })
 
         # TODO: Why do we skip backward branches?
-        if txl.has_backwards:
-            print(f' -- Encountered backward branch. Don\'t skip.')
+        #if txl.has_backwards:
+        #    print(f' -- Encountered backward branch. Don\'t skip.')
 
     return result
 
-def find_errors_symbolic(txl_from: ProgramState,
-                         txl_to: ProgramState,
-                         transform_truth: SymbolicTransform) \
+def _find_errors_symbolic(txl_from: ProgramState,
+                          txl_to: ProgramState,
+                          transform_truth: SymbolicTransform) \
         -> list[dict]:
     arch = txl_from.arch
-    resolver = SnapshotSymbolResolver(txl_from)
 
-    assert(txl_from.read('PC') == transform_truth.start_addr)
-    assert(txl_to.read('PC') == transform_truth.end_addr)
+    assert(txl_from.read('PC') == transform_truth.range[0])
+    assert(txl_to.read('PC') == transform_truth.range[1])
 
     errors = []
     for reg in arch.regnames:
@@ -137,14 +142,14 @@ def find_errors_symbolic(txl_from: ProgramState,
 
         txl_val = txl_to.read(reg)
         try:
-            truth = transform_truth.eval_register_transform(reg.lower(), resolver)
+            truth = transform_truth.calc_register_transform(txl_from)
             print(f'Evaluated symbolic formula to {hex(txl_val)} vs. txl {hex(txl_val)}')
             if txl_val != truth:
                 errors.append({
                     'reg': reg,
                     'expected': truth,
                     'actual': txl_val,
-                    'equation': transform_truth.state.regs.get(reg),
+                    'equation': transform_truth.regs_diff[reg],
                 })
         except AttributeError:
             print(f'Register {reg} does not exist.')
@@ -157,7 +162,7 @@ def compare_symbolic(test_states: list[ProgramState],
     PC_REGNAME = 'PC'
 
     result = [{
-        'pc': test_states[0].regs[PC_REGNAME],
+        'pc': test_states[0].read(PC_REGNAME),
         'txl': test_states[0],
         'ref': transforms[0],
         'errors': []
@@ -171,21 +176,22 @@ def compare_symbolic(test_states: list[ProgramState],
         # The program counter should always be set on a snapshot
         assert(pc_cur is not None and pc_next is not None)
 
-        if pc_cur != transform.start_addr:
+        start_addr, end_addr = transform.range
+        if pc_cur != start_addr:
             print(f'Program counter {hex(pc_cur)} in translated code has no'
                   f' corresponding reference state! Skipping.'
-                  f' (reference: {hex(transform.start_addr)})')
+                  f' (reference: {hex(start_addr)})')
             continue
-        if pc_next != transform.end_addr:
+        if pc_next != end_addr:
             print(f'Tested state transformation is {hex(pc_cur)} ->'
                   f' {hex(pc_next)}, but reference transform is'
-                  f' {hex(transform.start_addr)} -> {hex(transform.end_addr)}!'
+                  f' {hex(start_addr)} -> {hex(end_addr)}!'
                   f' Skipping.')
 
-        errors = find_errors_symbolic(cur_state, next_state, transform)
+        errors = _find_errors_symbolic(cur_state, next_state, transform)
         result.append({
             'pc': pc_cur,
-            'txl': calc_transformation(cur_state, next_state),
+            'txl': _calc_transformation(cur_state, next_state),
             'ref': transform,
             'errors': errors
         })