summary refs log tree commit diff stats
path: root/python/qemu/qmp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/qemu/qmp/protocol.py')
-rw-r--r--python/qemu/qmp/protocol.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/python/qemu/qmp/protocol.py b/python/qemu/qmp/protocol.py
index 22e60298d2..d534db4631 100644
--- a/python/qemu/qmp/protocol.py
+++ b/python/qemu/qmp/protocol.py
@@ -370,7 +370,7 @@ class AsyncProtocol(Generic[T]):
 
     @upper_half
     @require(Runstate.IDLE)
-    async def connect(self, address: SocketAddrT,
+    async def connect(self, address: Union[SocketAddrT, socket.socket],
                       ssl: Optional[SSLContext] = None) -> None:
         """
         Connect to the server and begin processing message queues.
@@ -615,7 +615,7 @@ class AsyncProtocol(Generic[T]):
         self.logger.debug("Connection accepted.")
 
     @upper_half
-    async def _do_connect(self, address: SocketAddrT,
+    async def _do_connect(self, address: Union[SocketAddrT, socket.socket],
                           ssl: Optional[SSLContext] = None) -> None:
         """
         Acting as the transport client, initiate a connection to a server.
@@ -634,9 +634,17 @@ class AsyncProtocol(Generic[T]):
         # otherwise yield.
         await asyncio.sleep(0)
 
-        self.logger.debug("Connecting to %s ...", address)
-
-        if isinstance(address, tuple):
+        if isinstance(address, socket.socket):
+            self.logger.debug("Connecting with existing socket: "
+                              "fd=%d, family=%r, type=%r",
+                              address.fileno(), address.family, address.type)
+            connect = asyncio.open_connection(
+                limit=self._limit,
+                ssl=ssl,
+                sock=address,
+            )
+        elif isinstance(address, tuple):
+            self.logger.debug("Connecting to %s ...", address)
             connect = asyncio.open_connection(
                 address[0],
                 address[1],
@@ -644,13 +652,14 @@ class AsyncProtocol(Generic[T]):
                 limit=self._limit,
             )
         else:
+            self.logger.debug("Connecting to file://%s ...", address)
             connect = asyncio.open_unix_connection(
                 path=address,
                 ssl=ssl,
                 limit=self._limit,
             )
-        self._reader, self._writer = await connect
 
+        self._reader, self._writer = await connect
         self.logger.debug("Connected.")
 
     @upper_half