about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/focaccia/parser.py22
-rw-r--r--src/focaccia/qemu/_qemu_tool.py6
-rw-r--r--src/focaccia/trace.py25
3 files changed, 30 insertions, 23 deletions
diff --git a/src/focaccia/parser.py b/src/focaccia/parser.py
index 5602be6..60bf2af 100644
--- a/src/focaccia/parser.py
+++ b/src/focaccia/parser.py
@@ -8,7 +8,7 @@ from typing import TextIO
 from .arch import supported_architectures, Arch
 from .snapshot import ProgramState
 from .symbolic import SymbolicTransform
-from .trace import Trace, TraceEnvironment
+from .trace import Trace, TraceContainer, TraceEnvironment
 
 class ParseError(Exception):
     """A parse error."""
@@ -20,7 +20,7 @@ def _get_or_throw(obj: dict, key: str):
         return val
     raise ParseError(f'Expected value at key {key}, but found none.')
 
-def parse_transformations(json_stream: TextIO) -> Trace[SymbolicTransform]:
+def parse_transformations(json_stream: TextIO) -> TraceContainer[SymbolicTransform]:
     """Parse symbolic transformations from a text stream."""
     data = json.loads(json_stream.read())
 
@@ -28,7 +28,7 @@ def parse_transformations(json_stream: TextIO) -> Trace[SymbolicTransform]:
     strace = [SymbolicTransform.from_json(item) \
               for item in _get_or_throw(data, 'states')]
 
-    return Trace(strace, env)
+    return TraceContainer(strace, env)
 
 def serialize_transformations(transforms: Trace[SymbolicTransform],
                               out_stream: TextIO):
@@ -39,7 +39,7 @@ def serialize_transformations(transforms: Trace[SymbolicTransform],
     }, option=json.OPT_INDENT_2).decode()
     out_stream.write(data)
 
-def parse_snapshots(json_stream: TextIO) -> Trace[ProgramState]:
+def parse_snapshots(json_stream: TextIO) -> TraceContainer[ProgramState]:
     """Parse snapshots from our JSON format."""
     json_data = json.loads(json_stream.read())
 
@@ -58,7 +58,7 @@ def parse_snapshots(json_stream: TextIO) -> Trace[ProgramState]:
 
         snapshots.append(state)
 
-    return Trace(snapshots, env)
+    return TraceContainer(snapshots, env)
 
 def serialize_snapshots(snapshots: Trace[ProgramState], out_stream: TextIO):
     """Serialize a list of snapshots to out JSON format."""
@@ -89,7 +89,7 @@ def serialize_snapshots(snapshots: Trace[ProgramState], out_stream: TextIO):
 def _make_unknown_env() -> TraceEnvironment:
     return TraceEnvironment('', [], False, [], '?')
 
-def parse_qemu(stream: TextIO, arch: Arch) -> Trace[ProgramState]:
+def parse_qemu(stream: TextIO, arch: Arch) -> TraceContainer[ProgramState]:
     """Parse a QEMU log from a stream.
 
     Recommended QEMU log option: `qemu -d exec,cpu,fpu,vpu,nochain`. The `exec`
@@ -106,7 +106,7 @@ def parse_qemu(stream: TextIO, arch: Arch) -> Trace[ProgramState]:
         if states:
             _parse_qemu_line(line, states[-1])
 
-    return Trace(states, _make_unknown_env())
+    return TraceContainer(states, _make_unknown_env())
 
 def _parse_qemu_line(line: str, cur_state: ProgramState):
     """Try to parse a single register-assignment line from a QEMU log.
@@ -147,7 +147,7 @@ def _parse_qemu_line(line: str, cur_state: ProgramState):
             if regname is not None:
                 cur_state.set_register(regname, int(value, 16))
 
