about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTheofilos Augoustis <37243696+taugoust@users.noreply.github.com>2025-11-11 09:53:46 +0100
committerGitHub <noreply@github.com>2025-11-11 09:53:46 +0100
commit61d5d618756d11a02e621ee03835f65941c44645 (patch)
tree13b13f8107a00e5ed02923245ee0ed19fd960319
parentb4114d32eab237c1b4f3e96544a40243097bf632 (diff)
parent08afe9decbd6047e06133394c7c408d6f2c092db (diff)
downloadfocaccia-61d5d618756d11a02e621ee03835f65941c44645.tar.gz
focaccia-61d5d618756d11a02e621ee03835f65941c44645.zip
Merge pull request #22 from TUM-DSE/ta/rr-mmaps
Implement complete RR log support
-rw-r--r--src/focaccia/deterministic.py525
1 files changed, 473 insertions, 52 deletions
diff --git a/src/focaccia/deterministic.py b/src/focaccia/deterministic.py
index 5a2b411..e7914a3 100644
--- a/src/focaccia/deterministic.py
+++ b/src/focaccia/deterministic.py
@@ -1,8 +1,9 @@
 """Parsing of JSON files containing snapshot data."""
 
 import os
-import itertools
-from typing import Union, Iterable
+import io
+import struct
+from typing import Union, Optional
 
 import brotli
 
@@ -22,6 +23,80 @@ TaskEvent = rr_trace.TaskEvent
 MMap = rr_trace.MMap
 SerializedObject = Union[Frame, TaskEvent, MMap]
 
+class DeterministicLogReader(io.RawIOBase):
+    """
+    File-like reader for rr trace files.
+
+    Each block in the file:
+      uint32_t uncompressed_size
+      uint32_t compressed_size
+      [compressed_data...]
+    Presents the concatenated uncompressed data as a sequential byte stream.
+    """
+
+    _HDR = struct.Struct("<II")
+
+    def __init__(self, filename: str):
+        super().__init__()
+        self._f = open(filename, "rb", buffering=0)
+        self._data_buffer = memoryview(b"")
+        self._pos = 0
+        self._eof = False
+
+    def _load_chunk(self) -> None:
+        """Load and decompress the next Brotli block."""
+        header = self._f.read(self._HDR.size)
+        if not header:
+            self._eof = True
+            self._data_buffer = memoryview(b"")
+            return
+        if len(header) != self._HDR.size:
+            raise EOFError("Incomplete RR data block header")
+
+        compressed_length, uncompressed_length = self._HDR.unpack(header)
+        chunk = self._f.read(compressed_length)
+        if len(chunk) != compressed_length:
+            raise EOFError("Incomplete RR data block")
+
+        chunk = brotli.decompress(chunk)
+        if len(chunk) != uncompressed_length:
+            raise Exception(f'Malformed deterministic log: uncompressed chunk is not equal'
+                            f'to reported length {hex(uncompressed_length)}')
+
+        self._data_buffer = memoryview(chunk)
+        self._pos = 0
+
+    def read(self, n: Optional[int] = -1) -> bytes:
+        """Read up to n bytes from the uncompressed stream."""
+        if n == 0:
+            return b""
+
+        chunks = bytearray()
+        remaining = n if n is not None and n >= 0 else None
+
+        while not self._eof and (remaining is None or remaining > 0):
+            if self._pos >= len(self._data_buffer):
+                self._load_chunk()
+                if self._eof:
+                    break
+
+            available = len(self._data_buffer) - self._pos
+            take = available if remaining is None else min(available, remaining)
+            chunks += self._data_buffer[self._pos:self._pos + take]
+            self._pos += take
+            if remaining is not None:
+                remaining -= take
+
+        return bytes(chunks)
+
+    def readable(self) -> bool:
+        return True
+
+    def close(self) -> None:
+        if not self.closed:
+            self._f.close()
+        super().close()
+
 def parse_x64_registers(enc_regs: bytes, signed: bool=False) -> dict[str, int]:
     idx = 0
     def parse_reg():
