about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--run.py329
1 files changed, 148 insertions, 181 deletions
diff --git a/run.py b/run.py
index bae909d..8a5bac2 100644
--- a/run.py
+++ b/run.py
@@ -6,220 +6,158 @@ import select
 
 import ptrace.debugger
 from ptrace.debugger import (
+    ProcessExit,
     ProcessEvent,
     ProcessSignal,
-    ProcessExit,
     NewProcessEvent,
+    ProcessExecution,
 )
 
-
-# ----------------------------------------------------------------------
-# Configuration
-# ----------------------------------------------------------------------
-
 # If scheduler does not provide input within this time (seconds),
 # continue running the last chosen thread.
 SCHED_TIMEOUT = 0
 
-# ----------------------------------------------------------------------
-# Scheduler (non-blocking)
-# ----------------------------------------------------------------------
+class Scheduler:
+    def __init__(self, sched_socket_path: str = '/tmp/memcached_scheduler.sock'):
+        self.debugger = ptrace.debugger.PtraceDebugger()
+        self.debugger.traceClone()
+        self.debugger.traceFork()
+        self.debugger.traceExec()
 
-def schedule_next_nonblocking(sock, processes, current_proc):
-    """
-    processes: dict[tid] -> PtraceProcess
-    current_proc: PtraceProcess or None
-    """
-    timeout = SCHED_TIMEOUT if SCHED_TIMEOUT > 0 else 0
+        self._first_clone_ignored = False
+        self._ignored_tid = None
 
-    r, _, _ = select.select([sock], [], [], timeout)
-    if not r:
-        return current_proc  # no input → continue with current
+        if os.path.exists(sched_socket_path):
+            os.unlink(sched_socket_path)
 
-    data = sock.recv(64)
-    if not data:
-        return current_proc
+        self.srv = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        self.srv.bind(sched_socket_path)
+        self.srv.listen(1)
 
-    try:
-        tid = int(data.strip())
-    except ValueError:
-        print(f"Scheduler: invalid data {data!r}")
-        return current_proc
+        print(f"Waiting for scheduler connection on {sched_socket_path}")
+        self.conn, _ = self.srv.accept()
+        print("Scheduler connected")
 
-    proc = processes.get(tid)
-    if proc is not None:
-        print(f"Scheduler picked TID {tid}")
-        return proc
 
-    print(f"Scheduler sent inactive TID {tid}, ignoring")
-    return current_proc
+    def _next(self, processes, current_proc):
+        """
+        processes: dict[tid] -> PtraceProcess
+        current_proc: PtraceProcess or None
+        """
+        timeout = SCHED_TIMEOUT if SCHED_TIMEOUT > 0 else 0
 
+        r, _, _ = select.select([self.conn], [], [], timeout)
+        if not r:
+            return current_proc  # no input → continue with current
 
-# ----------------------------------------------------------------------
-# Main tracing logic
-# ----------------------------------------------------------------------
+        data = self.conn.recv(8)
+        if not data:
+            return current_proc
 
-def trace(pid, sched_socket_path):
-    debugger = ptrace.debugger.PtraceDebugger()
-    debugger.traceClone()
-    debugger.traceFork()
-    debugger.traceExec()
-
-    print(f"Attach process {pid}")
-    proc0 = debugger.addProcess(pid, False)
-
-    # ------------------------------------------------------------------
-    # Create scheduler socket
-    # ------------------------------------------------------------------
-    if os.path.exists(sched_socket_path):
-        os.unlink(sched_socket_path)
-
-    srv = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-    srv.bind(sched_socket_path)
-    srv.listen(1)
-
-    print(f"Waiting for scheduler connection on {sched_socket_path}")
-    conn, _ = srv.accept()
-    print("Scheduler connected")
-
-    # ------------------------------------------------------------------
-    # Initial state
-    # ------------------------------------------------------------------
-    current_proc = proc0
-    first_clone_ignored = False
-    ignored_tid = None
-
-    # Arm the first process: run until its first event/syscall
-    current_proc.syscall()
-
-    # ------------------------------------------------------------------
-    # Global event loop
-    # ------------------------------------------------------------------
-    while debugger.list:
         try:
