about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/focaccia/deterministic.py182
-rw-r--r--src/focaccia/native/tracer.py48
-rw-r--r--src/focaccia/qemu/_qemu_tool.py57
3 files changed, 176 insertions, 111 deletions
diff --git a/src/focaccia/deterministic.py b/src/focaccia/deterministic.py
index 6d76457..41d49b8 100644
--- a/src/focaccia/deterministic.py
+++ b/src/focaccia/deterministic.py
@@ -2,7 +2,7 @@ from .arch import Arch
 from .snapshot import ReadableProgramState
 
 from reprlib import repr as alt_repr
-from typing import Callable
+from typing import Callable, Tuple, Optional
 
 class MemoryWriteHole:
     def __init__(self, offset: int, size: int):
@@ -285,60 +285,136 @@ except Exception:
         def tasks(self) -> list[Task]: return []
         def mmaps(self) -> list[MemoryMapping]: return []
 finally:
-    class DeterministicEventIterator:
-        def __init__(self, deterministic_log: DeterministicLog, match_fn: Callable):
-            self._detlog = deterministic_log
-            self._events = self._detlog.events()
-            self._pc_to_event = {}
-            self._match = match_fn
-            self._idx: int | None = None # None represents no current event
-            self._in_event: bool = False
-
-            idx = 0
-            for event in self._events:
-                self._pc_to_event.setdefault(event.pc, []).append((event, idx))
+    class EventMatcher:
+        def __init__(self, 
+                     events: list[Event], 
+                     match_fn: Callable,
+                     from_state: ReadableProgramState | None = None):
+            self.events = events
+            self.matcher = match_fn
+
+            self.matched_count = None
+            if from_state:
+                self.match(from_state)
+
+        def match(self, state: ReadableProgramState) -> Event | None:
+            if self.matched_count is None:
+                # Need to synchronize
+                # Search for match
+                for idx in range(len(self.events)):
+                    event = self.events[idx]
+                    if self.matcher(event, state):
+                        self.matched_count = idx + 1
+                        return event
+
+                if self.matched_count is None:
+                    return None
+
+            event = self.events[self.matched_count]
+            if self.matcher(event, state):
+                self.matched_count += 1 # proceed to next
+                return event
+            
+            return None
+
+        def next(self):
+            if self.matched_count is None:
+                raise ValueError('Cannot get next event with unsynchronized event matcher')
+            if self.matched_count < len(self.events):
+                return self.events[self.matched_count]
+            return None
+
+        def match_pair(self, state: ReadableProgramState):
+            event = self.match(state)
+            if event is None:
+                return None, None
+            if isinstance(event, SyscallEvent) and event.syscall_state == 'exiting':
+                self.matched_count = None
+                return None, None
+            assert(self.matched_count is not None)
+            post_event = self.events[self.matched_count]
+            self.matched_count += 1
+            return event, post_event
+
+        def __bool__(self) -> bool:
+            return len(self.events) > 0
+
+    class MappingMatcher:
+        def __init__(self, memory_mappings: list[MemoryMapping]):
+            self.memory_mappings = memory_mappings
+            self.matched_count = None
+
+        def match(self, event_count: int) -> MemoryMapping | None:
+            if self.matched_count is None:
+                # Need to synchronize
+                # Search for match
+                for idx in range(len(self.memory_mappings)):
+                    mapping = self.memory_mappings[idx]
+                    if mapping.event_count == event_count:
+                        self.matched_count = idx + 1
+                        return mapping
+
+                if self.matched_count is None:
+                    return None
+
+            mapping = self.memory_mappings[self.matched_count]
+            if mapping.event_count == event_count:
+                self.matched_count += 1 # proceed to next
+                return mapping
+            
+            return None
+
+        def next(self):
+            if self.matched_count is None:
+                raise ValueError('Cannot get next mapping with unsynchronized mapping matcher')
+            if self.matched_count < len(self.memory_mappings):
+                return self.memory_mappings[self.matched_count]
+            return None
+
+        def __bool__(self) -> bool:
+            return len(self.memory_mappings) > 0
+
+    class LogStateMatcher:
+        def __init__(self, 
+                     events: list[Event], 
+                     memory_mappings: list[MemoryMapping],
+                     event_match_fn: Callable,
+                     from_state: ReadableProgramState | None = None):
+            self.event_matcher = EventMatcher(events, event_match_fn, from_state)
+            self.mapping_matcher = MappingMatcher(memory_mappings)
 
         def events(self) -> list[Event]:
