summary refs log tree commit diff stats
path: root/python/qemu/aqmp
diff options
context:
space:
mode:
Diffstat (limited to 'python/qemu/aqmp')
-rw-r--r--python/qemu/aqmp/legacy.py7
-rw-r--r--python/qemu/aqmp/protocol.py381
2 files changed, 238 insertions, 150 deletions
diff --git a/python/qemu/aqmp/legacy.py b/python/qemu/aqmp/legacy.py
index 6baa5f3409..46026e9fdc 100644
--- a/python/qemu/aqmp/legacy.py
+++ b/python/qemu/aqmp/legacy.py
@@ -57,7 +57,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
         self._timeout: Optional[float] = None
 
         if server:
-            self._aqmp._bind_hack(address)  # pylint: disable=protected-access
+            self._sync(self._aqmp.start_server(self._address))
 
     _T = TypeVar('_T')
 
@@ -90,10 +90,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
         self._aqmp.await_greeting = True
         self._aqmp.negotiate = True
 
-        self._sync(
-            self._aqmp.accept(self._address),
-            timeout
-        )
+        self._sync(self._aqmp.accept(), timeout)
 
         ret = self._get_greeting()
         assert ret is not None
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py
index 33358f5cd7..36fae57f27 100644
--- a/python/qemu/aqmp/protocol.py
+++ b/python/qemu/aqmp/protocol.py
@@ -10,12 +10,14 @@ In this package, it is used as the implementation for the `QMPClient`
 class.
 """
 
+# It's all the docstrings ... ! It's long for a good reason ^_^;
+# pylint: disable=too-many-lines
+
 import asyncio
 from asyncio import StreamReader, StreamWriter
 from enum import Enum
 from functools import wraps
 import logging
-import socket
 from ssl import SSLContext
 from typing import (
     Any,
@@ -239,8 +241,9 @@ class AsyncProtocol(Generic[T]):
         self._runstate = Runstate.IDLE
         self._runstate_changed: Optional[asyncio.Event] = None
 
-        # Workaround for bind()
-        self._sock: Optional[socket.socket] = None
+        # Server state for start_server() and _incoming()
+        self._server: Optional[asyncio.AbstractServer] = None
+        self._accepted: Optional[asyncio.Event] = None
 
     def __repr__(self) -> str:
         cls_name = type(self).__name__
@@ -265,21 +268,90 @@ class AsyncProtocol(Generic[T]):
 
     @upper_half
     @require(Runstate.IDLE)
-    async def accept(self, address: SocketAddrT,
-                     ssl: Optional[SSLContext] = None) -> None:
+    async def start_server_and_accept(
+            self, address: SocketAddrT,
+            ssl: Optional[SSLContext] = None
+    ) -> None:
         """
         Accept a connection and begin processing message queues.
 
         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
+        This method is precisely equivalent to calling `start_server()`
+        followed by `accept()`.
+
+        :param address:
+            Address to listen on; UNIX socket path or TCP address/port.
+        :param ssl: SSL context to use, if any.
+
+        :raise StateError: When the `Runstate` is not `IDLE`.
+        :raise ConnectError:
+            When a connection or session cannot be established.
+
+            This exception will wrap a more concrete one. In most cases,
+            the wrapped exception will be `OSError` or `EOFError`. If a
+            protocol-level failure occurs while establishing a new
+            session, the wrapped error may also be an `QMPError`.
+        """
+        await self.start_server(address, ssl)
+        await self.accept()
+        assert self.runstate == Runstate.RUNNING
+
+    @upper_half
+    @require(Runstate.IDLE)
+    async def start_server(self, address: SocketAddrT,
+                           ssl: Optional[SSLContext] = None) -> None:
+        """
+        Start listening for an incoming connection, but do not wait for a peer.
+
+        This method starts listening for an incoming connection, but
+        does not block waiting for a peer. This call will return
+        immediately after binding and listening on a socket. A later
+        call to `accept()` must be made in order to finalize the
+        incoming connection.
 
         :param address:
-            Address to listen to; UNIX socket path or TCP address/port.
+            Address to listen on; UNIX socket path or TCP address/port.
         :param ssl: SSL context to use, if any.
 
         :raise StateError: When the `Runstate` is not `IDLE`.