-            event: ProcessEvent = debugger.waitSyscall()
-
-        # --------------------------------------------------------------
-        # New process / thread via clone/fork/vfork
-        # --------------------------------------------------------------
-        except NewProcessEvent as ev:
-            child = ev.process
-            parent = child.parent
-            child_tid = child.pid
+            tid = int.from_bytes(data, byteorder='little', signed=False)
+        except ValueError:
+            print(f"Scheduler: invalid data {data!r}")
+            return current_proc
 
-            if not first_clone_ignored:
-                # FIRST clone is ignored
-                first_clone_ignored = True
-                ignored_tid = child_tid
+        proc = processes.get(tid)
+        if proc is not None:
+            print(f"Scheduler picked TID {tid}")
+            return proc
 
-                print(f"First clone: created TID {child_tid} — IGNORING it")
-
-                # Detach ignored child so it runs untraced
-                try:
-                    child.detach()
-                except Exception:
-                    pass
+        print(f"Scheduler sent inactive TID {tid}, ignoring")
+        return current_proc
 
-                # Remove from debugger
-                try:
-                    debugger.deleteProcess(child)
-                except Exception:
-                    pass
+    def _handle_signal(self, event: ProcessSignal):
+        proc: PtraceProcess = event.process
+        try:
+            proc.syscall(event.signum)
+        except Exception as e:
+            print(f"Error arming TID {proc.pid} after signal {event.signum}: {e}")
+            try:
+                self.debugger.deleteProcess(proc)
+            except Exception:
+                pass
 
-                # Resume parent so clone() completes
-                try:
-                    parent.syscall()
-                except Exception as e:
-                    print(f"Error resuming parent {parent.pid} after ignored clone: {e}")
-            else:
-                # LATER clones are traced
-                print(f"New traced thread {child_tid} (parent {parent.pid})")
-
-                # Arm both child and parent again
-                try:
-                    child.syscall()
-                except Exception as e:
-                    print(f"Error arming child {child_tid}: {e}")
-                    try:
-                        debugger.deleteProcess(child)
-                    except Exception:
-                        pass
+    def _handle_clone(self, event: NewProcessEvent):
+        child = event.process
+        parent = child.parent
+        child_tid = child.pid
 
-                try:
-                    parent.syscall()
-                except Exception as e:
-                    print(f"Error arming parent {parent.pid}: {e}")
-                    try:
-                        debugger.deleteProcess(parent)
-                    except Exception:
-                        pass
-
-            continue
-
-        # --------------------------------------------------------------
-        # Signal delivered to a traced task
-        # --------------------------------------------------------------
-        except ProcessSignal as ev:
-            proc: PtraceProcess = ev.process
-            try:
-                proc.syscall(ev.signum)
-            except Exception as e:
-                print(f"Error arming TID {proc.pid} after signal {ev.signum}: {e}")
-                try:
-                    debugger.deleteProcess(proc)
-                except Exception:
-                    pass
-            continue
+        if not self._first_clone_ignored:
+            # FIRST clone is ignored
+            self._first_clone_ignored = True
+            self._ignored_tid = child_tid
 
-        # --------------------------------------------------------------
-        # A traced task exited
-        # --------------------------------------------------------------
-        except ProcessExit as ev:
-            dead_proc: PtraceProcess = ev.process
-            tid: int = dead_proc.pid
-            print(f"TID {tid} exited (exitcode={ev.exitcode})")
+            print(f"First clone: created TID {child_tid} — IGNORING it")
 
+            # Detach ignored child so it runs untraced
             try:
-                dead_proc.detach()
+                child.detach()
             except Exception:
                 pass
 
+            # Remove from debugger
             try:
-                debugger.deleteProcess(dead_proc)
+                self.debugger.deleteProcess(child)
             except Exception:
                 pass
 
