about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/focaccia/qemu/_qemu_tool.py59
1 files changed, 34 insertions, 25 deletions
diff --git a/src/focaccia/qemu/_qemu_tool.py b/src/focaccia/qemu/_qemu_tool.py
index 07a6981..42f1628 100644
--- a/src/focaccia/qemu/_qemu_tool.py
+++ b/src/focaccia/qemu/_qemu_tool.py
@@ -8,7 +8,7 @@ work to do.
 
 import logging
 import traceback
-from typing import Iterable
+from typing import Iterable, Optional
 
 import focaccia.parser as parser
 from focaccia.compare import compare_symbolic, Error, ErrorTypes
@@ -111,7 +111,7 @@ def record_minimal_snapshot(prev_state: ReadableProgramState,
                state)
     return state
 
-def collect_conc_trace(gdb: GDBServerStateIterator, strace: TraceContainer) \
+def collect_conc_trace(gdb: GDBServerStateIterator, strace: Trace) \
         -> tuple[list[ProgramState], list[SymbolicTransform]]:
     """Collect a trace of concrete states from GDB.
 
@@ -127,9 +127,9 @@ def collect_conc_trace(gdb: GDBServerStateIterator, strace: TraceContainer) \
     :return: A list of concrete states and a list of corresponding symbolic
              transformations. The lists are guaranteed to have the same length.
     """
-    def find_index(seq, target, access=lambda el: el):
+    def find_index(seq, target):
         for i, el in enumerate(seq):
-            if access(el) == target:
+            if el == target:
                 return i
         return None
 
@@ -143,6 +143,8 @@ def collect_conc_trace(gdb: GDBServerStateIterator, strace: TraceContainer) \
     cur_state = next(state_iter)
     symb_i = 0
 
+    trace = iter(strace)
+
     if logger.isEnabledFor(logging.DEBUG):
         debug('Tracing program with the following non-deterministic events:')
         for event in gdb._events.events:
@@ -162,16 +164,24 @@ def collect_conc_trace(gdb: GDBServerStateIterator, strace: TraceContainer) \
 
     # An online trace matching algorithm.
     info(f'Tracing QEMU between {hex(start_addr)}:{hex(strace.env.stop_address) if strace.env.stop_address else "end"}')
+
+    transform: Optional[SymbolicTransform] = None
     while True:
         try:
             pc = cur_state.read_pc()
             if strace.env.stop_address and pc == strace.env.stop_address:
                 break
 
-            while pc != strace[symb_i].addr:
-                warn(f'PC {hex(pc)} does not match next symbolic reference {hex(strace[symb_i].addr)}')
+            try:
+                symb_i += 1
+                transform = next(trace)
+            except StopIteration:
+                break
+
+            while pc != transform.addr:
+                warn(f'PC {hex(pc)} does not match next symbolic reference {hex(transform.addr)}')
 
-                next_i = find_index(strace[symb_i+1:], pc, lambda t: t.addr)
+                next_i = find_index(strace.addresses[symb_i:], pc)
 
                 # Drop the concrete state if no address in the symbolic trace
                 # matches
@@ -184,41 +194,40 @@ def collect_conc_trace(gdb: GDBServerStateIterator, strace: TraceContainer) \
                     continue
 
                 # Otherwise, jump to the next matching symbolic state
-                symb_i += next_i + 1
-                if symb_i >= len(strace):
-                    break
-
-            assert(cur_state.read_pc() == strace[symb_i].addr)
+                for _ in range(next_i+1):
+                    try:
+                        symb_i += 1
+                        transform = next(trace)
+                    except StopIteration:
+                        warn(f'QEMU executed more states than native execution: {symb_i} vs {len(strace.addresses)-1}')
+                        break
+
+            assert(cur_state.read_pc() == transform.addr)
             info(f'Validating instruction at address {hex(pc)}')
             states.append(record_minimal_snapshot(
                 states[-1] if states else cur_state,
                 cur_state,
-                matched_transforms[-1] if matched_transforms else strace[symb_i],
-                strace[symb_i]))
-            matched_transforms.append(strace[symb_i])
+                matched_transforms[-1] if matched_transforms else transform,
+                transform))
+            matched_transforms.append(transform)
             cur_state = next(state_iter)
-            symb_i += 1
-            if symb_i >= len(strace):
-                break
         except StopIteration:
             # TODO: The conditions may test for the same
             if strace.env.stop_address and pc != strace.env.stop_address:
                 raise Exception(f'QEMU stopped at {hex(pc)} before reaching the stop address'
                                 f' {hex(strace.env.stop_address)}')
-            if symb_i+1 < len(strace):
+
+            assert(transform is not None)
+            if symb_i+1 < len(strace.addresses):
                 qemu_crash["crashed"] = True
-                qemu_crash["pc"] = strace[symb_i].addr
-                qemu_crash["ref"] = strace[symb_i]
+                qemu_crash["pc"] = transform.addr
+                qemu_crash["ref"] = transform
                 qemu_crash["snap"] = states[-1]
             break
         except Exception as e:
             print(traceback.format_exc())
             raise e
 
-    # Note: this may occur when symbolic traces were gathered with a stop address
-    if symb_i >= len(strace):
-        warn(f'QEMU executed more states than native execution: {symb_i} vs {len(strace)-1}')
-
     return states, matched_transforms
 
 def main():