summary refs log tree commit diff stats
path: root/python/qemu
diff options
context:
space:
mode:
Diffstat (limited to 'python/qemu')
-rw-r--r--python/qemu/machine/machine.py17
-rw-r--r--python/qemu/qmp/legacy.py26
-rw-r--r--python/qemu/qmp/protocol.py45
3 files changed, 41 insertions, 47 deletions
diff --git a/python/qemu/machine/machine.py b/python/qemu/machine/machine.py
index e57c254484..c16a0b6fed 100644
--- a/python/qemu/machine/machine.py
+++ b/python/qemu/machine/machine.py
@@ -337,18 +337,18 @@ class QEMUMachine:
             self._remove_files.append(self._console_address)
 
         if self._qmp_set:
-            monitor_address = None
-            sock = None
             if self._monitor_address is None:
                 self._sock_pair = socket.socketpair()
                 sock = self._sock_pair[1]
             if isinstance(self._monitor_address, str):
                 self._remove_files.append(self._monitor_address)
-                monitor_address = self._monitor_address
+
+            sock_or_addr = self._monitor_address or sock
+            assert sock_or_addr is not None
+
             self._qmp_connection = QEMUMonitorProtocol(
-                address=monitor_address,
-                sock=sock,
-                server=True,
+                sock_or_addr,
+                server=bool(self._monitor_address),
                 nickname=self._name
             )
 
@@ -370,7 +370,10 @@ class QEMUMachine:
         if self._sock_pair:
             self._sock_pair[0].close()
         if self._qmp_connection:
-            self._qmp.accept(self._qmp_timer)
+            if self._sock_pair:
+                self._qmp.connect()
+            else:
+                self._qmp.accept(self._qmp_timer)
 
     def _close_qemu_log_file(self) -> None:
         if self._qemu_log_file is not None:
diff --git a/python/qemu/qmp/legacy.py b/python/qemu/qmp/legacy.py
index 8b09ee7dbb..e1e9383978 100644
--- a/python/qemu/qmp/legacy.py
+++ b/python/qemu/qmp/legacy.py
@@ -68,34 +68,31 @@ class QEMUMonitorProtocol:
     Provide an API to connect to QEMU via QEMU Monitor Protocol (QMP)
     and then allow to handle commands and events.
 
-    :param address:  QEMU address, can be either a unix socket path (string)
-                     or a tuple in the form ( address, port ) for a TCP
-                     connection or None
-    :param sock:     a socket or None
+    :param address:  QEMU address, can be a unix socket path (string), a tuple
+                     in the form ( address, port ) for a TCP connection, or an
+                     existing `socket.socket` object.
     :param server:   Act as the socket server. (See 'accept')
+                     Not applicable when passing a socket directly.
     :param nickname: Optional nickname used for logging.
     """
 
     def __init__(self,
-                 address: Optional[SocketAddrT] = None,
-                 sock: Optional[socket.socket] = None,
+                 address: Union[SocketAddrT, socket.socket],
                  server: bool = False,
                  nickname: Optional[str] = None):
 
-        assert address or sock
+        if server and isinstance(address, socket.socket):
+            raise ValueError(
+                "server argument should be False when passing a socket")
+
         self._qmp = QMPClient(nickname)
         self._aloop = asyncio.get_event_loop()
         self._address = address
-        self._sock = sock
         self._timeout: Optional[float] = None
 
         if server:
-            if sock:
-                assert self._sock is not None
-                self._sync(self._qmp.open_with_socket(self._sock))
-            else:
-                assert self._address is not None
-                self._sync(self._qmp.start_server(self._address))
+            assert not isinstance(self._address, socket.socket)
+            self._sync(self._qmp.start_server(self._address))
 
     _T = TypeVar('_T')
 
@@ -150,7 +147,6 @@ class QEMUMonitorProtocol:
         :return: QMP greeting dict, or None if negotiate is false
         :raise ConnectError: on connection errors
         """
-        assert self._address is not None
         self._qmp.await_greeting = negotiate
         self._qmp.negotiate = negotiate
 
diff --git a/python/qemu/qmp/protocol.py b/python/qemu/qmp/protocol.py
index 22e60298d2..753182131f 100644
--- a/python/qemu/qmp/protocol.py
+++ b/python/qemu/qmp/protocol.py
@@ -299,19 +299,6 @@ class AsyncProtocol(Generic[T]):
 
     @upper_half
     @require(Runstate.IDLE)
-    async def open_with_socket(self, sock: socket.socket) -> None:
-        """
-        Start connection with given socket.
-
-        :param sock: A socket.
-
-        :raise StateError: When the `Runstate` is not `IDLE`.
-        """
-        self._reader, self._writer = await asyncio.open_connection(sock=sock)
-        self._set_state(Runstate.CONNECTING)
-
-    @upper_half
-    @require(Runstate.IDLE)
     async def start_server(self, address: SocketAddrT,
                            ssl: Optional[SSLContext] = None) -> None:
         """
@@ -357,12 +344,11 @@ class AsyncProtocol(Generic[T]):
             protocol-level failure occurs while establishing a new
             session, the wrapped error may also be an `QMPError`.
         """
-        if not self._reader:
-            if self._accepted is None:
-                raise QMPError("Cannot call accept() before start_server().")
-            await self._session_guard(
-                self._do_accept(),
-                'Failed to establish connection')
+        if self._accepted is None:
+            raise QMPError("Cannot call accept() before start_server().")
+        await self._session_guard(
+            self._do_accept(),
+            'Failed to establish connection')
         await self._session_guard(
             self._establish_session(),
             'Failed to establish session')
@@ -370,7 +356,7 @@ class AsyncProtocol(Generic[T]):
 
     @upper_half
     @require(Runstate.IDLE)
-    async def connect(self, address: SocketAddrT,
+    async def connect(self, address: Union[SocketAddrT, socket.socket],
                       ssl: Optional[SSLContext] = None) -> None:
         """
         Connect to the server and begin processing message queues.
@@ -615,7 +601,7 @@ class AsyncProtocol(Generic[T]):
         self.logger.debug("Connection accepted.")
 
     @upper_half
-    async def _do_connect(self, address: SocketAddrT,
+    async def _do_connect(self, address: Union[SocketAddrT, socket.socket],
                           ssl: Optional[SSLContext] = None) -> None:
         """
         Acting as the transport client, initiate a connection to a server.
@@ -634,9 +620,17 @@ class AsyncProtocol(Generic[T]):
         # otherwise yield.
         await asyncio.sleep(0)
 
-        self.logger.debug("Connecting to %s ...", address)
-
-        if isinstance(address, tuple):
+        if isinstance(address, socket.socket):
+            self.logger.debug("Connecting with existing socket: "
+                              "fd=%d, family=%r, type=%r",
+                              address.fileno(), address.family, address.type)
+            connect = asyncio.open_connection(
+                limit=self._limit,
+                ssl=ssl,
+                sock=address,
+            )
+        elif isinstance(address, tuple):
+            self.logger.debug("Connecting to %s ...", address)
             connect = asyncio.open_connection(
                 address[0],
                 address[1],
@@ -644,13 +638,14 @@ class AsyncProtocol(Generic[T]):
                 limit=self._limit,
             )
         else:
+            self.logger.debug("Connecting to file://%s ...", address)
             connect = asyncio.open_unix_connection(
                 path=address,
                 ssl=ssl,
                 limit=self._limit,
             )
-        self._reader, self._writer = await connect
 
+        self._reader, self._writer = await connect
         self.logger.debug("Connected.")
 
     @upper_half