about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/focaccia/qemu/_qemu_tool.py31
1 files changed, 20 insertions, 11 deletions
diff --git a/src/focaccia/qemu/_qemu_tool.py b/src/focaccia/qemu/_qemu_tool.py
index b5008e1..7c3ccc2 100644
--- a/src/focaccia/qemu/_qemu_tool.py
+++ b/src/focaccia/qemu/_qemu_tool.py
@@ -141,6 +141,7 @@ class GDBServerStateIterator:
         gdb.execute('set pagination 0')
         gdb.execute('set sysroot')
         gdb.execute('set python print-stack full') # enable complete Python tracebacks
+        gdb.execute('set scheduler-locking on')
         gdb.execute(f'target remote {remote}')
         self._deterministic_log = deterministic_log
         self._process = gdb.selected_inferior()
@@ -178,8 +179,9 @@ class GDBServerStateIterator:
                                     skipped_events=skipped_events)
         event = self._events.match(first_state)
         
+        self._current_event_id = event.tid
         self._thread_map = {
-            event.tid: self.current_tid()
+            self._current_event_id: self.current_tid()
         }
         info(f'Synchronized at PC={hex(first_state.read_pc())} to event:\n{event}')
         debug(f'Thread mapping at this point: {hex(event.tid)}: {hex(self.current_tid())}')
@@ -189,6 +191,7 @@ class GDBServerStateIterator:
 
     def _handle_syscall(self, event: Event, post_event: Event) -> GDBProgramState:
         call = event.registers.get(self.arch.get_syscall_reg())
+        next_state = None
 
         syscall = emulated_system_calls[self.arch.archname].get(call, None)
         if syscall is not None:
@@ -214,9 +217,6 @@ class GDBServerStateIterator:
                 for hole in mem.holes:
                     data[hole.offset:hole.offset] = b'\x00' * hole.size
                 self._process.write_memory(addr, data)
-                return next_state
-
-            return next_state
 
         syscall = passthrough_system_calls[self.arch.archname].get(call, None)
         if syscall is not None:
@@ -235,13 +235,23 @@ class GDBServerStateIterator:
                 for event_tid, tid in self._thread_map.items():
                     debug(f'{hex(event_tid)}: {hex(tid)}')
 
-            return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
+            next_state = GDBProgramState(self._process, gdb.selected_frame(), self.arch)
 
-        info(f'System call number {hex(call)} not replayed')
-        self._step()
-        if self._is_exited():
-            raise StopIteration
-        return GDBProgramState(self._process, gdb.selected_frame(), self.arch)
+        if not next_state:
+            info(f'System call number {hex(call)} not replayed')
+            self._step()
+            if self._is_exited():
+                raise StopIteration
+            next_state = GDBProgramState(self._process, gdb.selected_frame(), self.arch)
+
+        # Context switch
+        if post_event.tid != self._current_event_id:
+            self._current_event_id = post_event.tid
+            tid = self._thread_map[self._current_event_id]
+            self.context_switch(tid)
+            debug(f'Scheduled native TID {post_event.tid} as {tid}')
+
+        return next_state
 
     def _handle_event(self, event: Event | None, post_event: Event | None) -> GDBProgramState:
         if not event:
@@ -327,7 +337,6 @@ class GDBServerStateIterator:
 
     def context_switch(self, thread_number: int) -> None:
         gdb.execute(f'thread {thread_number}')
-        self._thread_num = thread_number
 
     def get_sections(self) -> list[MemoryMapping]:
         mappings = []