-        :raise ConnectError: If a connection could not be accepted.
+        :raise ConnectError:
+            When the server could not start listening on this address.
+
+            This exception will wrap a more concrete one. In most cases,
+            the wrapped exception will be `OSError`.
+        """
+        await self._session_guard(
+            self._do_start_server(address, ssl),
+            'Failed to establish connection')
+        assert self.runstate == Runstate.CONNECTING
+
+    @upper_half
+    @require(Runstate.CONNECTING)
+    async def accept(self) -> None:
+        """
+        Accept an incoming connection and begin processing message queues.
+
+        If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
+
+        :raise StateError: When the `Runstate` is not `CONNECTING`.
+        :raise QMPError: When `start_server()` was not called yet.
+        :raise ConnectError:
+            When a connection or session cannot be established.
+
+            This exception will wrap a more concrete one. In most cases,
+            the wrapped exception will be `OSError` or `EOFError`. If a
+            protocol-level failure occurs while establishing a new
+            session, the wrapped error may also be an `QMPError`.
         """
-        await self._new_session(address, ssl, accept=True)
+        if self._accepted is None:
+            raise QMPError("Cannot call accept() before start_server().")
+        await self._session_guard(
+            self._do_accept(),
+            'Failed to establish connection')
+        await self._session_guard(
+            self._establish_session(),
+            'Failed to establish session')
+        assert self.runstate == Runstate.RUNNING
 
     @upper_half
     @require(Runstate.IDLE)
@@ -295,9 +367,21 @@ class AsyncProtocol(Generic[T]):
         :param ssl: SSL context to use, if any.
 
         :raise StateError: When the `Runstate` is not `IDLE`.
-        :raise ConnectError: If a connection cannot be made to the server.
+        :raise ConnectError:
+            When a connection or session cannot be established.
+
+            This exception will wrap a more concrete one. In most cases,
+            the wrapped exception will be `OSError` or `EOFError`. If a
+            protocol-level failure occurs while establishing a new
+            session, the wrapped error may also be an `QMPError`.
         """
-        await self._new_session(address, ssl)
+        await self._session_guard(
+            self._do_connect(address, ssl),
+            'Failed to establish connection')
+        await self._session_guard(
+            self._establish_session(),
+            'Failed to establish session')
+        assert self.runstate == Runstate.RUNNING
 
     @upper_half
     async def disconnect(self) -> None:
@@ -317,153 +401,146 @@ class AsyncProtocol(Generic[T]):
     # Section: Session machinery
     # --------------------------
 
-    @property
-    def _runstate_event(self) -> asyncio.Event:
-        # asyncio.Event() objects should not be created prior to entrance into
-        # an event loop, so we can ensure we create it in the correct context.
-        # Create it on-demand *only* at the behest of an 'async def' method.
-        if not self._runstate_changed:
-            self._runstate_changed = asyncio.Event()
-        return self._runstate_changed
-
-    @upper_half
-    @bottom_half
-    def _set_state(self, state: Runstate) -> None:
-        """
-        Change the `Runstate` of the protocol connection.
-
-        Signals the `runstate_changed` event.
-        """
-        if state == self._runstate:
-            return
-
-        self.logger.debug("Transitioning from '%s' to '%s'.",
-                          str(self._runstate), str(state))
-        self._runstate = state
-        self._runstate_event.set()
-        self._runstate_event.clear()
-
-    @upper_half
-    async def _new_session(self,
-                           address: SocketAddrT,
-                           ssl: Optional[SSLContext] = None,
-                           accept: bool = False) -> None:
+    async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
         """
-        Establish a new connection and initialize the session.
+        Async guard function used to roll back to `IDLE` on any error.
 
-        Connect or accept a new connection, then begin the protocol
-        session machinery. If this call fails, `runstate` is guaranteed
-        to be set back to `IDLE`.
+        On any Exception, the state machine will be reset back to
+        `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
+        `BaseException` events will be left alone (This includes
+        asyncio.CancelledError, even prior to Python 3.8).
 
-        :param address:
-            Address to connect to/listen on;
-            UNIX socket path or TCP address/port.
-        :param ssl: SSL context to use, if any.
-        :param accept: Accept a connection instead of connecting when `True`.
+        :param error_message:
+            Human-readable string describing what connection phase failed.
 
+        :raise BaseException:
+            When `BaseException` occurs in the guarded block.
         :raise ConnectError:
-            When a connection or session cannot be established.
-
-            This exception will wrap a more concrete one. In most cases,
-            the wrapped exception will be `OSError` or `EOFError`. If a
-            protocol-level failure occurs while establishing a new
-            session, the wrapped error may also be an `QMPError`.
+            When any other error is encountered in the guarded block.
         """
-        assert self.runstate == Runstate.IDLE
-
+        # Note: After Python 3.6 support is removed, this should be an
+        # @asynccontextmanager instead of accepting a callback.
         try:
-            phase = "connection"
-            await self._establish_connection(address, ssl, accept)
-
-            phase = "session"
-            await self._establish_session()
-
+            await coro
         except BaseException as err:
-            emsg = f"Failed to establish {phase}"
             self.logger.error("%s: %s", emsg, exception_summary(err))
             self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
             try:
-                # Reset from CONNECTING back to IDLE.
+                # Reset the runstate back to IDLE.
                 await self.disconnect()
             except:
