summary refs log tree commit diff stats
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/qemu/aqmp/protocol.py117
-rw-r--r--python/tests/protocol.py10
2 files changed, 53 insertions, 74 deletions
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py
index 73719257e0..b7e5e635d8 100644
--- a/python/qemu/aqmp/protocol.py
+++ b/python/qemu/aqmp/protocol.py
@@ -275,13 +275,25 @@ class AsyncProtocol(Generic[T]):
         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
 
         :param address:
-            Address to listen to; UNIX socket path or TCP address/port.
+            Address to listen on; UNIX socket path or TCP address/port.
         :param ssl: SSL context to use, if any.
 
         :raise StateError: When the `Runstate` is not `IDLE`.
-        :raise ConnectError: If a connection could not be accepted.
+        :raise ConnectError:
+            When a connection or session cannot be established.
+
+            This exception will wrap a more concrete one. In most cases,
+            the wrapped exception will be `OSError` or `EOFError`. If a
+            protocol-level failure occurs while establishing a new
+            session, the wrapped error may also be an `QMPError`.
         """
-        await self._new_session(address, ssl, accept=True)
+        await self._session_guard(
+            self._do_accept(address, ssl),
+            'Failed to establish connection')
+        await self._session_guard(
+            self._establish_session(),
+            'Failed to establish session')
+        assert self.runstate == Runstate.RUNNING
 
     @upper_half
     @require(Runstate.IDLE)
@@ -297,9 +309,21 @@ class AsyncProtocol(Generic[T]):
         :param ssl: SSL context to use, if any.
 
         :raise StateError: When the `Runstate` is not `IDLE`.
-        :raise ConnectError: If a connection cannot be made to the server.
+        :raise ConnectError:
+            When a connection or session cannot be established.
+
+            This exception will wrap a more concrete one. In most cases,
+            the wrapped exception will be `OSError` or `EOFError`. If a
+            protocol-level failure occurs while establishing a new
+            session, the wrapped error may also be an `QMPError`.
         """
-        await self._new_session(address, ssl)
+        await self._session_guard(
+            self._do_connect(address, ssl),
+            'Failed to establish connection')
+        await self._session_guard(
+            self._establish_session(),
+            'Failed to establish session')
+        assert self.runstate == Runstate.RUNNING
 
     @upper_half
     async def disconnect(self) -> None:
@@ -401,73 +425,6 @@ class AsyncProtocol(Generic[T]):
         self._runstate_event.set()
         self._runstate_event.clear()
 
-    @upper_half
-    async def _new_session(self,
-                           address: SocketAddrT,
-                           ssl: Optional[SSLContext] = None,
-                           accept: bool = False) -> None:
-        """
-        Establish a new connection and initialize the session.
-
-        Connect or accept a new connection, then begin the protocol
-        session machinery. If this call fails, `runstate` is guaranteed
-        to be set back to `IDLE`.
-
-        :param address:
-            Address to connect to/listen on;
-            UNIX socket path or TCP address/port.
-        :param ssl: SSL context to use, if any.
-        :param accept: Accept a connection instead of connecting when `True`.
-
-        :raise ConnectError:
-            When a connection or session cannot be established.
-
-            This exception will wrap a more concrete one. In most cases,
-            the wrapped exception will be `OSError` or `EOFError`. If a
-            protocol-level failure occurs while establishing a new
-            session, the wrapped error may also be an `QMPError`.
-        """
-        assert self.runstate == Runstate.IDLE
-
-        await self._session_guard(
-            self._establish_connection(address, ssl, accept),
-            'Failed to establish connection')
-
-        await self._session_guard(
-            self._establish_session(),
-            'Failed to establish session')
-
-        assert self.runstate == Runstate.RUNNING
-
-    @upper_half
-    async def _establish_connection(
-            self,
-            address: SocketAddrT,
-            ssl: Optional[SSLContext] = None,
-            accept: bool = False
-    ) -> None:
-        """
-        Establish a new connection.
-
-        :param address:
-            Address to connect to/listen on;
-            UNIX socket path or TCP address/port.
-        :param ssl: SSL context to use, if any.
-        :param accept: Accept a connection instead of connecting when `True`.
-        """
-        assert self.runstate == Runstate.IDLE
-        self._set_state(Runstate.CONNECTING)
-
-        # Allow runstate watchers to witness 'CONNECTING' state; some
-        # failures in the streaming layer are synchronous and will not
-        # otherwise yield.
-        await asyncio.sleep(0)
-
-        if accept:
-            await self._do_accept(address, ssl)
-        else:
-            await self._do_connect(address, ssl)
-
     def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
         """
         Used to create a socket in advance of accept().
@@ -508,6 +465,9 @@ class AsyncProtocol(Generic[T]):
 
         :raise OSError: For stream-related errors.
         """
+        assert self.runstate == Runstate.IDLE
+        self._set_state(Runstate.CONNECTING)
+
         self.logger.debug("Awaiting connection on %s ...", address)
         connected = asyncio.Event()
         server: Optional[asyncio.AbstractServer] = None
@@ -550,6 +510,11 @@ class AsyncProtocol(Generic[T]):
                 sock=self._sock,
             )
 
+        # Allow runstate watchers to witness 'CONNECTING' state; some
+        # failures in the streaming layer are synchronous and will not
+        # otherwise yield.
+        await asyncio.sleep(0)
+
         server = await coro     # Starts listening
         await connected.wait()  # Waits for the callback to fire (and finish)
         assert server is None
@@ -569,6 +534,14 @@ class AsyncProtocol(Generic[T]):
 
         :raise OSError: For stream-related errors.
         """
+        assert self.runstate == Runstate.IDLE
+        self._set_state(Runstate.CONNECTING)
+
+        # Allow runstate watchers to witness 'CONNECTING' state; some
+        # failures in the streaming layer are synchronous and will not
+        # otherwise yield.
+        await asyncio.sleep(0)
+
         self.logger.debug("Connecting to %s ...", address)
 
         if isinstance(address, tuple):
diff --git a/python/tests/protocol.py b/python/tests/protocol.py
index 354d6559b9..8dd26c4ed1 100644
--- a/python/tests/protocol.py
+++ b/python/tests/protocol.py
@@ -42,11 +42,17 @@ class NullProtocol(AsyncProtocol[None]):
         await super()._establish_session()
 
     async def _do_accept(self, address, ssl=None):
-        if not self.fake_session:
+        if self.fake_session:
+            self._set_state(Runstate.CONNECTING)
+            await asyncio.sleep(0)
+        else:
             await super()._do_accept(address, ssl)
 
     async def _do_connect(self, address, ssl=None):
-        if not self.fake_session:
+        if self.fake_session:
+            self._set_state(Runstate.CONNECTING)
+            await asyncio.sleep(0)
+        else:
             await super()._do_connect(address, ssl)
 
     async def _do_recv(self) -> None: