summary refs log tree commit diff stats
path: root/python/qemu/qmp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/qemu/qmp/protocol.py')
-rw-r--r--python/qemu/qmp/protocol.py194
1 files changed, 119 insertions, 75 deletions
diff --git a/python/qemu/qmp/protocol.py b/python/qemu/qmp/protocol.py
index a4ffdfad51..219d092a79 100644
--- a/python/qemu/qmp/protocol.py
+++ b/python/qemu/qmp/protocol.py
@@ -15,13 +15,16 @@ class.
 
 import asyncio
 from asyncio import StreamReader, StreamWriter
+from contextlib import asynccontextmanager
 from enum import Enum
 from functools import wraps
+from inspect import iscoroutinefunction
 import logging
 import socket
 from ssl import SSLContext
 from typing import (
     Any,
+    AsyncGenerator,
     Awaitable,
     Callable,
     Generic,
@@ -36,13 +39,10 @@ from typing import (
 from .error import QMPError
 from .util import (
     bottom_half,
-    create_task,
     exception_summary,
     flush,
-    is_closing,
     pretty_traceback,
     upper_half,
-    wait_closed,
 )
 
 
@@ -54,6 +54,9 @@ InternetAddrT = Tuple[str, int]
 UnixAddrT = str
 SocketAddrT = Union[UnixAddrT, InternetAddrT]
 
+# Maximum allowable size of read buffer, default
+_DEFAULT_READBUFLEN = 64 * 1024
+
 
 class Runstate(Enum):
     """Protocol session runstate."""
@@ -76,11 +79,17 @@ class ConnectError(QMPError):
     This Exception always wraps a "root cause" exception that can be
     interrogated for additional information.
 
+    For example, when connecting to a non-existent socket::
+
+        await qmp.connect('not_found.sock')
+        # ConnectError: Failed to establish connection:
+        #               [Errno 2] No such file or directory
+
     :param error_message: Human-readable string describing the error.
     :param exc: The root-cause exception.
     """
     def __init__(self, error_message: str, exc: Exception):
-        super().__init__(error_message)
+        super().__init__(error_message, exc)
         #: Human-readable error string
         self.error_message: str = error_message
         #: Wrapped root cause exception
@@ -99,8 +108,8 @@ class StateError(QMPError):
     An API command (connect, execute, etc) was issued at an inappropriate time.
 
     This error is raised when a command like
-    :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
-    time.
+    :py:meth:`~AsyncProtocol.connect()` is called when the client is
+    already connected.
 
     :param error_message: Human-readable string describing the state violation.
     :param state: The actual `Runstate` seen at the time of the violation.
@@ -108,11 +117,14 @@ class StateError(QMPError):
     """
     def __init__(self, error_message: str,
                  state: Runstate, required: Runstate):
-        super().__init__(error_message)
+        super().__init__(error_message, state, required)
         self.error_message = error_message
         self.state = state
         self.required = required
 
+    def __str__(self) -> str:
+        return self.error_message
+
 
 F = TypeVar('F', bound=Callable[..., Any])  # pylint: disable=invalid-name
 
@@ -125,6 +137,25 @@ def require(required_state: Runstate) -> Callable[[F], F]:
     :param required_state: The `Runstate` required to invoke this method.
     :raise StateError: When the required `Runstate` is not met.
     """
+    def _check(proto: 'AsyncProtocol[Any]') -> None:
+        name = type(proto).__name__
+        if proto.runstate == required_state:
+            return
+
+        if proto.runstate == Runstate.CONNECTING:
+            emsg = f"{name} is currently connecting."
+        elif proto.runstate == Runstate.DISCONNECTING:
+            emsg = (f"{name} is disconnecting."
+                    " Call disconnect() to return to IDLE state.")
+        elif proto.runstate == Runstate.RUNNING:
+            emsg = f"{name} is already connected and running."
+        elif proto.runstate == Runstate.IDLE:
+            emsg = f"{name} is disconnected and idle."
+        else:
+            assert False
+
+        raise StateError(emsg, proto.runstate, required_state)
+
     def _decorator(func: F) -> F:
         # _decorator is the decorator that is built by calling the
         # require() decorator factory; e.g.:
@@ -135,29 +166,20 @@ def require(required_state: Runstate) -> Callable[[F], F]:
         @wraps(func)
         def _wrapper(proto: 'AsyncProtocol[Any]',
                      *args: Any, **kwargs: Any) -> Any:
-            # _wrapper is the function that gets executed prior to the
-            # decorated method.
-
-            name = type(proto).__name__
-
-            if proto.runstate != required_state:
-                if proto.runstate == Runstate.CONNECTING:
-                    emsg = f"{name} is currently connecting."
-                elif proto.runstate == Runstate.DISCONNECTING:
-                    emsg = (f"{name} is disconnecting."
-                            " Call disconnect() to return to IDLE state.")
-                elif proto.runstate == Runstate.RUNNING:
-                    emsg = f"{name} is already connected and running."
-                elif proto.runstate == Runstate.IDLE:
-                    emsg = f"{name} is disconnected and idle."
-                else:
-                    assert False
-                raise StateError(emsg, proto.runstate, required_state)
-            # No StateError, so call the wrapped method.
+            _check(proto)
             return func(proto, *args, **kwargs)
 
-        # Return the decorated method;
-        # Transforming Func to Decorated[Func].
+        @wraps(func)
+        async def _async_wrapper(proto: 'AsyncProtocol[Any]',
+                                 *args: Any, **kwargs: Any) -> Any:
+            _check(proto)
+            return await func(proto, *args, **kwargs)
+
+        # Return the decorated method; F => Decorated[F]
+        # Use an async version when applicable, which
+        # preserves async signature generation in sphinx.
+        if iscoroutinefunction(func):
+            return cast(F, _async_wrapper)
         return cast(F, _wrapper)
 
     # Return the decorator instance from the decorator factory. Phew!
@@ -200,24 +222,26 @@ class AsyncProtocol(Generic[T]):
         will log to 'qemu.qmp.protocol', but each individual connection
         can be given its own logger by giving it a name; messages will
         then log to 'qemu.qmp.protocol.${name}'.
+    :param readbuflen:
+        The maximum read buffer length of the underlying StreamReader
+        instance.
     """
     # pylint: disable=too-many-instance-attributes
 
     #: Logger object for debugging messages from this connection.
     logger = logging.getLogger(__name__)
 
-    # Maximum allowable size of read buffer
-    _limit = 64 * 1024
-
     # -------------------------
     # Section: Public interface
     # -------------------------
 
-    def __init__(self, name: Optional[str] = None) -> None:
-        #: The nickname for this connection, if any.
-        self.name: Optional[str] = name
-        if self.name is not None:
-            self.logger = self.logger.getChild(self.name)
+    def __init__(
+        self, name: Optional[str] = None,
+        readbuflen: int = _DEFAULT_READBUFLEN
+    ) -> None:
+        self._name: Optional[str]
+        self.name = name
+        self.readbuflen = readbuflen
 
         # stream I/O
         self._reader: Optional[StreamReader] = None
@@ -254,6 +278,24 @@ class AsyncProtocol(Generic[T]):
         tokens.append(f"runstate={self.runstate.name}")
         return f"<{cls_name} {' '.join(tokens)}>"
 
+    @property
+    def name(self) -> Optional[str]:
+        """
+        The nickname for this connection, if any.
+
+        This name is used for differentiating instances in debug output.
+        """
+        return self._name
+
+    @name.setter
+    def name(self, name: Optional[str]) -> None:
+        logger = logging.getLogger(__name__)
+        if name:
+            self.logger = logger.getChild(name)
+        else:
+            self.logger = logger
+        self._name = name
+
     @property  # @upper_half
     def runstate(self) -> Runstate:
         """The current `Runstate` of the connection."""
@@ -262,7 +304,7 @@ class AsyncProtocol(Generic[T]):
     @upper_half
     async def runstate_changed(self) -> Runstate:
         """
-        Wait for the `runstate` to change, then return that runstate.
+        Wait for the `runstate` to change, then return that `Runstate`.
         """
         await self._runstate_event.wait()
         return self.runstate
@@ -276,9 +318,9 @@ class AsyncProtocol(Generic[T]):
         """
         Accept a connection and begin processing message queues.
 
-        If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
-        This method is precisely equivalent to calling `start_server()`
-        followed by `accept()`.
+        If this call fails, `runstate` is guaranteed to be set back to
+        `IDLE`.  This method is precisely equivalent to calling
+        `start_server()` followed by :py:meth:`~AsyncProtocol.accept()`.
 
         :param address:
             Address to listen on; UNIX socket path or TCP address/port.
@@ -291,7 +333,8 @@ class AsyncProtocol(Generic[T]):
             This exception will wrap a more concrete one. In most cases,
             the wrapped exception will be `OSError` or `EOFError`. If a
             protocol-level failure occurs while establishing a new
-            session, the wrapped error may also be an `QMPError`.
+            session, the wrapped error may also be a `QMPError`.
+
         """
         await self.start_server(address, ssl)
         await self.accept()
@@ -307,8 +350,8 @@ class AsyncProtocol(Generic[T]):
         This method starts listening for an incoming connection, but
         does not block waiting for a peer. This call will return
         immediately after binding and listening on a socket. A later
-        call to `accept()` must be made in order to finalize the
-        incoming connection.
+        call to :py:meth:`~AsyncProtocol.accept()` must be made in order
+        to finalize the incoming connection.
 
         :param address:
             Address to listen on; UNIX socket path or TCP address/port.
@@ -321,9 +364,8 @@ class AsyncProtocol(Generic[T]):
             This exception will wrap a more concrete one. In most cases,
             the wrapped exception will be `OSError`.
         """
-        await self._session_guard(
-            self._do_start_server(address, ssl),
-            'Failed to establish connection')
+        async with self._session_guard('Failed to establish connection'):
+            await self._do_start_server(address, ssl)
         assert self.runstate == Runstate.CONNECTING
 
     @upper_half
@@ -332,10 +374,12 @@ class AsyncProtocol(Generic[T]):
         """
         Accept an incoming connection and begin processing message queues.
 
-        If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
+        Used after a previous call to `start_server()` to accept an
+        incoming connection. If this call fails, `runstate` is
+        guaranteed to be set back to `IDLE`.
 
         :raise StateError: When the `Runstate` is not `CONNECTING`.
-        :raise QMPError: When `start_server()` was not called yet.
+        :raise QMPError: When `start_server()` was not called first.
         :raise ConnectError:
             When a connection or session cannot be established.
 
@@ -346,12 +390,10 @@ class AsyncProtocol(Generic[T]):
         """
         if self._accepted is None:
             raise QMPError("Cannot call accept() before start_server().")
-        await self._session_guard(
-            self._do_accept(),
-            'Failed to establish connection')
-        await self._session_guard(
-            self._establish_session(),
-            'Failed to establish session')
+        async with self._session_guard('Failed to establish connection'):
+            await self._do_accept()
+        async with self._session_guard('Failed to establish session'):
+            await self._establish_session()
         assert self.runstate == Runstate.RUNNING
 
     @upper_half
@@ -376,12 +418,10 @@ class AsyncProtocol(Generic[T]):
             protocol-level failure occurs while establishing a new
             session, the wrapped error may also be an `QMPError`.
         """
-        await self._session_guard(
-            self._do_connect(address, ssl),
-            'Failed to establish connection')
-        await self._session_guard(
-            self._establish_session(),
-            'Failed to establish session')
+        async with self._session_guard('Failed to establish connection'):
+            await self._do_connect(address, ssl)
+        async with self._session_guard('Failed to establish session'):
+            await self._establish_session()
         assert self.runstate == Runstate.RUNNING
 
     @upper_half
@@ -392,7 +432,11 @@ class AsyncProtocol(Generic[T]):
         If there was an exception that caused the reader/writers to
         terminate prematurely, it will be raised here.
 
-        :raise Exception: When the reader or writer terminate unexpectedly.
+        :raise Exception:
+            When the reader or writer terminate unexpectedly. You can
+            expect to see `EOFError` if the server hangs up, or
+            `OSError` for connection-related issues. If there was a QMP
+            protocol-level problem, `ProtocolError` will be seen.
         """
         self.logger.debug("disconnect() called.")
         self._schedule_disconnect()
@@ -402,7 +446,8 @@ class AsyncProtocol(Generic[T]):
     # Section: Session machinery
     # --------------------------
 
-    async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
+    @asynccontextmanager
+    async def _session_guard(self, emsg: str) -> AsyncGenerator[None, None]:
         """
         Async guard function used to roll back to `IDLE` on any error.
 
@@ -419,10 +464,9 @@ class AsyncProtocol(Generic[T]):
         :raise ConnectError:
             When any other error is encountered in the guarded block.
         """
-        # Note: After Python 3.6 support is removed, this should be an
-        # @asynccontextmanager instead of accepting a callback.
         try:
-            await coro
+            # Caller's code runs here.
+            yield
         except BaseException as err:
             self.logger.error("%s: %s", emsg, exception_summary(err))
             self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
@@ -561,7 +605,7 @@ class AsyncProtocol(Generic[T]):
                 port=address[1],
                 ssl=ssl,
                 backlog=1,
-                limit=self._limit,
+                limit=self.readbuflen,
             )
         else:
             coro = asyncio.start_unix_server(
@@ -569,7 +613,7 @@ class AsyncProtocol(Generic[T]):
                 path=address,
                 ssl=ssl,
                 backlog=1,
-                limit=self._limit,
+                limit=self.readbuflen,
             )
 
         # Allow runstate watchers to witness 'CONNECTING' state; some
@@ -624,7 +668,7 @@ class AsyncProtocol(Generic[T]):
                               "fd=%d, family=%r, type=%r",
                               address.fileno(), address.family, address.type)
             connect = asyncio.open_connection(
-                limit=self._limit,
+                limit=self.readbuflen,
                 ssl=ssl,
                 sock=address,
             )
@@ -634,14 +678,14 @@ class AsyncProtocol(Generic[T]):
                 address[0],
                 address[1],
                 ssl=ssl,
-                limit=self._limit,
+                limit=self.readbuflen,
             )
         else:
             self.logger.debug("Connecting to file://%s ...", address)
             connect = asyncio.open_unix_connection(
                 path=address,
                 ssl=ssl,
-                limit=self._limit,
+                limit=self.readbuflen,
             )
 
         self._reader, self._writer = await connect
@@ -663,8 +707,8 @@ class AsyncProtocol(Generic[T]):
         reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader')
         writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer')
 
-        self._reader_task = create_task(reader_coro)
-        self._writer_task = create_task(writer_coro)
+        self._reader_task = asyncio.create_task(reader_coro)
+        self._writer_task = asyncio.create_task(writer_coro)
 
         self._bh_tasks = asyncio.gather(
             self._reader_task,
@@ -689,7 +733,7 @@ class AsyncProtocol(Generic[T]):
         if not self._dc_task:
             self._set_state(Runstate.DISCONNECTING)
             self.logger.debug("Scheduling disconnect.")
-            self._dc_task = create_task(self._bh_disconnect())
+            self._dc_task = asyncio.create_task(self._bh_disconnect())
 
     @upper_half
     async def _wait_disconnect(self) -> None:
@@ -825,13 +869,13 @@ class AsyncProtocol(Generic[T]):
         if not self._writer:
             return
 
-        if not is_closing(self._writer):
+        if not self._writer.is_closing():
             self.logger.debug("Closing StreamWriter.")
             self._writer.close()
 
         self.logger.debug("Waiting for StreamWriter to close ...")
         try:
-            await wait_closed(self._writer)
+            await self._writer.wait_closed()
         except Exception:  # pylint: disable=broad-except
             # It's hard to tell if the Stream is already closed or
             # not. Even if one of the tasks has failed, it may have