about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTheofilos Augoustis <theofilos.augoustis@gmail.com>2025-11-20 13:17:38 +0000
committerTheofilos Augoustis <theofilos.augoustis@gmail.com>2025-11-20 17:26:00 +0000
commit0036438d2aae660d035770b41bcd996ce3e6738f (patch)
tree3643c376661995511f517ed5294263d1c20f62a6
parent8ce65a76eb158470285065d21f47a3dd1db3db59 (diff)
downloadfocaccia-0036438d2aae660d035770b41bcd996ce3e6738f.tar.gz
focaccia-0036438d2aae660d035770b41bcd996ce3e6738f.zip
Match and context switch
-rw-r--r--src/focaccia/deterministic.py17
-rw-r--r--src/focaccia/qemu/_qemu_tool.py87
2 files changed, 63 insertions, 41 deletions
diff --git a/src/focaccia/deterministic.py b/src/focaccia/deterministic.py
index 4fcc222..58d9fd9 100644
--- a/src/focaccia/deterministic.py
+++ b/src/focaccia/deterministic.py
@@ -337,17 +337,18 @@ finally:
                 return self.events[self.matched_count]
             return None
 
-        def match_pair(self, state: ReadableProgramState):
-            event = self.match(state)
-            if event is None:
-                return None, None
-            if isinstance(event, SyscallEvent) and event.syscall_state == 'exiting':
-                self.matched_count = None
-                return None, None
+        def match_pair(self, event: Event | None):
+            if event is None or not isinstance(event, SyscallEvent):
+                return None
             assert(self.matched_count is not None)
             post_event = self.events[self.matched_count]
             self.matched_count += 1
-            return event, post_event
+            return post_event
+
+        def unmatch(self, count: int = 1) -> None:
+            if self.matched_count is None:
+                raise ValueError('Cannot get unmatch event with unsynchronized event matcher')
+            self.matched_count -= count
 
         def __bool__(self) -> bool:
             return len(self.events) > 0
diff --git a/src/focaccia/qemu/_qemu_tool.py b/src/focaccia/qemu/_qemu_tool.py
index 7c3ccc2..5dd76b5 100644
--- a/src/focaccia/qemu/_qemu_tool.py
+++ b/src/focaccia/qemu/_qemu_tool.py
@@ -141,8 +141,8 @@ class GDBServerStateIterator:
         gdb.execute('set pagination 0')
         gdb.execute('set sysroot')
         gdb.execute('set python print-stack full') # enable complete Python tracebacks
-        gdb.execute('set scheduler-locking on')
         gdb.execute(f'target remote {remote}')
+        gdb.execute('set scheduler-locking on')
         self._deterministic_log = deterministic_log
         self._process = gdb.selected_inferior()
         self._first_next = True
@@ -179,9 +179,12 @@ class GDBServerStateIterator:
                                     skipped_events=skipped_events)
         event = self._events.match(first_state)
         
+        self._thread_count = 1
         self._current_event_id = event.tid
         self._thread_map = {
-            self._current_event_id: self.current_tid()
+            self._current_event_id: (self.current_tid(), self._thread_count)
+        }
+        self._thread_context = {
         }
         info(f'Synchronized at PC={hex(first_state.read_pc())} to event:\n{event}')
         debug(f'Thread mapping at this point: {hex(event.tid)}: {hex(self.current_tid())}')
@@ -189,7 +192,7 @@ class GDBServerStateIterator:
     def current_state(self) -> ReadableProgramState:
         return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-    def _handle_syscall(self, event: Event, post_event: Event) -> GDBProgramState:
+    def _handle_syscall(self, event: Event, post_event: Event) -> ReadableProgramState:
         call = event.registers.get(self.arch.get_syscall_reg())
         next_state = None
 
@@ -229,10 +232,11 @@ class GDBServerStateIterator:
             if syscall.creates_thread:
                 new_tid = self.current_state().read_register(self.arch.get_syscall_reg())
                 event_new_tid = post_event.registers[self.arch.get_syscall_reg()]
