summary refs log tree commit diff stats
path: root/io/channel.c
diff options
context:
space:
mode:
Diffstat (limited to 'io/channel.c')
-rw-r--r--io/channel.c118
1 files changed, 87 insertions, 31 deletions
diff --git a/io/channel.c b/io/channel.c
index 72f0066af5..86c5834510 100644
--- a/io/channel.c
+++ b/io/channel.c
@@ -365,6 +365,12 @@ int qio_channel_set_blocking(QIOChannel *ioc,
 }
 
 
+void qio_channel_set_follow_coroutine_ctx(QIOChannel *ioc, bool enabled)
+{
+    ioc->follow_coroutine_ctx = enabled;
+}
+
+
 int qio_channel_close(QIOChannel *ioc,
                       Error **errp)
 {
@@ -388,14 +394,16 @@ GSource *qio_channel_create_watch(QIOChannel *ioc,
 
 
 void qio_channel_set_aio_fd_handler(QIOChannel *ioc,
-                                    AioContext *ctx,
+                                    AioContext *read_ctx,
                                     IOHandler *io_read,
+                                    AioContext *write_ctx,
                                     IOHandler *io_write,
                                     void *opaque)
 {
     QIOChannelClass *klass = QIO_CHANNEL_GET_CLASS(ioc);
 
-    klass->io_set_aio_fd_handler(ioc, ctx, io_read, io_write, opaque);
+    klass->io_set_aio_fd_handler(ioc, read_ctx, io_read, write_ctx, io_write,
+            opaque);
 }
 
 guint qio_channel_add_watch_full(QIOChannel *ioc,
@@ -542,56 +550,101 @@ static void qio_channel_restart_write(void *opaque)
     aio_co_wake(co);
 }
 
-static void qio_channel_set_aio_fd_handlers(QIOChannel *ioc)
+static void coroutine_fn
+qio_channel_set_fd_handlers(QIOChannel *ioc, GIOCondition condition)
 {
-    IOHandler *rd_handler = NULL, *wr_handler = NULL;
-    AioContext *ctx;
+    AioContext *ctx = ioc->follow_coroutine_ctx ?
+        qemu_coroutine_get_aio_context(qemu_coroutine_self()) :
+        iohandler_get_aio_context();
+    AioContext *read_ctx = NULL;
+    IOHandler *io_read = NULL;
+    AioContext *write_ctx = NULL;
+    IOHandler *io_write = NULL;
 
-    if (ioc->read_coroutine) {
-        rd_handler = qio_channel_restart_read;
-    }
-    if (ioc->write_coroutine) {
-        wr_handler = qio_channel_restart_write;
+    if (condition == G_IO_IN) {
+        ioc->read_coroutine = qemu_coroutine_self();
+        ioc->read_ctx = ctx;
+        read_ctx = ctx;
+        io_read = qio_channel_restart_read;
+
+        /*
+         * Thread safety: if the other coroutine is set and its AioContext
+         * matches ours, then there is mutual exclusion between read and write
+         * because they share a single thread and it's safe to set both read
+         * and write fd handlers here. If the AioContext does not match ours,
+         * then both threads may run in parallel but there is no shared state
+         * to worry about.
+         */
+        if (ioc->write_coroutine && ioc->write_ctx == ctx) {
+            write_ctx = ctx;
+            io_write = qio_channel_restart_write;
+        }
+    } else if (condition == G_IO_OUT) {
+        ioc->write_coroutine = qemu_coroutine_self();
+        ioc->write_ctx = ctx;
+        write_ctx = ctx;
+        io_write = qio_channel_restart_write;
+        if (ioc->read_coroutine && ioc->read_ctx == ctx) {
+            read_ctx = ctx;
+            io_read = qio_channel_restart_read;
+        }
+    } else {
+        abort();
     }
 
-    ctx = ioc->ctx ? ioc->ctx : iohandler_get_aio_context();
-    qio_channel_set_aio_fd_handler(ioc, ctx, rd_handler, wr_handler, ioc);
+    qio_channel_set_aio_fd_handler(ioc, read_ctx, io_read,
+            write_ctx, io_write, ioc);
 }
 
-void qio_channel_attach_aio_context(QIOChannel *ioc,
-                                    AioContext *ctx)
+static void coroutine_fn
+qio_channel_clear_fd_handlers(QIOChannel *ioc, GIOCondition condition)
 {
-    assert(!ioc->read_coroutine);
-    assert(!ioc->write_coroutine);
-    ioc->ctx = ctx;
-}
+    AioContext *read_ctx = NULL;
+    IOHandler *io_read = NULL;
+    AioContext *write_ctx = NULL;
+    IOHandler *io_write = NULL;
+    AioContext *ctx;
 
-void qio_channel_detach_aio_context(QIOChannel *ioc)
-{
-    ioc->read_coroutine = NULL;
-    ioc->write_coroutine = NULL;
-    qio_channel_set_aio_fd_handlers(ioc);
-    ioc->ctx = NULL;
+    if (condition == G_IO_IN) {
+        ctx = ioc->read_ctx;
+        read_ctx = ctx;
+        io_read = NULL;
+        if (ioc->write_coroutine && ioc->write_ctx == ctx) {
+            write_ctx = ctx;
+            io_write = qio_channel_restart_write;
+        }
+    } else if (condition == G_IO_OUT) {
+        ctx = ioc->write_ctx;
+        write_ctx = ctx;
+        io_write = NULL;
+        if (ioc->read_coroutine && ioc->read_ctx == ctx) {
+            read_ctx = ctx;
+            io_read = qio_channel_restart_read;
+        }
+    } else {
+        abort();
+    }
+
+    qio_channel_set_aio_fd_handler(ioc, read_ctx, io_read,
+            write_ctx, io_write, ioc);
 }
 
 void coroutine_fn qio_channel_yield(QIOChannel *ioc,
                                     GIOCondition condition)
 {
-    AioContext *ioc_ctx = ioc->ctx ?: qemu_get_aio_context();
+    AioContext *ioc_ctx;
 
     assert(qemu_in_coroutine());
-    assert(in_aio_context_home_thread(ioc_ctx));
+    ioc_ctx = qemu_coroutine_get_aio_context(qemu_coroutine_self());
 
     if (condition == G_IO_IN) {
         assert(!ioc->read_coroutine);
-        ioc->read_coroutine = qemu_coroutine_self();
     } else if (condition == G_IO_OUT) {
         assert(!ioc->write_coroutine);
-        ioc->write_coroutine = qemu_coroutine_self();
     } else {
         abort();
     }
-    qio_channel_set_aio_fd_handlers(ioc);
+    qio_channel_set_fd_handlers(ioc, condition);
     qemu_coroutine_yield();
     assert(in_aio_context_home_thread(ioc_ctx));
 
@@ -599,11 +652,10 @@ void coroutine_fn qio_channel_yield(QIOChannel *ioc,
      * through the aio_fd_handlers. */
     if (condition == G_IO_IN) {
         assert(ioc->read_coroutine == NULL);
-        qio_channel_set_aio_fd_handlers(ioc);
     } else if (condition == G_IO_OUT) {
         assert(ioc->write_coroutine == NULL);
-        qio_channel_set_aio_fd_handlers(ioc);
     }
+    qio_channel_clear_fd_handlers(ioc, condition);
 }
 
 void qio_channel_wake_read(QIOChannel *ioc)
@@ -653,6 +705,10 @@ static void qio_channel_finalize(Object *obj)
 {
     QIOChannel *ioc = QIO_CHANNEL(obj);
 
+    /* Must not have coroutines in qio_channel_yield() */
+    assert(!ioc->read_coroutine);
+    assert(!ioc->write_coroutine);
+
     g_free(ioc->name);
 
 #ifdef _WIN32