about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/focaccia/_deterministic_impl.py30
-rw-r--r--src/focaccia/deterministic.py52
-rw-r--r--src/focaccia/native/tracer.py52
-rw-r--r--src/focaccia/snapshot.py7
4 files changed, 95 insertions, 46 deletions
diff --git a/src/focaccia/_deterministic_impl.py b/src/focaccia/_deterministic_impl.py
index 1d784cb..fc85b9a 100644
--- a/src/focaccia/_deterministic_impl.py
+++ b/src/focaccia/_deterministic_impl.py
@@ -3,7 +3,7 @@
 import os
 import io
 import struct
-from typing import Union, Optional
+from typing import Union, Tuple, Optional
 
 import brotli
 
@@ -217,14 +217,14 @@ class DeterministicLog:
         return self._read_structure(self.mmaps_file(), MMap)
 
     def events(self) -> list[Event]:
-        def parse_registers(event: Frame) -> Union[int, dict[str, int]]:
+        def parse_registers(event: Frame) -> Tuple[str, dict[str, int]]:
             arch = event.arch
             if arch == rr_trace.Arch.x8664:
                 regs = parse_x64_registers(event.registers.raw)
-                return regs['rip'], regs
+                return 'rip', regs
             if arch == rr_trace.Arch.aarch64:
                 regs = parse_aarch64_registers(event.registers.raw)
-                return regs['pc'], regs
+                return 'pc', regs
             raise NotImplementedError(f'Unable to parse registers for architecture {arch}')
 
         def parse_memory_writes(event: Frame, reader: io.RawIOBase) -> list[MemoryWrite]:
@@ -274,10 +274,18 @@ class DeterministicLog:
             if event_type == 'syscall': 
                 if raw_event.arch == rr_trace.Arch.x8664:
                     # On entry: substitute orig_rax for RAX
-                    if raw_event.event.syscall.state == rr_trace.SyscallState.entering:
+                    syscall = raw_event.event.syscall
+                    if syscall.state == rr_trace.SyscallState.entering:
                         registers['rax'] = registers['orig_rax']
+                        if syscall.number != 59:
+                            registers[pc] -= 2
                     del registers['orig_rax']