@@ -93,21 +168,51 @@ def parse_aarch64_registers(enc_regs: bytes, order: str='little', signed: bool=F
     
     return regs
 
+class MemoryWriteHole:
+    def __init__(self, offset: int, size: int):
+        self.offset = offset
+        self.size = size
+        if self.size <= 0:
+            raise ValueError(f'Write hole cannot have size {size}')
+
+    def __repr__(self) -> str:
+        return f'hole at {hex(self.offset)}:{hex(self.offset+self.size)}'
+
+class MemoryWrite:
+    def __init__(self,
+                 tid: int,
+                 address: int,
+                 size: int,
+                 holes: list[MemoryWriteHole],
+                 is_conservative: bool,
+                 data: bytes | None = None):
+        self.tid = tid
+        self.address = address
+        self.size = size
+        self.holes = holes
+        self.is_conservative = is_conservative
+        self.data = data
+
+    def __repr__(self) -> str:
+        return f'{{ tid: {hex(self.tid)}, addr: {hex(self.address)}:{hex(self.address+self.size)}\n' \
+               f'   conservative? {self.is_conservative}, holes: {self.holes}\n' \
+               f'   data: {self.data} }}'
+
 class Event:
     def __init__(self,
                  pc: int,
                  tid: int,
                  arch: Arch,
-                 event_type: str,
                  registers: dict[str, int],
-                 memory_writes: dict[int, int]):
+                 memory_writes: list[MemoryWrite],
+                 event_type: str):
         self.pc = pc
         self.tid = tid
         self.arch = arch
-        self.event_type = event_type
 
         self.registers = registers
         self.mem_writes = memory_writes
+        self.event_type = event_type
 
     def match(self, pc: int, target: ReadableProgramState) -> bool:
         # TODO: match the rest of the state to be sure
@@ -122,13 +227,13 @@ class Event:
         return False
 
     def __repr__(self) -> str:
-        reg_repr = ''
+        reg_repr = f'{self.event_type} event\n'
         for reg, value in self.registers.items():
             reg_repr += f'{reg} = {hex(value)}\n'
 
         mem_write_repr = ''
-        for addr, size in self.mem_writes.items():
-            mem_write_repr += f'{hex(addr)}:{hex(addr+size)}\n'
+        for mem_write in self.mem_writes:
+            mem_write_repr += f'{mem_write}\n'
 
         repr_str = f'Thread {hex(self.tid)} executed event {self.event_type} at {hex(self.pc)}\n'
         repr_str += f'Register set:\n{reg_repr}'
@@ -138,6 +243,210 @@ class Event:
 
         return repr_str
 
