summary refs log tree commit diff stats
path: root/python/tests
diff options
context:
space:
mode:
Diffstat (limited to 'python/tests')
-rw-r--r--python/tests/protocol.py45
1 files changed, 29 insertions, 16 deletions
diff --git a/python/tests/protocol.py b/python/tests/protocol.py
index 5cd7938be3..d6849ad306 100644
--- a/python/tests/protocol.py
+++ b/python/tests/protocol.py
@@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]):
         self.trigger_input = asyncio.Event()
         await super()._establish_session()
 
-    async def _do_accept(self, address, ssl=None):
-        if not self.fake_session:
-            await super()._do_accept(address, ssl)
+    async def _do_start_server(self, address, ssl=None):
+        if self.fake_session:
+            self._accepted = asyncio.Event()
+            self._set_state(Runstate.CONNECTING)
+            await asyncio.sleep(0)
+        else:
+            await super()._do_start_server(address, ssl)
+
+    async def _do_accept(self):
+        if self.fake_session:
+            self._accepted = None
+        else:
+            await super()._do_accept()
 
     async def _do_connect(self, address, ssl=None):
-        if not self.fake_session:
+        if self.fake_session:
+            self._set_state(Runstate.CONNECTING)
+            await asyncio.sleep(0)
+        else:
             await super()._do_connect(address, ssl)
 
     async def _do_recv(self) -> None:
@@ -413,14 +426,14 @@ class Accept(Connect):
         assert family in ('INET', 'UNIX')
 
         if family == 'INET':
-            await self.proto.accept(('example.com', 1))
+            await self.proto.start_server_and_accept(('example.com', 1))
         elif family == 'UNIX':
-            await self.proto.accept('/dev/null')
+            await self.proto.start_server_and_accept('/dev/null')
 
     async def _hanging_connection(self):
         with TemporaryDirectory(suffix='.aqmp') as tmpdir:
             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
-            await self.proto.accept(sock)
+            await self.proto.start_server_and_accept(sock)
 
 
 class FakeSession(TestBase):
@@ -449,13 +462,13 @@ class FakeSession(TestBase):
     @TestBase.async_test
     async def testFakeAccept(self):
         """Test the full state lifecycle (via accept) with a no-op session."""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
 
     @TestBase.async_test
     async def testFakeRecv(self):
         """Test receiving a fake/null message."""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
 
         logname = self.proto.logger.name
         with self.assertLogs(logname, level='DEBUG') as context:
@@ -471,7 +484,7 @@ class FakeSession(TestBase):
     @TestBase.async_test
     async def testFakeSend(self):
         """Test sending a fake/null message."""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
 
         logname = self.proto.logger.name
         with self.assertLogs(logname, level='DEBUG') as context:
@@ -493,7 +506,7 @@ class FakeSession(TestBase):
     ):
         with self.assertRaises(StateError) as context:
             if accept:
-                await self.proto.accept('/not/a/real/path')
+                await self.proto.start_server_and_accept('/not/a/real/path')
             else:
                 await self.proto.connect('/not/a/real/path')
 
@@ -504,7 +517,7 @@ class FakeSession(TestBase):
     @TestBase.async_test
     async def testAcceptRequireRunning(self):
         """Test that accept() cannot be called when Runstate=RUNNING"""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
 
         await self._prod_session_api(
             Runstate.RUNNING,
@@ -515,7 +528,7 @@ class FakeSession(TestBase):
     @TestBase.async_test
     async def testConnectRequireRunning(self):
         """Test that connect() cannot be called when Runstate=RUNNING"""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
 
         await self._prod_session_api(
             Runstate.RUNNING,
@@ -526,7 +539,7 @@ class FakeSession(TestBase):
     @TestBase.async_test
     async def testAcceptRequireDisconnecting(self):
         """Test that accept() cannot be called when Runstate=DISCONNECTING"""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
 
         # Cheat: force a disconnect.
         await self.proto.simulate_disconnect()
@@ -541,7 +554,7 @@ class FakeSession(TestBase):
     @TestBase.async_test
     async def testConnectRequireDisconnecting(self):
         """Test that connect() cannot be called when Runstate=DISCONNECTING"""
-        await self.proto.accept('/not/a/real/path')
+        await self.proto.start_server_and_accept('/not/a/real/path')
 
         # Cheat: force a disconnect.
         await self.proto.simulate_disconnect()
@@ -576,7 +589,7 @@ class SimpleSession(TestBase):
     async def testSmoke(self):
         with TemporaryDirectory(suffix='.aqmp') as tmpdir:
             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
-            server_task = create_task(self.server.accept(sock))
+            server_task = create_task(self.server.start_server_and_accept(sock))
 
             # give the server a chance to start listening [...]
             await asyncio.sleep(0)