-def parse_arancini(stream: TextIO, arch: Arch) -> Trace[ProgramState]:
+def parse_arancini(stream: TextIO, arch: Arch) -> TraceContainer[ProgramState]:
     aliases = {
         'Program counter': 'RIP',
         'flag ZF': 'ZF',
@@ -172,9 +172,9 @@ def parse_arancini(stream: TextIO, arch: Arch) -> Trace[ProgramState]:
             if regname is not None:
                 states[-1].set_register(regname, int(value, 16))
 
-    return Trace(states, _make_unknown_env())
+    return TraceContainer(states, _make_unknown_env())
 
-def parse_box64(stream: TextIO, arch: Arch) -> Trace[ProgramState]:
+def parse_box64(stream: TextIO, arch: Arch) -> TraceContainer[ProgramState]:
     def parse_box64_flags(state: ProgramState, flags_dump: str):
         flags = ['O', 'D', 'S', 'Z', 'A', 'P', 'C']
         for i, flag in enumerate(flags):
@@ -203,5 +203,5 @@ def parse_box64(stream: TextIO, arch: Arch) -> Trace[ProgramState]:
             if regname is not None:
                 states[-1].set_register(regname, int(value, 16))
 
-    return Trace(states, _make_unknown_env())
+    return TraceContainer(states, _make_unknown_env())
 
diff --git a/src/focaccia/qemu/_qemu_tool.py b/src/focaccia/qemu/_qemu_tool.py
index c534f7b..64a2949 100644
--- a/src/focaccia/qemu/_qemu_tool.py
+++ b/src/focaccia/qemu/_qemu_tool.py
@@ -19,7 +19,7 @@ from focaccia.snapshot import (
     MemoryAccessError,
 )
 from focaccia.symbolic import SymbolicTransform, eval_symbol, ExprMem
-from focaccia.trace import Trace, TraceEnvironment
+from focaccia.trace import Trace, TraceContainer, TraceEnvironment
 from focaccia.utils import print_result
 from focaccia.deterministic import DeterministicLog, Event
 
@@ -112,7 +112,7 @@ def record_minimal_snapshot(prev_state: ReadableProgramState,
     return state
 
 def collect_conc_trace(gdb: GDBServerStateIterator, \
-                       strace: list[SymbolicTransform],
+                       strace: TraceContainer,
                        start_addr: int | None = None,
                        stop_addr: int | None = None) \
         -> tuple[list[ProgramState], list[SymbolicTransform]]:
@@ -264,7 +264,7 @@ def main():
     try:
         conc_states, matched_transforms = collect_conc_trace(
             gdb_server,
-            symb_transforms.states,
+            symb_transforms,
             symb_transforms.env.start_address,
             symb_transforms.env.stop_address)
     except Exception as e:
diff --git a/src/focaccia/trace.py b/src/focaccia/trace.py
index 14c475b..ffae8d5 100644
--- a/src/focaccia/trace.py
+++ b/src/focaccia/trace.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import Generic, TypeVar
+from typing import Generic, TypeVar, Iterable
 
 from .utils import file_hash
 
@@ -68,21 +68,28 @@ class TraceEnvironment:
 
 class Trace(Generic[T]):
     def __init__(self,
-                 trace_states: list[T],
+                 states: Iterable[T],
                  env: TraceEnvironment):
         self.env = env
-        self.states = trace_states
+        self._iter = states
+
+    def __iter__(self):
+        return iter(self._iter)
+
+class TraceContainer(Trace[T]):
+    def __init__(self,
+                 states: list[T],
+                 env: TraceEnvironment):
+        self._state_list = states
+        super().__init__(iter(states), env)
 
     def __len__(self) -> int:
-        return len(self.states)
+        return len(self._state_list)
 
     def __getitem__(self, i: int) -> T:
-        return self.states[i]
-
-    def __iter__(self):
-        return iter(self.states)
+        return self._state_list[i]
 
     def __repr__(self) -> str:
-        return f'Trace with {len(self.states)} trace points.' \
+        return f'Trace with {len(self._state_list)} trace points.' \
                f' Environment: {repr(self.env)}'