+class SyscallBufferFlushEvent(Event):
+    def __init__(self,
+                 pc: int,
+                 tid: int,
+                 arch: Arch,
+                 registers: dict[str, int],
+                 memory_writes: list[MemoryWrite],
+                 mprotect_records: bytes):
+        super().__init__(pc, tid, arch, registers, memory_writes, 'syscallBufFlush')
+        self.mprotect_records = mprotect_records
+
+    def __repr__(self):
+        return f'{super().__repr__()}\nmprotect_records = {self.mprotect_records}'
+
+class SyscallExtra:
+    def __init__(self,
+                 write_offset: int | None,
+                 exec_fds_to_close: list[int] | None,
+                 opened_fds: list[int] | None,
+                 socket_local_address: bytes,
+                 socket_remote_address: bytes):
+        self.write_offset = write_offset
+        self.exec_fds_to_close = exec_fds_to_close
+        self.opened_fds = opened_fds
+        self.socket_local_address = socket_local_address
+        self.socket_remote_address = socket_remote_address
+
+class SyscallEvent(Event):
+    def __init__(self,
+                 pc: int,
+                 tid: int,
+                 arch: Arch,
+                 registers: dict[str, int],
+                 memory_writes: list[MemoryWrite],
+                 syscall_arch: Arch,
+                 syscall_number: int,
+                 syscall_state: str,
+                 failed_during_preparation: bool,
+                 syscall_extras: SyscallExtra | None = None):
+        super().__init__(pc, tid, arch, registers, memory_writes, 'syscall')
+        self.syscall_arch = syscall_arch
+        self.syscall_number = syscall_number
+        self.syscall_state = syscall_state
+        self.failed_during_preparation = failed_during_preparation
+        self.syscall_extras = syscall_extras
+
+        if syscall_state not in ['entering', 'exiting', 'enteringPtrace']:
+            raise NotImplementedError(f'Cannot handle system call state of type: {syscall_state}')
+
+    def __repr__(self) -> str:
+        return f'{super().__repr__()}\n' \
+               f'system call architecture = {self.syscall_arch}\n' \
+               f'system call number = {hex(self.syscall_number)}\n' \
+               f'system call state = {self.syscall_state}\n' \
+               f'failed during preparation? {self.failed_during_preparation}\n' \
+               f'syscall extras: {self.syscall_extras}\n'
+
+class SignalDescriptor:
+    def __init__(self,
+                 arch: Arch,
+                 siginfo: bytes,
+                 deterministic: bool,
+                 disposition: str):
+        self.arch = arch
+        self.siginfo = siginfo
+        self.deterministic = deterministic
+        self.disposition = disposition
+
+        if self.disposition not in ['fatal', 'userHandler', 'ignored']:
+            raise NotImplementedError(f'Canot handle signal dispositions of type'
+                                      f' {self.disposition}')
+
+    def __repr__(self) -> str:
+        return f'signal architecture: {self.arch}\n' \
+               f'siginfo data:\n{self.siginfo}\n' \
+               f'deterministic? {self.deterministic}\n' \
+               f'disposition: {self.disposition}\n'
+
+class SignalEvent(Event):
+    def __init__(self,
+                 pc: int,
+                 tid: int,
+                 arch: Arch,
+                 registers: dict[str, int],
+                 memory_writes: list[MemoryWrite],
+                 signal_number: SignalDescriptor | None = None,
+                 signal_delivery: SignalDescriptor | None = None,
+                 signal_handler: SignalDescriptor | None = None):
+        super().__init__(pc, tid, arch, registers, memory_writes, 'signal')
+        self.signal_number = signal_number
+        self.signal_delivery = signal_delivery
+        self.signal_handler = signal_handler
+
+        if [self.signal_number, self.signal_delivery, self.signal_handler].count(None) != 1:
+            raise ValueError(f'A signal event may be either a signal number, delivery or handler event')
+
+    def __repr__(self) -> str:
+        repr_str = f'{super().__repr__()}\n'
+        if self.signal_number:
+            return repr_str + '{self.signal_number}'
+        if self.signal_delivery:
+            return repr_str + '{self.signal_delivery}'
+        if self.signal_handler:
+            return repr_str + '{self.signal_handler}'
+
+class MemoryMapping:
+    def __init__(self,
+                 event_count: int,
+                 start_address: int,
+                 end_address: int,
+                 source: str,
+                 offset: int,
+                 mmap_prot: int,
+                 mmap_flags: int):
+        self.event_count = event_count
+        self.start_address = start_address
+        self.length = end_address - self.start_address
+        self.source = source
+        self.offset = offset
+        self.mmap_prot = mmap_prot
+        self.mmap_flags = mmap_flags
+
+    def __repr__(self) -> str:
+        return f'Memory mapping at event {self.event_count}\n' \
+               f'start = {hex(self.start_address)}\n' \
+               f'length = {self.length}\n' \
+               f'source = {self.source}\n' \
+               f'offset = {self.offset}\n' \
+               f'mmap_prot = {hex(self.mmap_prot)}\n' \
+               f'mmap_flags = {hex(self.mmap_flags)}'
+
+class Task:
+    def __init__(self,
+                 event_count: int,
+                 tid: int):
+        self.event_count = event_count
+        self.tid = tid
+
+    def __repr__(self) -> str:
+        return f'For event index {self.event_count} at tid = {hex(self.tid)}'
+
+class CloneTask(Task):
+    def __init__(self,
+                 event_count: int,
+                 tid: int,
+                 parent_tid: int,
+                 clone_flags: int,
+                 own_namespace_tid: int):
+        super().__init__(event_count, tid)
+        self.parent_tid = parent_tid
+        self.clone_flags = clone_flags
+        self.own_namespace_tid = own_namespace_tid
+
+    def __repr__(self) -> str:
+        return f'Clone task\n{super().__repr__()}\n' \
+               f'parent tid = {hex(self.parent_tid)}\n' \
+               f'clone flags = {hex(self.clone_flags)}\n' \
+               f'own namespace tid = {hex(self.own_namespace_tid)}'
+
+class ExecTask(Task):
+    def __init__(self,
+                 event_count: int,
+                 tid: int,
+                 filename: str,
+                 commandline: list[str],
+                 execution_base_address: int,
+                 interpreter_base_address: int,
+                 interpreter_name: str):
+        super().__init__(event_count, tid)
+        self.filename = filename
+        self.commandline = commandline
+        self.execution_base_address = execution_base_address
+        self.interpreter_base_address = interpreter_base_address
+        self.interpreter_name = interpreter_name
+
+    def __repr__(self) -> str:
+        return f'Exec task\n{super().__repr__()}\n' \
+               f'filename = {self.filename}\n' \
+               f'command-line = {self.commandline}\n' \
+               f'execution base address = {hex(self.execution_base_address)}\n' \
+               f'interpereter base address = {hex(self.interpreter_base_address)}\n' \
+               f'interpreter name = {self.interpreter_name}'
+
+class ExitTask(Task):
+    def __init__(self,
+                 event_count: int,
+                 tid: int,
+                 exit_status: int):
+        super().__init__(event_count, tid)
+        self.exit_status = exit_status
+
+    def __repr__(self) -> str:
+        return f'Exit task\n{super().__repr__()}\n' \
+               f'exit status = {hex(self.exit_status)}'
+
+class DetachTask(Task):
+    def __init__(self,
+                 event_count: int,
+                 tid: int):
+        super().__init__(event_count, tid)
+
+    def __repr__(self) -> str:
+        return f'Detach task\n{super().__repr__()}'
+
 class DeterministicLog:
     def __init__(self, log_dir: str):
         self.base_directory = log_dir