-                event = SyscallEvent(pc,
+                if raw_event.arch == rr_trace.Arch.aarch64:
+                    syscall = raw_event.event.syscall
+                    if syscall.state == rr_trace.SyscallState.entering and syscall.number != 221:
+                        registers[pc] -= 4
+
+                event = SyscallEvent(registers[pc],
                                      tid,
                                      arch,
                                      registers,
@@ -288,7 +296,7 @@ class DeterministicLog:
                                      raw_event.event.syscall.failedDuringPreparation)
 
             if event_type == 'syscallbufFlush':
-                event = SyscallBufferFlushEvent(pc,
+                event = SyscallBufferFlushEvent(registers[pc],
                                                 tid,
                                                 arch,
                                                 registers,
@@ -301,7 +309,7 @@ class DeterministicLog:
                                                      signal.siginfo,
                                                      signal.deterministic,
                                                      signal.disposition)
-                event = SignalEvent(pc, tid, arch, registers, mem_writes, 
+                event = SignalEvent(registers[pc], tid, arch, registers, mem_writes, 
                                     signal_number=signal_descriptor)
 
             if event_type == 'signalDelivery':
@@ -310,7 +318,7 @@ class DeterministicLog:
                                                      signal.siginfo,
                                                      signal.deterministic,
                                                      signal.disposition)
-                event = SignalEvent(pc, tid, arch, registers, mem_writes, 
+                event = SignalEvent(registers[pc], tid, arch, registers, mem_writes, 
                                     signal_delivery=signal_descriptor)
 
             if event_type == 'signalHandler':
@@ -319,11 +327,11 @@ class DeterministicLog:
                                                      signal.siginfo,
                                                      signal.deterministic,
                                                      signal.disposition)
-                event = SignalEvent(pc, tid, arch, registers, mem_writes, 
+                event = SignalEvent(registers[pc], tid, arch, registers, mem_writes, 
                                     signal_handler=signal_descriptor)
 
             if event is None:
-                event = Event(pc, tid, arch, registers, mem_writes, event_type)
+                event = Event(registers[pc], tid, arch, registers, mem_writes, event_type)
 
             events.append(event)
 
diff --git a/src/focaccia/deterministic.py b/src/focaccia/deterministic.py
index 2a15430..91afd4e 100644
--- a/src/focaccia/deterministic.py
+++ b/src/focaccia/deterministic.py
@@ -1,4 +1,7 @@
 from .arch import Arch
+from .snapshot import ReadableProgramState
+
+from typing import Callable
 
 class MemoryWriteHole:
     def __init__(self, offset: int, size: int):
@@ -280,4 +283,53 @@ except Exception:
         def events(self) -> list[Event]: return []
         def tasks(self) -> list[Task]: return []
         def mmaps(self) -> list[MemoryMapping]: return []
+finally:
+    class DeterministicEventIterator:
+        def __init__(self, deterministic_log: DeterministicLog, match_fn: Callable):
+            self._detlog = deterministic_log
+            self._events = self._detlog.events()
+            self._pc_to_event = {}
+            self._match = match_fn
+            self._idx: int | None = None # None represents no current event
+            self._in_event: bool = False
+
+            idx = 0
+            for event in self._events:
+                self._pc_to_event.setdefault(event.pc, []).append((event, idx))
+
+        def events(self) -> list[Event]:
+            return self._events
+
+        def current_event(self) -> Event | None:
+            # No event when not synchronized
+            if self._idx is None or not self._in_event:
+                return None
+            return self._events[self._idx]
+
+        def update(self, target: ReadableProgramState) -> Event | None:
+            # Quick check
+            candidates = self._pc_to_event.get(target.read_pc(), [])
+            if len(candidates) == 0:
+                self._in_event = False
+                return None
+
+            # Find synchronization point
+            if self._idx is None:
+                for event, idx in candidates:
+                    if self._match(event, target):
+                        self._idx = idx
+                        self._in_event = True
+                        return self.current_event()
+
+            return self.next()
+
+        def next(self) -> Event | None:
+            if self._idx is None:
+                raise ValueError('Attempted to get next event without synchronizing')
+
+            self._idx += 1
+            return self.current_event()
+
+        def __bool__(self) -> bool:
+            return len(self.events()) > 0
 
diff --git a/src/focaccia/native/tracer.py b/src/focaccia/native/tracer.py
index 47ac7e2..b369b22 100644
--- a/src/focaccia/native/tracer.py
+++ b/src/focaccia/native/tracer.py
@@ -12,7 +12,7 @@ from focaccia.trace import Trace, TraceEnvironment
 from focaccia.miasm_util import MiasmSymbolResolver
 from focaccia.snapshot import ReadableProgramState, RegisterAccessError
 from focaccia.symbolic import SymbolicTransform, DisassemblyContext, run_instruction
-from focaccia.deterministic import Event
+from focaccia.deterministic import Event, DeterministicEventIterator
 
 from .lldb_target import LLDBConcreteTarget, LLDBLocalTarget, LLDBRemoteTarget
 
@@ -27,9 +27,9 @@ logging.getLogger('asmblock').setLevel(logging.CRITICAL)
 class ValidationError(Exception):
     pass
 
-def match_event(event: Event, pc: int, target: ReadableProgramState) -> bool:
+def match_event(event: Event, target: ReadableProgramState) -> bool:
     # TODO: match the rest of the state to be sure
-    if event.pc == pc:
+    if event.pc == target.read_pc():
         for reg, value in event.registers.items():
             if value == event.pc:
                 continue
@@ -154,8 +154,7 @@ class SymbolicTracer:
         self.cross_validate = cross_validate
         self.target = SpeculativeTracer(self.create_debug_target())
 
-        self.nondet_events = self.env.detlog.events()
-        self.next_event: int | None = None
+        self.nondet_events = DeterministicEventIterator(self.env.detlog, match_event)
 
     def create_debug_target(self) -> LLDBConcreteTarget:
         binary = self.env.binary_name
@@ -209,30 +208,22 @@ class SymbolicTracer:
                                       f' mem[{hex(addr)}:{hex(addr+len(data))}] = {conc_data}.'
                                       f'\nFaulty transformation: {transform}')
 
-    def progress_event(self) -> None:
-        if (self.next_event + 1) < len(self.nondet_events):
-            self.next_event += 1
-            debug(f'Next event to handle at index {self.next_event}')
-        else:
-            self.next_event = None
-
     def post_event(self) -> None:
-        if self.next_event:
-            if self.nondet_events[self.next_event].pc == 0:
+        current_event = self.nondet_events.current_event()
+        if current_event:
+            if current_event.pc == 0:
                 # Exit sequence
                 debug('Completed exit event')
                 self.target.run()
 
-            debug(f'Completed handling event at index {self.next_event}')
-            self.progress_event()
+            debug(f'Completed handling event: {current_event}')
+            self.nondet_events.next()
 
-    def is_stepping_instr(self, pc: int, instruction: Instruction) -> bool:
-        if self.nondet_events:
-            pc = pc + instruction.length # detlog reports next pc for each event
-            if self.next_event and match_event(self.nondet_events[self.next_event], pc, self.target):
-                debug('Current instruction matches next event; stepping through it')
-                self.progress_event()
-                return True
+    def is_stepping_instr(self, instruction: Instruction) -> bool:
+        if self.nondet_events.current_event():
+            debug('Current instruction matches next event; stepping through it')
+            self.nondet_events.next()
+            return True
         else:
             if self.target.arch.is_instr_syscall(str(instruction)):
                 return True
@@ -257,21 +248,12 @@ class SymbolicTracer:
         if self.env.start_address is not None:
             self.target.run_until(self.env.start_address)
 
-        for i in range(len(self.nondet_events)):
-            if self.nondet_events[i].pc == self.target.read_pc():
-                self.next_event = i+1
-                if self.next_event >= len(self.nondet_events):
-                    break
-
-                debug(f'Starting from event {self.nondet_events[i]} onwards')
-                break
-
         ctx = DisassemblyContext(self.target)
         arch = ctx.arch
 
         if logger.isEnabledFor(logging.DEBUG):
             debug('Tracing program with the following non-deterministic events')
-            for event in self.nondet_events:
+            for event in self.nondet_events.events():
                 debug(event)
 
         # Trace concolically
@@ -282,7 +264,7 @@ class SymbolicTracer:
             if self.env.stop_address is not None and pc == self.env.stop_address:
                 break
 
-            assert(pc != 0)
+            self.nondet_events.update(self.target)
 
             # Disassemble instruction at the current PC
             tid = self.target.get_current_tid()
@@ -310,7 +292,7 @@ class SymbolicTracer:
                         continue
                     raise # forward exception
 
-            is_event = self.is_stepping_instr(pc, instruction)
+            is_event = self.is_stepping_instr(instruction)
 
             # Run instruction
             conc_state = MiasmSymbolResolver(self.target, ctx.loc_db)
diff --git a/src/focaccia/snapshot.py b/src/focaccia/snapshot.py
index 03a03cd..f40ac5a 100644
--- a/src/focaccia/snapshot.py
+++ b/src/focaccia/snapshot.py
@@ -92,6 +92,13 @@ class ReadableProgramState:
         self.arch = arch
         self.strict = True
 
+    def read_pc(self) -> int:
+        """Read the PC value.
+
+        :raise RegisterAccessError: If the register has not value.
+        """
+        return self.read_register('pc')
+
     def read_register(self, reg: str) -> int:
         """Read a register's value.