summary refs log tree commit diff stats
path: root/python/qemu/console_socket.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/qemu/console_socket.py')
-rw-r--r--python/qemu/console_socket.py54
1 files changed, 23 insertions, 31 deletions
diff --git a/python/qemu/console_socket.py b/python/qemu/console_socket.py
index 70869fbbdc..f060d79e06 100644
--- a/python/qemu/console_socket.py
+++ b/python/qemu/console_socket.py
@@ -13,10 +13,11 @@ which can drain a socket and optionally dump the bytes to file.
 # the COPYING file in the top-level directory.
 #
 
+from collections import deque
 import socket
 import threading
-from collections import deque
 import time
+from typing import Deque, Optional
 
 
 class ConsoleSocket(socket.socket):
@@ -29,22 +30,22 @@ class ConsoleSocket(socket.socket):
     Optionally a file path can be passed in and we will also
     dump the characters to this file for debugging purposes.
     """
-    def __init__(self, address, file=None, drain=False):
-        self._recv_timeout_sec = 300
+    def __init__(self, address: str, file: Optional[str] = None,
+                 drain: bool = False):
+        self._recv_timeout_sec = 300.0
         self._sleep_time = 0.5
-        self._buffer = deque()
+        self._buffer: Deque[int] = deque()
         socket.socket.__init__(self, socket.AF_UNIX, socket.SOCK_STREAM)
         self.connect(address)
         self._logfile = None
         if file:
-            self._logfile = open(file, "w")
+            self._logfile = open(file, "bw")
         self._open = True
+        self._drain_thread = None
         if drain:
             self._drain_thread = self._thread_start()
-        else:
-            self._drain_thread = None
 
-    def _drain_fn(self):
+    def _drain_fn(self) -> None:
         """Drains the socket and runs while the socket is open."""
         while self._open:
             try:
@@ -55,7 +56,7 @@ class ConsoleSocket(socket.socket):
                 # self._open is set to False.
                 time.sleep(self._sleep_time)
 
-    def _thread_start(self):
+    def _thread_start(self) -> threading.Thread:
         """Kick off a thread to drain the socket."""
         # Configure socket to not block and timeout.
         # This allows our drain thread to not block
@@ -67,7 +68,7 @@ class ConsoleSocket(socket.socket):
         drain_thread.start()
         return drain_thread
 
-    def close(self):
+    def close(self) -> None:
         """Close the base object and wait for the thread to terminate"""
         if self._open:
             self._open = False
@@ -79,51 +80,42 @@ class ConsoleSocket(socket.socket):
                 self._logfile.close()
                 self._logfile = None
 
-    def _drain_socket(self):
+    def _drain_socket(self) -> None:
         """process arriving characters into in memory _buffer"""
         data = socket.socket.recv(self, 1)
-        # latin1 is needed since there are some chars
-        # we are receiving that cannot be encoded to utf-8
-        # such as 0xe2, 0x80, 0xA6.
-        string = data.decode("latin1")
         if self._logfile:
-            self._logfile.write("{}".format(string))
+            self._logfile.write(data)
             self._logfile.flush()
-        for c in string:
-            self._buffer.extend(c)
+        self._buffer.extend(data)
 
-    def recv(self, bufsize=1):
+    def recv(self, bufsize: int = 1, flags: int = 0) -> bytes:
         """Return chars from in memory buffer.
            Maintains the same API as socket.socket.recv.
         """
         if self._drain_thread is None:
             # Not buffering the socket, pass thru to socket.
-            return socket.socket.recv(self, bufsize)
+            return socket.socket.recv(self, bufsize, flags)
+        assert not flags, "Cannot pass flags to recv() in drained mode"
         start_time = time.time()
         while len(self._buffer) < bufsize:
             time.sleep(self._sleep_time)
             elapsed_sec = time.time() - start_time
             if elapsed_sec > self._recv_timeout_sec:
                 raise socket.timeout
-        chars = ''.join([self._buffer.popleft() for i in range(bufsize)])
-        # We choose to use latin1 to remain consistent with
-        # handle_read() and give back the same data as the user would
-        # receive if they were reading directly from the
-        # socket w/o our intervention.
-        return chars.encode("latin1")
+        return bytes((self._buffer.popleft() for i in range(bufsize)))
 
-    def setblocking(self, value):
+    def setblocking(self, value: bool) -> None:
         """When not draining we pass thru to the socket,
            since when draining we control socket blocking.
         """
         if self._drain_thread is None:
             socket.socket.setblocking(self, value)
 
-    def settimeout(self, seconds):
+    def settimeout(self, value: Optional[float]) -> None:
         """When not draining we pass thru to the socket,
            since when draining we control the timeout.
         """
-        if seconds is not None:
-            self._recv_timeout_sec = seconds
+        if value is not None:
+            self._recv_timeout_sec = value
         if self._drain_thread is None:
-            socket.socket.settimeout(self, seconds)
+            socket.socket.settimeout(self, value)