@@ -151,39 +460,25 @@ class DeterministicLog:
     def mmaps_file(self) -> str:
         return os.path.join(self.base_directory, 'mmaps')
 
-    def _read(self, file, obj: SerializedObject) -> list[SerializedObject]:
-        data = bytearray()
-        objects = []
-        with open(file, 'rb') as f:
-            while True:
-                try:
-                    compressed_len = int.from_bytes(f.read(4), byteorder='little')
-                    uncompressed_len = int.from_bytes(f.read(4), byteorder='little')
-                except Exception as e:
-                    raise Exception(f'Malformed deterministic log: {e}') from None
-
-                chunk = f.read(compressed_len)
-                if not chunk:
-                    break
+    def data_file(self) -> str:
+        return os.path.join(self.base_directory, 'data')
 
-                chunk = brotli.decompress(chunk)
-                if len(chunk) != uncompressed_len:
-                    raise Exception(f'Malformed deterministic log: uncompressed chunk is not equal'
-                                    f'to reported length {hex(uncompressed_len)}')
-                data.extend(chunk)
+    def _read_structure(self, file, obj: SerializedObject) -> list[SerializedObject]:
+        data = DeterministicLogReader(file).read()
 
-            for deser in obj.read_multiple_bytes_packed(data):
-                objects.append(deser)
-            return objects
+        objects = []
+        for deser in obj.read_multiple_bytes_packed(data):
+            objects.append(deser)
+        return objects
 