-            if not debugger.list:
-                break
+            # Resume parent so clone() completes
+            try:
+                parent.syscall()
+            except Exception as e:
+                print(f"Error resuming parent {parent.pid} after ignored clone: {e}")
+        else:
+            # LATER clones are traced
+            print(f"New traced thread {child_tid} (parent {parent.pid})")
+
+            # Arm both child and parent again
+            try:
+                child.syscall()
+            except Exception as e:
+                print(f"Error arming child {child_tid}: {e}")
+                try:
+                    self.debugger.deleteProcess(child)
+                except Exception:
+                    pass
+
+            try:
+                parent.syscall()
+            except Exception as e:
+                print(f"Error arming parent {parent.pid}: {e}")
+                try:
+                    self.debugger.deleteProcess(parent)
+                except Exception:
+                    pass
+
+    def _handle_exit(self, event: ProcessExit):
+        dead_proc: PtraceProcess = event.process
+        tid: int = dead_proc.pid
+        print(f"TID {tid} exited (exitcode={event.exitcode})")
 
-            # We do NOT arm anything new here; others may already be armed.
-            continue
+        dead_proc.detach()
+        self.debugger.deleteProcess(dead_proc)
 
-        # --------------------------------------------------------------
-        # NORMAL SYSCALL STOP
-        # --------------------------------------------------------------
+    def _handle_syscall(self, event: ProcessExecution):
         proc = event.process
         tid = proc.pid
 
-        if tid == ignored_tid:
-            # Should not happen; just log and continue
-            print(f"WARNING: ignored TID {tid} hit a syscall-stop")
-        else:
-            try:
-                ip = proc.getInstrPointer()
-                print(f"TID {tid} syscall-stop at {hex(ip)}")
-            except Exception as e:
-                print(f"Error reading IP for TID {tid}: {e}")
+        try:
+            ip = proc.getInstrPointer()
+            print(f"TID {tid} syscall-stop at {hex(ip)}")
+        except Exception as e:
+            print(f"Error reading IP for TID {tid}: {e}")
 
         # Build fresh pid->process map from debugger.list
-        processes = {p.pid: p for p in debugger.list}
+        processes = {p.pid: p for p in self.debugger.list}
 
         # Scheduler decides what to run next
-        current_proc = schedule_next_nonblocking(conn, processes, proc)
-        if current_proc is None or current_proc not in debugger.list:
-            if not debugger.list:
-                break
-            current_proc = debugger.list[0]
+        current_proc = self._next(processes, proc)
+        if current_proc is None or current_proc not in self.debugger.list:
+            if self.is_exited():
+                return
+            current_proc = self.debugger.list[0]
 
         # Resume chosen thread
         try:
@@ -227,14 +165,44 @@ def trace(pid, sched_socket_path):
         except Exception as e:
             print(f"Error arming TID {current_proc.pid}: {e}")
             try:
-                debugger.deleteProcess(current_proc)
+                self.debugger.deleteProcess(current_proc)
             except Exception:
                 pass
-            # Don't immediately re-arm; next loop will pick another if any.
 
-    conn.close()
-    srv.close()
-    debugger.quit()
+    def is_exited(self):
+        return len(self.debugger.list) == 0
+
+    def schedule(self, pid: int):
+        print(f"Attach process {pid}")
+        proc0 = self.debugger.addProcess(pid, False)
+
+        # ------------------------------------------------------------------
+        # Initial state
+        # ------------------------------------------------------------------
+        current_proc = proc0
+
+        # Arm the first process: run until its first event/syscall
+        current_proc.syscall()
+
+        while not self.is_exited():
+            try:
+                event: ProcessEvent = self.debugger.waitSyscall()
+            except NewProcessEvent as event:
+                self._handle_clone(event)
+                continue
+            except ProcessSignal as event:
+                self._handle_signal(event)
+                continue
+            except ProcessExit as event:
+                self._handle_exit(event)
+                continue
+
+            self._handle_syscall(event)
+
+    def __del__(self):
+        self.conn.close()
+        self.srv.close()
+        self.debugger.quit()
 
 
 # ----------------------------------------------------------------------
@@ -260,15 +228,14 @@ if __name__ == "__main__":
         "-vv"
     ]
 
-    sched_path = "/tmp/memcached_scheduler.sock"
-
     proc = subprocess.Popen(qemu, env=env)
     try:
-        trace(proc.pid, sched_path)
+        scheduler = Scheduler()
+        scheduler.schedule(proc.pid)
     except Exception as e:
-        print(f"Got exception: {e}")
+        print(f"Scheduling failed: {e}")
         proc.kill()
-        exit(2)
+        raise
 
     exit(0)