diff options
| author | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2025-11-17 19:06:25 +0000 |
|---|---|---|
| committer | Theofilos Augoustis <theofilos.augoustis@gmail.com> | 2025-11-18 18:19:38 +0000 |
| commit | fe56719f06e6fd53ae0d897cf29cee6456a0e1db (patch) | |
| tree | 26eef0271e3c319c612c9ae435f3d8bb7078eb50 | |
| parent | 6247fb055a5c1e3eddb9948b3abf4c1e766edc08 (diff) | |
| download | focaccia-fe56719f06e6fd53ae0d897cf29cee6456a0e1db.tar.gz focaccia-fe56719f06e6fd53ae0d897cf29cee6456a0e1db.zip | |
Refactor iteration interface over events
| -rw-r--r-- | src/focaccia/deterministic.py | 182 | ||||
| -rw-r--r-- | src/focaccia/native/tracer.py | 48 | ||||
| -rw-r--r-- | src/focaccia/qemu/_qemu_tool.py | 57 |
3 files changed, 176 insertions, 111 deletions
diff --git a/src/focaccia/deterministic.py b/src/focaccia/deterministic.py index 6d76457..41d49b8 100644 --- a/src/focaccia/deterministic.py +++ b/src/focaccia/deterministic.py @@ -2,7 +2,7 @@ from .arch import Arch from .snapshot import ReadableProgramState from reprlib import repr as alt_repr -from typing import Callable +from typing import Callable, Tuple, Optional class MemoryWriteHole: def __init__(self, offset: int, size: int): @@ -285,60 +285,136 @@ except Exception: def tasks(self) -> list[Task]: return [] def mmaps(self) -> list[MemoryMapping]: return [] finally: - class DeterministicEventIterator: - def __init__(self, deterministic_log: DeterministicLog, match_fn: Callable): - self._detlog = deterministic_log - self._events = self._detlog.events() - self._pc_to_event = {} - self._match = match_fn - self._idx: int | None = None # None represents no current event - self._in_event: bool = False - - idx = 0 - for event in self._events: - self._pc_to_event.setdefault(event.pc, []).append((event, idx)) + class EventMatcher: + def __init__(self, + events: list[Event], + match_fn: Callable, + from_state: ReadableProgramState | None = None): + self.events = events + self.matcher = match_fn + + self.matched_count = None + if from_state: + self.match(from_state) + + def match(self, state: ReadableProgramState) -> Event | None: + if self.matched_count is None: + # Need to synchronize + # Search for match + for idx in range(len(self.events)): + event = self.events[idx] + if self.matcher(event, state): + self.matched_count = idx + 1 + return event + + if self.matched_count is None: + return None + + event = self.events[self.matched_count] + if self.matcher(event, state): + self.matched_count += 1 # proceed to next + return event + + return None + + def next(self): + if self.matched_count is None: + raise ValueError('Cannot get next event with unsynchronized event matcher') + if self.matched_count < len(self.events): + return self.events[self.matched_count] + return None + + def match_pair(self, state: ReadableProgramState): + event = self.match(state) + if event is None: + return None, None + if isinstance(event, SyscallEvent) and event.syscall_state == 'exiting': + self.matched_count = None + return None, None + assert(self.matched_count is not None) + post_event = self.events[self.matched_count] + self.matched_count += 1 + return event, post_event + + def __bool__(self) -> bool: + return len(self.events) > 0 + + class MappingMatcher: + def __init__(self, memory_mappings: list[MemoryMapping]): + self.memory_mappings = memory_mappings + self.matched_count = None + + def match(self, event_count: int) -> MemoryMapping | None: + if self.matched_count is None: + # Need to synchronize + # Search for match + for idx in range(len(self.memory_mappings)): + mapping = self.memory_mappings[idx] + if mapping.event_count == event_count: + self.matched_count = idx + 1 + return mapping + + if self.matched_count is None: + return None + + mapping = self.memory_mappings[self.matched_count] + if mapping.event_count == event_count: + self.matched_count += 1 # proceed to next + return mapping + + return None + + def next(self): + if self.matched_count is None: + raise ValueError('Cannot get next mapping with unsynchronized mapping matcher') + if self.matched_count < len(self.memory_mappings): + return self.memory_mappings[self.matched_count] + return None + + def __bool__(self) -> bool: + return len(self.memory_mappings) > 0 + + class LogStateMatcher: + def __init__(self, + events: list[Event], + memory_mappings: list[MemoryMapping], + event_match_fn: Callable, + from_state: ReadableProgramState | None = None): + self.event_matcher = EventMatcher(events, event_match_fn, from_state) + self.mapping_matcher = MappingMatcher(memory_mappings) def events(self) -> list[Event]: - return self._events - - def current_event(self) -> Event | None: - # No event when not synchronized - if self._idx is None or not self._in_event: - return None - return self._events[self._idx] - - def next_event(self) -> Event | None: - if self._idx is None: - raise ValueError('Attempted to get next event without synchronizing') - if self._idx + 1 >= len(self._events): - return None - return self._events[self._idx+1] - - def update(self, target: ReadableProgramState) -> Event | None: - # Quick check - candidates = self._pc_to_event.get(target.read_pc(), []) - if len(candidates) == 0: - self._in_event = False - return None - - # Find synchronization point - if self._idx is None: - for event, idx in candidates: - if self._match(event, target): - self._idx = idx - self._in_event = True - return self.current_event() - - return self.update_to_next() - - def update_to_next(self, count: int = 1) -> Event | None: - if self._idx is None: - raise ValueError('Attempted to get next event without synchronizing') - - self._in_event = True - self._idx += count - return self.current_event() + return self.event_matcher.events + + def mappings(self) -> list[MemoryMapping]: + return self.mapping_matcher.memory_mappings + + def matched_events(self) -> Optional[int]: + return self.event_matcher.matched_count + + def match(self, state: ReadableProgramState) -> Tuple[Optional[Event], Optional[MemoryMapping]]: + event = self.event_matcher.match(state) + if not event: + return None, None + assert(self.event_matcher.matched_count is not None) + mapping = self.mapping_matcher.match(self.event_matcher.matched_count) + return event, mapping + + def match_pair(self, state: ReadableProgramState) -> Tuple[Optional[Event], Optional[Event], Optional[MemoryMapping]]: + event, post_event = self.event_matcher.match_pair(state) + if not event: + return None, None, None + assert(self.event_matcher.matched_count is not None) + mapping = self.mapping_matcher.match(self.event_matcher.matched_count-1) + return event, post_event, mapping + + def next(self) -> Tuple[Optional[Event], Optional[MemoryMapping]]: + next_event = self.event_matcher.next() + if not next_event: + return None, None + assert(self.event_matcher.matched_count is not None) + return next_event, self.mapping_matcher.match(self.event_matcher.matched_count) def __bool__(self) -> bool: - return len(self.events()) > 0 + return bool(self.event_matcher) diff --git a/src/focaccia/native/tracer.py b/src/focaccia/native/tracer.py index 4376f41..b2ca0d8 100644 --- a/src/focaccia/native/tracer.py +++ b/src/focaccia/native/tracer.py @@ -12,7 +12,7 @@ from focaccia.trace import Trace, TraceEnvironment from focaccia.miasm_util import MiasmSymbolResolver from focaccia.snapshot import ReadableProgramState, RegisterAccessError from focaccia.symbolic import SymbolicTransform, DisassemblyContext, run_instruction -from focaccia.deterministic import Event, DeterministicEventIterator +from focaccia.deterministic import Event, EventMatcher from .lldb_target import LLDBConcreteTarget, LLDBLocalTarget, LLDBRemoteTarget @@ -154,8 +154,6 @@ class SymbolicTracer: self.cross_validate = cross_validate self.target = SpeculativeTracer(self.create_debug_target()) - self.nondet_events = DeterministicEventIterator(self.env.detlog, match_event) - def create_debug_target(self) -> LLDBConcreteTarget: binary = self.env.binary_name if self.remote is False: @@ -208,30 +206,10 @@ class SymbolicTracer: f' mem[{hex(addr)}:{hex(addr+len(data))}] = {conc_data}.' f'\nFaulty transformation: {transform}') - def post_event(self) -> None: - current_event = self.nondet_events.current_event() - if current_event: - if current_event.pc == 0: - # Exit sequence - debug('Completed exit event') - self.target.run() - - debug(f'Completed handling event: {current_event}') - self.nondet_events.update_to_next() - - def is_stepping_instr(self, instruction: Instruction) -> bool: - if self.nondet_events.current_event(): - debug('Current instruction matches next event; stepping through it') - self.nondet_events.update_to_next() - return True - else: - if self.target.arch.is_instr_syscall(str(instruction)): - return True - return False - def progress(self, new_pc, step: bool = False) -> int | None: self.target.speculate(new_pc) if step: + info(f'Stepping through event at {hex(self.target.read_pc())}') self.target.progress_execution() if self.target.is_exited(): return None @@ -251,9 +229,10 @@ class SymbolicTracer: ctx = DisassemblyContext(self.target) arch = ctx.arch + event_matcher = EventMatcher(self.env.detlog.events(), match_event, self.target) if logger.isEnabledFor(logging.DEBUG): debug('Tracing program with the following non-deterministic events') - for event in self.nondet_events.events(): + for event in event_matcher.events: debug(event) # Trace concolically @@ -262,10 +241,9 @@ class SymbolicTracer: pc = self.target.read_pc() if self.env.stop_address is not None and pc == self.env.stop_address: + info(f'Reached stop address at {hex(pc)}') break - self.nondet_events.update(self.target) - # Disassemble instruction at the current PC tid = self.target.get_current_tid() try: @@ -292,7 +270,8 @@ class SymbolicTracer: continue raise # forward exception - is_event = self.is_stepping_instr(instruction) + event, post_event = event_matcher.match_pair(self.target) + in_event = (event and event_matcher) or self.target.arch.is_instr_syscall(str(instruction)) # Run instruction conc_state = MiasmSymbolResolver(self.target, ctx.loc_db) @@ -311,7 +290,7 @@ class SymbolicTracer: new_pc = int(new_pc) transform = SymbolicTransform(tid, modified, [instruction], arch, pc, new_pc) pred_regs, pred_mems = self.predict_next_state(instruction, transform) - self.progress(new_pc, step=is_event) + self.progress(new_pc, step=in_event) try: self.validate(instruction, transform, pred_regs, pred_mems) @@ -321,7 +300,7 @@ class SymbolicTracer: continue raise else: - new_pc = self.progress(new_pc, step=is_event) + new_pc = self.progress(new_pc, step=in_event) if new_pc is None: transform = SymbolicTransform(tid, modified, [instruction], arch, pc, 0) strace.append(transform) @@ -330,8 +309,13 @@ class SymbolicTracer: strace.append(transform) - if is_event: - self.post_event() + if post_event: + if post_event.pc == 0: + # Exit sequence + debug('Completed exit event') + self.target.run() + + debug(f'Completed handling event: {post_event}') return Trace(strace, self.env) diff --git a/src/focaccia/qemu/_qemu_tool.py b/src/focaccia/qemu/_qemu_tool.py index 75b142e..188ecf2 100644 --- a/src/focaccia/qemu/_qemu_tool.py +++ b/src/focaccia/qemu/_qemu_tool.py @@ -9,7 +9,7 @@ work to do. import gdb import logging import traceback -from typing import Iterable +from typing import Iterable, Optional import focaccia.parser as parser from focaccia.arch import supported_architectures, Arch @@ -19,7 +19,13 @@ from focaccia.snapshot import ProgramState, ReadableProgramState, \ from focaccia.symbolic import SymbolicTransform, eval_symbol, ExprMem from focaccia.trace import Trace, TraceEnvironment from focaccia.utils import print_result -from focaccia.deterministic import DeterministicLog, DeterministicEventIterator, Event, SyscallEvent +from focaccia.deterministic import ( + DeterministicLog, + LogStateMatcher, + Event, + SyscallEvent, + MemoryMapping, +) from focaccia.qemu.deterministic import emulated_system_calls from focaccia.tools.validate_qemu import make_argparser, verbosity @@ -40,6 +46,7 @@ qemu_crash = { def match_event(event: Event, target: ReadableProgramState) -> bool: # Match just on PC + debug(f'Matching for PC {hex(target.read_pc())} with event {hex(event.pc)}') if event.pc == target.read_pc(): return True return False @@ -149,23 +156,20 @@ class GDBServerStateIterator: self.arch = supported_architectures[archname] self.binary = self._process.progspace.filename - self._deterministic_events = DeterministicEventIterator(self._deterministic_log, match_event) - - # Filter non-deterministic events for event after start - self._deterministic_events.update(self.current_state()) - self._deterministic_events.update_to_next() + self._log_matcher = LogStateMatcher(self._deterministic_log.events(), + self._deterministic_log.mmaps(), + match_event, + from_state=self.current_state()) + info(f'Synchronizing at PC {hex(self.current_state().read_pc())} with {self._log_matcher.matched_events()}') def current_state(self) -> ReadableProgramState: return GDBProgramState(self._process, gdb.selected_frame(), self.arch) - def _handle_syscall(self) -> GDBProgramState: - cur_event = self._deterministic_events.current_event() - call = cur_event.registers.get(self.arch.get_syscall_reg()) + def _handle_syscall(self, event: Event, post_event: Event) -> GDBProgramState: + call = event.registers.get(self.arch.get_syscall_reg()) - post_event = self._deterministic_events.update_to_next() syscall = emulated_system_calls[self.arch.archname].get(call, None) - debug(f'Handling event:\n{cur_event}') - if syscall is not None: + if syscall is not None and False: info(f'Replaying system call number {hex(call)}') self.skip(post_event.pc) @@ -189,15 +193,14 @@ class GDBServerStateIterator: raise StopIteration return GDBProgramState(self._process, gdb.selected_frame(), self.arch) - def _handle_event(self) -> GDBProgramState: - current_event = self._deterministic_events.current_event() - if not current_event: + def _handle_event(self, event: Event | None, post_event: Event | None) -> GDBProgramState: + if not event: return self.current_state() - if isinstance(current_event, SyscallEvent): - return self._handle_syscall() + if isinstance(event, SyscallEvent): + return self._handle_syscall(event, post_event) - warn(f'Event handling for events of type {current_event.event_type} not implemented') + warn(f'Event handling for events of type {event.event_type} not implemented') return self.current_state() def _is_exited(self) -> bool: @@ -213,11 +216,12 @@ class GDBServerStateIterator: self._first_next = False return GDBProgramState(self._process, gdb.selected_frame(), self.arch) - if match_event(self._deterministic_events.current_event(), self.current_state()): - state = self._handle_event() + event, post_event, _ = self._log_matcher.match_pair(self.current_state()) + if event: + state = self._handle_event(event, post_event) if self._is_exited(): raise StopIteration - self._deterministic_events.update_to_next() + return state # Step @@ -233,7 +237,7 @@ class GDBServerStateIterator: def run_until(self, addr: int) -> ReadableProgramState: events_handled = 0 - event = self._deterministic_events.current_event() + event, _ = self._log_matcher.next() while event: state = self._run_until_any([addr, event.pc]) if state.read_pc() == addr: @@ -241,9 +245,10 @@ class GDBServerStateIterator: self._first_next = events_handled == 0 return state - self._handle_event() + event, post_event, _ = self._log_matcher.match_pair(self.current_state()) + self._handle_event(event, post_event) - event = self._deterministic_events.update_to_next() + event, _ = self._log_matcher.next() events_handled += 1 return self._run_until_any([addr]) @@ -375,7 +380,7 @@ def collect_conc_trace(gdb: GDBServerStateIterator, \ if logger.isEnabledFor(logging.DEBUG): debug('Tracing program with the following non-deterministic events:') - for event in gdb._deterministic_events.events(): + for event in gdb._log_matcher.events(): debug(event) # Skip to start |