-                emsg = "Unexpected bottom half exception"
+                # We don't expect any Exceptions from the disconnect function
+                # here, because we failed to connect in the first place.
+                # The disconnect() function is intended to perform
+                # only cannot-fail cleanup here, but you never know.
+                emsg = (
+                    "Unexpected bottom half exception. "
+                    "This is a bug in the QMP library. "
+                    "Please report it to <qemu-devel@nongnu.org> and "
+                    "CC: John Snow <jsnow@redhat.com>."
+                )
                 self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
                 raise
 
+            # CancelledError is an Exception with special semantic meaning;
+            # We do NOT want to wrap it up under ConnectError.
             # NB: CancelledError is not a BaseException before Python 3.8
             if isinstance(err, asyncio.CancelledError):
                 raise
 
+            # Any other kind of error can be treated as some kind of connection
+            # failure broadly. Inspect the 'exc' field to explore the root
+            # cause in greater detail.
             if isinstance(err, Exception):
                 raise ConnectError(emsg, err) from err
 
             # Raise BaseExceptions un-wrapped, they're more important.
             raise
 
-        assert self.runstate == Runstate.RUNNING
+    @property
+    def _runstate_event(self) -> asyncio.Event:
+        # asyncio.Event() objects should not be created prior to entrance into
+        # an event loop, so we can ensure we create it in the correct context.
+        # Create it on-demand *only* at the behest of an 'async def' method.
+        if not self._runstate_changed:
+            self._runstate_changed = asyncio.Event()
+        return self._runstate_changed
 
     @upper_half
-    async def _establish_connection(
-            self,
-            address: SocketAddrT,
-            ssl: Optional[SSLContext] = None,
-            accept: bool = False
-    ) -> None:
+    @bottom_half
+    def _set_state(self, state: Runstate) -> None:
         """
-        Establish a new connection.
+        Change the `Runstate` of the protocol connection.
 
-        :param address:
-            Address to connect to/listen on;
-            UNIX socket path or TCP address/port.
-        :param ssl: SSL context to use, if any.
-        :param accept: Accept a connection instead of connecting when `True`.
+        Signals the `runstate_changed` event.
         """
-        assert self.runstate == Runstate.IDLE
-        self._set_state(Runstate.CONNECTING)
-
-        # Allow runstate watchers to witness 'CONNECTING' state; some
-        # failures in the streaming layer are synchronous and will not
-        # otherwise yield.
-        await asyncio.sleep(0)
+        if state == self._runstate:
+            return
 
-        if accept:
-            await self._do_accept(address, ssl)
-        else:
-            await self._do_connect(address, ssl)
+        self.logger.debug("Transitioning from '%s' to '%s'.",
+                          str(self._runstate), str(state))
+        self._runstate = state
+        self._runstate_event.set()
+        self._runstate_event.clear()
 
-    def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
+    @bottom_half
+    async def _stop_server(self) -> None:
+        """
+        Stop listening for / accepting new incoming connections.
         """
-        Used to create a socket in advance of accept().
+        if self._server is None:
+            return
 
-        This is a workaround to ensure that we can guarantee timing of
-        precisely when a socket exists to avoid a connection attempt
-        bouncing off of nothing.
+        try:
+            self.logger.debug("Stopping server.")
+            self._server.close()
+            await self._server.wait_closed()
+            self.logger.debug("Server stopped.")
+        finally:
+            self._server = None
 
-        Python 3.7+ adds a feature to separate the server creation and
-        listening phases instead, and should be used instead of this
-        hack.
+    @bottom_half  # However, it does not run from the R/W tasks.
+    async def _incoming(self,
+                        reader: asyncio.StreamReader,
+                        writer: asyncio.StreamWriter) -> None:
         """
-        if isinstance(address, tuple):
-            family = socket.AF_INET
-        else:
-            family = socket.AF_UNIX
+        Accept an incoming connection and signal the upper_half.
 
-        sock = socket.socket(family, socket.SOCK_STREAM)
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        This method does the minimum necessary to accept a single
+        incoming connection. It signals back to the upper_half ASAP so
+        that any errors during session initialization can occur
+        naturally in the caller's stack.
 
-        try:
-            sock.bind(address)
-        except:
-            sock.close()
-            raise
+        :param reader: Incoming `asyncio.StreamReader`
+        :param writer: Incoming `asyncio.StreamWriter`
+        """
+        peer = writer.get_extra_info('peername', 'Unknown peer')
+        self.logger.debug("Incoming connection from %s", peer)
+
+        if self._reader or self._writer:
+            # Sadly, we can have more than one pending connection
+            # because of https://bugs.python.org/issue46715
+            # Close any extra connections we don't actually want.
+            self.logger.warning("Extraneous connection inadvertently accepted")
+            writer.close()
+            return
 