-                self._thread_map[event_new_tid] = new_tid
+                self._thread_count += 1
+                self._thread_map[event_new_tid] = (new_tid, self._thread_count)
                 info(f'New thread created TID={hex(new_tid)} corresponds to native {hex(event.tid)}')
                 debug('Thread mapping at this point:')
-                for event_tid, tid in self._thread_map.items():
+                for event_tid, (tid, num) in self._thread_map.items():
                     debug(f'{hex(event_tid)}: {hex(tid)}')
 
             next_state = GDBProgramState(self._process, gdb.selected_frame(), self.arch)
@@ -244,24 +248,42 @@ class GDBServerStateIterator:
                 raise StopIteration
             next_state = GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-        # Context switch
-        if post_event.tid != self._current_event_id:
-            self._current_event_id = post_event.tid
-            tid = self._thread_map[self._current_event_id]
-            self.context_switch(tid)
-            debug(f'Scheduled native TID {post_event.tid} as {tid}')
-
         return next_state
 
-    def _handle_event(self, event: Event | None, post_event: Event | None) -> GDBProgramState:
+    def _handle_event(self) -> ReadableProgramState | None:
+        event = self._events.match(self.current_state())       
+
         if not event:
-            return self.current_state()
+            return None
 
         if isinstance(event, SyscallEvent):
+            post_event = self._events.match_pair(event)
+            assert(post_event is not None)
+
+            # Context switch
+            # TODO: handle return from pre-empt
+            if post_event.tid != self._current_event_id:
+                self._thread_context[self._current_event_id] = event
+                self._current_event_id = post_event.tid
+                tid, num = self._thread_map[self._current_event_id]
+                self.context_switch(num)
+                state = self.current_state()
+                debug(f'Scheduled native TID {post_event.tid} as {tid}')
+
+                if self._current_event_id in self._thread_context:
+                    event = self._thread_context.pop(self._current_event_id)
+                elif match_event(post_event, state):
+                    event = post_event
+                    post_event = self._events.match_pair(event)
+                else:
+                    self._events.unmatch()
+                    self._step()
+                    return self.current_state()
+
             return self._handle_syscall(event, post_event)
 
         warn(f'Event handling for events of type {event.event_type} not implemented')
-        return self.current_state()
+        return None
 
     def _is_exited(self) -> bool:
         return not self._process.is_valid() or len(self._process.threads()) == 0
@@ -276,24 +298,22 @@ class GDBServerStateIterator:
             self._first_next = False
             return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-        event, post_event = self._events.match_pair(self.current_state())
-        if event:
-            state = self._handle_event(event, post_event)
-            if self._is_exited():
-                raise StopIteration
+        state = self._handle_event()
+        if self._is_exited():
+            raise StopIteration
 
-            return state
+        if not state:
+            # Step
+            pc = gdb.selected_frame().read_register('pc')
+            new_pc = pc
+            while pc == new_pc:  # Skip instruction chains from REP STOS etc.
+                self._step()
+                if self._is_exited():
+                    raise StopIteration
+                new_pc = gdb.selected_frame().read_register('pc')
+            state = self.current_state()
 
-        # Step
-        pc = gdb.selected_frame().read_register('pc')
-        new_pc = pc
-        while pc == new_pc:  # Skip instruction chains from REP STOS etc.
-            self._step()
-            if self._is_exited():
-                raise StopIteration
-            new_pc = gdb.selected_frame().read_register('pc')
-
-        return self.current_state()
+        return state
 
     def run_until(self, addr: int) -> ReadableProgramState:
         events_handled = 0
@@ -305,8 +325,9 @@ class GDBServerStateIterator:
                 self._first_next = events_handled == 0
                 return state
 
-            event, post_event = self._events.match_pair(self.current_state())
-            self._handle_event(event, post_event)
+            self._handle_event()
+            if self._is_exited():
+                raise Exception(f'Exited before reaching start address {hex(addr)}')
 
             event = self._events.next()
             events_handled += 1