-            return self._events
-
-        def current_event(self) -> Event | None:
-            # No event when not synchronized
-            if self._idx is None or not self._in_event:
-                return None
-            return self._events[self._idx]
-
-        def next_event(self) -> Event | None:
-            if self._idx is None:
-                raise ValueError('Attempted to get next event without synchronizing')
-            if self._idx + 1 >= len(self._events):
-                return None
-            return self._events[self._idx+1]
-
-        def update(self, target: ReadableProgramState) -> Event | None:
-            # Quick check
-            candidates = self._pc_to_event.get(target.read_pc(), [])
-            if len(candidates) == 0:
-                self._in_event = False
-                return None
-
-            # Find synchronization point
-            if self._idx is None:
-                for event, idx in candidates:
-                    if self._match(event, target):
-                        self._idx = idx
-                        self._in_event = True
-                        return self.current_event()
-
-            return self.update_to_next()
-
-        def update_to_next(self, count: int = 1) -> Event | None:
-            if self._idx is None:
-                raise ValueError('Attempted to get next event without synchronizing')
-
-            self._in_event = True
-            self._idx += count
-            return self.current_event()
+            return self.event_matcher.events
+
+        def mappings(self) -> list[MemoryMapping]:
+            return self.mapping_matcher.memory_mappings
+
+        def matched_events(self) -> Optional[int]:
+            return self.event_matcher.matched_count
+
+        def match(self, state: ReadableProgramState) -> Tuple[Optional[Event], Optional[MemoryMapping]]:
+            event = self.event_matcher.match(state)
+            if not event:
+                return None, None
+            assert(self.event_matcher.matched_count is not None)
+            mapping = self.mapping_matcher.match(self.event_matcher.matched_count)
+            return event, mapping
+
+        def match_pair(self, state: ReadableProgramState) -> Tuple[Optional[Event], Optional[Event], Optional[MemoryMapping]]:
+            event, post_event = self.event_matcher.match_pair(state)
+            if not event:
+                return None, None, None
+            assert(self.event_matcher.matched_count is not None)
+            mapping = self.mapping_matcher.match(self.event_matcher.matched_count-1)
+            return event, post_event, mapping
+
+        def next(self) -> Tuple[Optional[Event], Optional[MemoryMapping]]:
+            next_event = self.event_matcher.next()
+            if not next_event:
+                return None, None
+            assert(self.event_matcher.matched_count is not None)
+            return next_event, self.mapping_matcher.match(self.event_matcher.matched_count)
 
         def __bool__(self) -> bool:
-            return len(self.events()) > 0
+            return bool(self.event_matcher)
 
diff --git a/src/focaccia/native/tracer.py b/src/focaccia/native/tracer.py
index 4376f41..b2ca0d8 100644
--- a/src/focaccia/native/tracer.py
+++ b/src/focaccia/native/tracer.py
@@ -12,7 +12,7 @@ from focaccia.trace import Trace, TraceEnvironment
 from focaccia.miasm_util import MiasmSymbolResolver
 from focaccia.snapshot import ReadableProgramState, RegisterAccessError
 from focaccia.symbolic import SymbolicTransform, DisassemblyContext, run_instruction
-from focaccia.deterministic import Event, DeterministicEventIterator
+from focaccia.deterministic import Event, EventMatcher
 
 from .lldb_target import LLDBConcreteTarget, LLDBLocalTarget, LLDBRemoteTarget
 
@@ -154,8 +154,6 @@ class SymbolicTracer:
         self.cross_validate = cross_validate
         self.target = SpeculativeTracer(self.create_debug_target())
 
-        self.nondet_events = DeterministicEventIterator(self.env.detlog, match_event)
-
     def create_debug_target(self) -> LLDBConcreteTarget:
         binary = self.env.binary_name
         if self.remote is False:
@@ -208,30 +206,10 @@ class SymbolicTracer:
                                       f' mem[{hex(addr)}:{hex(addr+len(data))}] = {conc_data}.'
                                       f'\nFaulty transformation: {transform}')
 