-        self._sock = sock
+        # A connection has been accepted; stop listening for new ones.
+        assert self._accepted is not None
+        await self._stop_server()
+        self._reader, self._writer = (reader, writer)
+        self._accepted.set()
 
     @upper_half
-    async def _do_accept(self, address: SocketAddrT,
-                         ssl: Optional[SSLContext] = None) -> None:
+    async def _do_start_server(self, address: SocketAddrT,
+                               ssl: Optional[SSLContext] = None) -> None:
         """
-        Acting as the transport server, accept a single connection.
+        Start listening for an incoming connection, but do not wait for a peer.
+
+        This method starts listening for an incoming connection, but does not
+        block waiting for a peer. This call will return immediately after
+        binding and listening to a socket. A later call to accept() must be
+        made in order to finalize the incoming connection.
 
         :param address:
             Address to listen on; UNIX socket path or TCP address/port.
@@ -471,52 +548,54 @@ class AsyncProtocol(Generic[T]):
 
         :raise OSError: For stream-related errors.
         """
-        self.logger.debug("Awaiting connection on %s ...", address)
-        connected = asyncio.Event()
-        server: Optional[asyncio.AbstractServer] = None
-
-        async def _client_connected_cb(reader: asyncio.StreamReader,
-                                       writer: asyncio.StreamWriter) -> None:
-            """Used to accept a single incoming connection, see below."""
-            nonlocal server
-            nonlocal connected
-
-            # A connection has been accepted; stop listening for new ones.
-            assert server is not None
-            server.close()
-            await server.wait_closed()
-            server = None
-
-            # Register this client as being connected
-            self._reader, self._writer = (reader, writer)
+        assert self.runstate == Runstate.IDLE
+        self._set_state(Runstate.CONNECTING)
 
-            # Signal back: We've accepted a client!
-            connected.set()
+        self.logger.debug("Awaiting connection on %s ...", address)
+        self._accepted = asyncio.Event()
 
         if isinstance(address, tuple):
             coro = asyncio.start_server(
-                _client_connected_cb,
-                host=None if self._sock else address[0],
-                port=None if self._sock else address[1],
+                self._incoming,
+                host=address[0],
+                port=address[1],
                 ssl=ssl,
                 backlog=1,
                 limit=self._limit,
-                sock=self._sock,
             )
         else:
             coro = asyncio.start_unix_server(
-                _client_connected_cb,
-                path=None if self._sock else address,
+                self._incoming,
+                path=address,
                 ssl=ssl,
                 backlog=1,
                 limit=self._limit,
-                sock=self._sock,
             )
 
-        server = await coro     # Starts listening
-        await connected.wait()  # Waits for the callback to fire (and finish)
-        assert server is None
-        self._sock = None
+        # Allow runstate watchers to witness 'CONNECTING' state; some
+        # failures in the streaming layer are synchronous and will not
+        # otherwise yield.
+        await asyncio.sleep(0)
+
+        # This will start the server (bind(2), listen(2)). It will also
+        # call accept(2) if we yield, but we don't block on that here.
+        self._server = await coro
+        self.logger.debug("Server listening on %s", address)
+
+    @upper_half
+    async def _do_accept(self) -> None:
+        """
+        Wait for and accept an incoming connection.
+
+        Requires that we have not yet accepted an incoming connection
+        from the upper_half, but it's OK if the server is no longer
+        running because the bottom_half has already accepted the
+        connection.
+        """
+        assert self._accepted is not None
+        await self._accepted.wait()
+        assert self._server is None
+        self._accepted = None
 
         self.logger.debug("Connection accepted.")
 
@@ -532,6 +611,14 @@ class AsyncProtocol(Generic[T]):
 
         :raise OSError: For stream-related errors.
         """
+        assert self.runstate == Runstate.IDLE
+        self._set_state(Runstate.CONNECTING)
+
+        # Allow runstate watchers to witness 'CONNECTING' state; some
+        # failures in the streaming layer are synchronous and will not
+        # otherwise yield.
+        await asyncio.sleep(0)
+
         self.logger.debug("Connecting to %s ...", address)
 
         if isinstance(address, tuple):
@@ -644,6 +731,7 @@ class AsyncProtocol(Generic[T]):
 
         self._reader = None
         self._writer = None
+        self._accepted = None
 
         # NB: _runstate_changed cannot be cleared because we still need it to
         # send the final runstate changed event ...!
@@ -667,6 +755,9 @@ class AsyncProtocol(Generic[T]):
         def _done(task: Optional['asyncio.Future[Any]']) -> bool:
             return task is not None and task.done()
 
+        # If the server is running, stop it.
+        await self._stop_server()
+
         # Are we already in an error pathway? If either of the tasks are
         # already done, or if we have no tasks but a reader/writer; we
         # must be.