-    def raw_events(self) -> list[SerializedObject]:
-        return self._read(self.events_file(), Frame)
+    def raw_events(self) -> list[Frame]:
+        return self._read_structure(self.events_file(), Frame)
 
-    def raw_tasks(self) -> list[SerializedObject]:
-        return self._read(self.tasks_file(), TaskEvent)
+    def raw_tasks(self) -> list[TaskEvent]:
+        return self._read_structure(self.tasks_file(), TaskEvent)
 
-    def raw_mmaps(self) -> list[SerializedObject]:
-        return self._read(self.mmaps_file(), MMap)
+    def raw_mmaps(self) -> list[MMap]:
+        return self._read_structure(self.mmaps_file(), MMap)
 
     def events(self) -> list[Event]:
         def parse_registers(event: Frame) -> Union[int, dict[str, int]]:
@@ -195,32 +490,158 @@ class DeterministicLog:
                 regs = parse_aarch64_registers(event.registers.raw)
                 return regs['pc'], regs
             raise NotImplementedError(f'Unable to parse registers for architecture {arch}')
-    
-        def parse_memory_writes(event: Frame) -> dict[int, int]:
-            writes = {}
+
+        def parse_memory_writes(event: Frame, reader: io.RawIOBase) -> list[MemoryWrite]:
+            writes = []
             for raw_write in event.memWrites:
-                writes[int(raw_write.addr)] = int(raw_write.size)
+                # Skip memory writes with 0 bytes
+                if raw_write.size == 0:
+                    continue
+
+                holes = []
+                for raw_hole in raw_write.holes:
+                    holes.append(MemoryWriteHole(raw_hole.offset, raw_hole.size))
+
+                data = bytearray()
+                for hole in holes:
+                    until_hole = hole.offset - reader.tell()
+                    data.extend(reader.read(until_hole))
+                    data.extend(b'\x00' * hole.size)
+
+                # No holes
+                if len(data) == 0:
+                    data = reader.read(raw_write.size)
+
+                mem_write = MemoryWrite(raw_write.tid,
+                                        raw_write.addr,
+                                        raw_write.size,
+                                        holes,
+                                        raw_write.sizeIsConservative,
+                                        bytes(data))
+                writes.append(mem_write)
             return writes
 
+        data_reader = DeterministicLogReader(self.data_file())
+
         events = []
         raw_events = self.raw_events()
         for raw_event in raw_events:
             pc, registers = parse_registers(raw_event)
-            mem_writes = parse_memory_writes(raw_event)
+            mem_writes = parse_memory_writes(raw_event, data_reader)
+
+            event = None
 
+            tid = raw_event.tid
+            arch = raw_event.arch
             event_type = raw_event.event.which()