-    def post_event(self) -> None:
-        current_event = self.nondet_events.current_event()
-        if current_event:
-            if current_event.pc == 0:
-                # Exit sequence
-                debug('Completed exit event')
-                self.target.run()
-
-            debug(f'Completed handling event: {current_event}')
-            self.nondet_events.update_to_next()
-
-    def is_stepping_instr(self, instruction: Instruction) -> bool:
-        if self.nondet_events.current_event():
-            debug('Current instruction matches next event; stepping through it')
-            self.nondet_events.update_to_next()
-            return True
-        else:
-            if self.target.arch.is_instr_syscall(str(instruction)):
-                return True
-        return False
-
     def progress(self, new_pc, step: bool = False) -> int | None:
         self.target.speculate(new_pc)
         if step:
+            info(f'Stepping through event at {hex(self.target.read_pc())}')
             self.target.progress_execution()
             if self.target.is_exited():
                 return None
@@ -251,9 +229,10 @@ class SymbolicTracer:
         ctx = DisassemblyContext(self.target)
         arch = ctx.arch
 
+        event_matcher = EventMatcher(self.env.detlog.events(), match_event, self.target)
         if logger.isEnabledFor(logging.DEBUG):
             debug('Tracing program with the following non-deterministic events')
-            for event in self.nondet_events.events():
+            for event in event_matcher.events:
                 debug(event)
 
         # Trace concolically
@@ -262,10 +241,9 @@ class SymbolicTracer:
             pc = self.target.read_pc()
 
             if self.env.stop_address is not None and pc == self.env.stop_address:
+                info(f'Reached stop address at {hex(pc)}')
                 break
 
-            self.nondet_events.update(self.target)
-
             # Disassemble instruction at the current PC
             tid = self.target.get_current_tid()
             try:
@@ -292,7 +270,8 @@ class SymbolicTracer:
                         continue
                     raise # forward exception
 
-            is_event = self.is_stepping_instr(instruction)
+            event, post_event = event_matcher.match_pair(self.target)
+            in_event = (event and event_matcher) or self.target.arch.is_instr_syscall(str(instruction))
 
             # Run instruction
             conc_state = MiasmSymbolResolver(self.target, ctx.loc_db)
@@ -311,7 +290,7 @@ class SymbolicTracer:
                 new_pc = int(new_pc)
                 transform = SymbolicTransform(tid, modified, [instruction], arch, pc, new_pc)
                 pred_regs, pred_mems = self.predict_next_state(instruction, transform)
-                self.progress(new_pc, step=is_event)
+                self.progress(new_pc, step=in_event)
 
                 try:
                     self.validate(instruction, transform, pred_regs, pred_mems)
@@ -321,7 +300,7 @@ class SymbolicTracer:
                         continue
                     raise
             else:
-                new_pc = self.progress(new_pc, step=is_event)
+                new_pc = self.progress(new_pc, step=in_event)
                 if new_pc is None:
                     transform = SymbolicTransform(tid, modified, [instruction], arch, pc, 0)
                     strace.append(transform)
@@ -330,8 +309,13 @@ class SymbolicTracer:
 
             strace.append(transform)
 
-            if is_event:
-                self.post_event()
+            if post_event:
+                if post_event.pc == 0:
+                    # Exit sequence
+                    debug('Completed exit event')
+                    self.target.run()
+
+                debug(f'Completed handling event: {post_event}')
 
         return Trace(strace, self.env)
 
diff --git a/src/focaccia/qemu/_qemu_tool.py b/src/focaccia/qemu/_qemu_tool.py
index 75b142e..188ecf2 100644
--- a/src/focaccia/qemu/_qemu_tool.py
+++ b/src/focaccia/qemu/_qemu_tool.py
@@ -9,7 +9,7 @@ work to do.
 import gdb
 import logging
 import traceback
-from typing import Iterable
+from typing import Iterable, Optional
 
 import focaccia.parser as parser
 from focaccia.arch import supported_architectures, Arch
@@ -19,7 +19,13 @@ from focaccia.snapshot import ProgramState, ReadableProgramState, \
 from focaccia.symbolic import SymbolicTransform, eval_symbol, ExprMem
 from focaccia.trace import Trace, TraceEnvironment
 from focaccia.utils import print_result
-from focaccia.deterministic import DeterministicLog, DeterministicEventIterator, Event, SyscallEvent
+from focaccia.deterministic import (
+    DeterministicLog,
+    LogStateMatcher,
+    Event,
+    SyscallEvent,
+    MemoryMapping,
+)
 from focaccia.qemu.deterministic import emulated_system_calls
 
 from focaccia.tools.validate_qemu import make_argparser, verbosity
@@ -40,6 +46,7 @@ qemu_crash = {
 
 def match_event(event: Event, target: ReadableProgramState) -> bool:
     # Match just on PC
+    debug(f'Matching for PC {hex(target.read_pc())} with event {hex(event.pc)}')
     if event.pc == target.read_pc():
         return True
     return False
@@ -149,23 +156,20 @@ class GDBServerStateIterator:
         self.arch = supported_architectures[archname]
         self.binary = self._process.progspace.filename
 
-        self._deterministic_events = DeterministicEventIterator(self._deterministic_log, match_event)
-
-        # Filter non-deterministic events for event after start
-        self._deterministic_events.update(self.current_state())
-        self._deterministic_events.update_to_next()
+        self._log_matcher = LogStateMatcher(self._deterministic_log.events(),
+                                            self._deterministic_log.mmaps(),
+                                            match_event,
+                                            from_state=self.current_state())
+        info(f'Synchronizing at PC {hex(self.current_state().read_pc())} with {self._log_matcher.matched_events()}')
 
     def current_state(self) -> ReadableProgramState:
         return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-    def _handle_syscall(self) -> GDBProgramState:
-        cur_event = self._deterministic_events.current_event()
-        call = cur_event.registers.get(self.arch.get_syscall_reg())
+    def _handle_syscall(self, event: Event, post_event: Event) -> GDBProgramState:
+        call = event.registers.get(self.arch.get_syscall_reg())
 
-        post_event = self._deterministic_events.update_to_next()
         syscall = emulated_system_calls[self.arch.archname].get(call, None)
-        debug(f'Handling event:\n{cur_event}')
-        if syscall is not None:
+        if syscall is not None and False:
             info(f'Replaying system call number {hex(call)}')
 
             self.skip(post_event.pc)
@@ -189,15 +193,14 @@ class GDBServerStateIterator:
             raise StopIteration
         return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-    def _handle_event(self) -> GDBProgramState:
-        current_event = self._deterministic_events.current_event()
-        if not current_event:
+    def _handle_event(self, event: Event | None, post_event: Event | None) -> GDBProgramState:
+        if not event:
             return self.current_state()
 
-        if isinstance(current_event, SyscallEvent):
-            return self._handle_syscall()
+        if isinstance(event, SyscallEvent):
+            return self._handle_syscall(event, post_event)
 
-        warn(f'Event handling for events of type {current_event.event_type} not implemented')
+        warn(f'Event handling for events of type {event.event_type} not implemented')
         return self.current_state()
 
     def _is_exited(self) -> bool:
@@ -213,11 +216,12 @@ class GDBServerStateIterator:
             self._first_next = False
             return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-        if match_event(self._deterministic_events.current_event(), self.current_state()):
-            state = self._handle_event()
+        event, post_event, _ = self._log_matcher.match_pair(self.current_state())
+        if event:
+            state = self._handle_event(event, post_event)
             if self._is_exited():
                 raise StopIteration
-            self._deterministic_events.update_to_next()
+
             return state
 
         # Step
@@ -233,7 +237,7 @@ class GDBServerStateIterator:
 
     def run_until(self, addr: int) -> ReadableProgramState:
         events_handled = 0
-        event = self._deterministic_events.current_event()
+        event, _ = self._log_matcher.next()
         while event:
             state = self._run_until_any([addr, event.pc])
             if state.read_pc() == addr:
@@ -241,9 +245,10 @@ class GDBServerStateIterator:
                 self._first_next = events_handled == 0
                 return state
 
-            self._handle_event()
+            event, post_event, _ = self._log_matcher.match_pair(self.current_state())
+            self._handle_event(event, post_event)
 
-            event = self._deterministic_events.update_to_next()
+            event, _ = self._log_matcher.next()
             events_handled += 1
         return self._run_until_any([addr])
 
@@ -375,7 +380,7 @@ def collect_conc_trace(gdb: GDBServerStateIterator, \
 
     if logger.isEnabledFor(logging.DEBUG):
         debug('Tracing program with the following non-deterministic events:')
-        for event in gdb._deterministic_events.events():
+        for event in gdb._log_matcher.events():
             debug(event)
 
     # Skip to start