-            if event_type == 'syscall' and raw_event.arch == rr_trace.Arch.x8664:
-                # On entry: substitute orig_rax for RAX
-                if raw_event.event.syscall.state == rr_trace.SyscallState.entering:
-                    registers['rax'] = registers['orig_rax']
-                del registers['orig_rax']
-
-            event = Event(pc,
-                          raw_event.tid,
-                          raw_event.arch,
-                          event_type,
-                          registers, mem_writes)
+
+            if event_type == 'syscall': 
+                if raw_event.arch == rr_trace.Arch.x8664:
+                    # On entry: substitute orig_rax for RAX
+                    if raw_event.event.syscall.state == rr_trace.SyscallState.entering:
+                        registers['rax'] = registers['orig_rax']
+                    del registers['orig_rax']
+                event = SyscallEvent(pc,
+                                     tid,
+                                     arch,
+                                     registers,
+                                     mem_writes,
+                                     raw_event.event.syscall.arch,
+                                     raw_event.event.syscall.number,
+                                     raw_event.event.syscall.state,
+                                     raw_event.event.syscall.failedDuringPreparation)
+
+            if event_type == 'syscallbufFlush':
+                event = SyscallBufferFlushEvent(pc,
+                                                tid,
+                                                arch,
+                                                registers,
+                                                mem_writes,
+                                                raw_event.event.syscallbufFlush.mprotectRecords)
+                raise NotImplementedError(f'Cannot support system call buffer events yet: {event}')
+            if event_type == 'signal':
+                signal = raw_event.event.signal
+                signal_descriptor = SignalDescriptor(signal.arch,
+                                                     signal.siginfo,
+                                                     signal.deterministic,
+                                                     signal.disposition)
+                event = SignalEvent(pc, tid, arch, registers, mem_writes, 
+                                    signal_number=signal_descriptor)
+
+            if event_type == 'signalDelivery':
+                signal = raw_event.event.signalDelivery
+                signal_descriptor = SignalDescriptor(signal.arch,
+                                                     signal.siginfo,
+                                                     signal.deterministic,
+                                                     signal.disposition)
+                event = SignalEvent(pc, tid, arch, registers, mem_writes, 
+                                    signal_delivery=signal_descriptor)
+
+            if event_type == 'signalHandler':
+                signal = raw_event.event.signalHandler
+                signal_descriptor = SignalDescriptor(signal.arch,
+                                                     signal.siginfo,
+                                                     signal.deterministic,
+                                                     signal.disposition)
+                event = SignalEvent(pc, tid, arch, registers, mem_writes, 
+                                    signal_handler=signal_descriptor)
+
+            if event is None:
+                event = Event(pc, tid, arch, registers, mem_writes, event_type)
+
             events.append(event)
 
         return events
 
+    def tasks(self) -> list[Task]:
+        tasks = []
+        raw_tasks = self.raw_tasks()
+        for raw_task in raw_tasks:
+            task_type = raw_task.which()
+
+            task = None
+            if task_type == 'clone':
+                task = CloneTask(raw_task.frameTime,
+                                 raw_task.tid,
+                                 raw_task.clone.parentTid,
+                                 raw_task.clone.flags,
+                                 raw_task.clone.ownNsTid)
+            if task_type == 'exec':
+                task = ExecTask(raw_task.frameTime,
+                                raw_task.tid,
+                                raw_task.exec.fileName,
+                                raw_task.exec.cmdLine,
+                                raw_task.exec.exeBase,
+                                raw_task.exec.interpBase,
+                                raw_task.exec.interpName)
+            if task_type == 'exit':
+                task = ExitTask(raw_task.frameTime, raw_task.tid, raw_task.exit.exitStatus)
+            if task_type == 'detach':
+                task = DetachTask(raw_task.frameTime, raw_task.tid)
+            tasks.append(task)
+        return tasks
+
+    def mmaps(self) -> list[MemoryMapping]:
+        def mapping_source(mmap: MMap) -> str:
+            source_type = mmap.source.which()
+            if source_type == 'zero' or source_type == 'trace':
+                return source_type
+            elif source_type == 'file':
+                return mmap.source.file.backingFileName
+            else:
+                raise NotImplementedError(f'Unable to handle memory mappings from source type:'
+                                          f' {source_type}')
+
+        mmaps = []
+        raw_mmaps = self.raw_mmaps()
+        for raw_mmap in raw_mmaps:
+            mmap = MemoryMapping(raw_mmap.frameTime,
+                                 raw_mmap.start,
+                                 raw_mmap.end,
+                                 mapping_source(raw_mmap),
+                                 raw_mmap.fileOffsetBytes,
+                                 raw_mmap.prot,
+                                 raw_mmap.flags)
+            mmaps.append(mmap)
+        return mmaps
+