summary refs log tree commit diff stats
path: root/hw/virtio/vhost-user.c
diff options
context:
space:
mode:
Diffstat (limited to 'hw/virtio/vhost-user.c')
-rw-r--r--hw/virtio/vhost-user.c217
1 files changed, 138 insertions, 79 deletions
diff --git a/hw/virtio/vhost-user.c b/hw/virtio/vhost-user.c
index 2fdd5daf74..ded0c10453 100644
--- a/hw/virtio/vhost-user.c
+++ b/hw/virtio/vhost-user.c
@@ -16,6 +16,7 @@
 #include "hw/virtio/virtio.h"
 #include "hw/virtio/virtio-net.h"
 #include "chardev/char-fe.h"
+#include "io/channel-socket.h"
 #include "sysemu/kvm.h"
 #include "qemu/error-report.h"
 #include "qemu/main-loop.h"
@@ -237,7 +238,8 @@ struct vhost_user {
     struct vhost_dev *dev;
     /* Shared between vhost devs of the same virtio device */
     VhostUserState *user;
-    int slave_fd;
+    QIOChannel *slave_ioc;
+    GSource *slave_src;
     NotifierWithReturn postcopy_notifier;
     struct PostCopyFD  postcopy_fd;
     uint64_t           postcopy_client_bases[VHOST_USER_MAX_RAM_SLOTS];
@@ -294,15 +296,27 @@ static int vhost_user_read_header(struct vhost_dev *dev, VhostUserMsg *msg)
     return 0;
 }
 
-static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
+struct vhost_user_read_cb_data {
+    struct vhost_dev *dev;
+    VhostUserMsg *msg;
+    GMainLoop *loop;
+    int ret;
+};
+
+static gboolean vhost_user_read_cb(GIOChannel *source, GIOCondition condition,
+                                   gpointer opaque)
 {
+    struct vhost_user_read_cb_data *data = opaque;
+    struct vhost_dev *dev = data->dev;
+    VhostUserMsg *msg = data->msg;
     struct vhost_user *u = dev->opaque;
     CharBackend *chr = u->user->chr;
     uint8_t *p = (uint8_t *) msg;
     int r, size;
 
     if (vhost_user_read_header(dev, msg) < 0) {
-        return -1;
+        data->ret = -1;
+        goto end;
     }
 
     /* validate message size is sane */
@@ -310,7 +324,8 @@ static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
         error_report("Failed to read msg header."
                 " Size %d exceeds the maximum %zu.", msg->hdr.size,
                 VHOST_USER_PAYLOAD_SIZE);
-        return -1;
+        data->ret = -1;
+        goto end;
     }
 
     if (msg->hdr.size) {
@@ -320,11 +335,84 @@ static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
         if (r != size) {
             error_report("Failed to read msg payload."
                          " Read %d instead of %d.", r, msg->hdr.size);
-            return -1;
+            data->ret = -1;
+            goto end;
         }
     }
 
-    return 0;
+end:
+    g_main_loop_quit(data->loop);
+    return G_SOURCE_REMOVE;
+}
+
+static gboolean slave_read(QIOChannel *ioc, GIOCondition condition,
+                           gpointer opaque);
+
+/*
+ * This updates the read handler to use a new event loop context.
+ * Event sources are removed from the previous context : this ensures
+ * that events detected in the previous context are purged. They will
+ * be re-detected and processed in the new context.
+ */
+static void slave_update_read_handler(struct vhost_dev *dev,
+                                      GMainContext *ctxt)
+{
+    struct vhost_user *u = dev->opaque;
+
+    if (!u->slave_ioc) {
+        return;
+    }
+
+    if (u->slave_src) {
+        g_source_destroy(u->slave_src);
+        g_source_unref(u->slave_src);
+    }
+
+    u->slave_src = qio_channel_add_watch_source(u->slave_ioc,
+                                                G_IO_IN | G_IO_HUP,
+                                                slave_read, dev, NULL,
+                                                ctxt);
+}
+
+static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
+{
+    struct vhost_user *u = dev->opaque;
+    CharBackend *chr = u->user->chr;
+    GMainContext *prev_ctxt = chr->chr->gcontext;
+    GMainContext *ctxt = g_main_context_new();
+    GMainLoop *loop = g_main_loop_new(ctxt, FALSE);
+    struct vhost_user_read_cb_data data = {
+        .dev = dev,
+        .loop = loop,
+        .msg = msg,
+        .ret = 0
+    };
+
+    /*
+     * We want to be able to monitor the slave channel fd while waiting
+     * for chr I/O. This requires an event loop, but we can't nest the
+     * one to which chr is currently attached : its fd handlers might not
+     * be prepared for re-entrancy. So we create a new one and switch chr
+     * to use it.
+     */
+    slave_update_read_handler(dev, ctxt);
+    qemu_chr_be_update_read_handlers(chr->chr, ctxt);
+    qemu_chr_fe_add_watch(chr, G_IO_IN | G_IO_HUP, vhost_user_read_cb, &data);
+
+    g_main_loop_run(loop);
+
+    /*
+     * Restore the previous event loop context. This also destroys/recreates
+     * event sources : this guarantees that all pending events in the original
+     * context that have been processed by the nested loop are purged.
+     */
+    qemu_chr_be_update_read_handlers(chr->chr, prev_ctxt);
+    slave_update_read_handler(dev, NULL);
+
+    g_main_loop_unref(loop);
+    g_main_context_unref(ctxt);
+
+    return data.ret;
 }
 
 static int process_message_reply(struct vhost_dev *dev,
@@ -1392,56 +1480,39 @@ static int vhost_user_slave_handle_vring_host_notifier(struct vhost_dev *dev,
     return 0;
 }
 
-static void slave_read(void *opaque)
+static void close_slave_channel(struct vhost_user *u)
+{
+    g_source_destroy(u->slave_src);
+    g_source_unref(u->slave_src);
+    u->slave_src = NULL;
+    object_unref(OBJECT(u->slave_ioc));
+    u->slave_ioc = NULL;
+}
+
+static gboolean slave_read(QIOChannel *ioc, GIOCondition condition,
+                           gpointer opaque)
 {
     struct vhost_dev *dev = opaque;
     struct vhost_user *u = dev->opaque;
     VhostUserHeader hdr = { 0, };
     VhostUserPayload payload = { 0, };
-    int size, ret = 0;
+    Error *local_err = NULL;
+    gboolean rc = G_SOURCE_CONTINUE;
+    int ret = 0;
     struct iovec iov;
-    struct msghdr msgh;
-    int fd[VHOST_USER_SLAVE_MAX_FDS];
-    char control[CMSG_SPACE(sizeof(fd))];
-    struct cmsghdr *cmsg;
-    int i, fdsize = 0;
-
-    memset(&msgh, 0, sizeof(msgh));
-    msgh.msg_iov = &iov;
-    msgh.msg_iovlen = 1;
-    msgh.msg_control = control;
-    msgh.msg_controllen = sizeof(control);
-
-    memset(fd, -1, sizeof(fd));
+    g_autofree int *fd = NULL;
+    size_t fdsize = 0;
+    int i;
 
     /* Read header */
     iov.iov_base = &hdr;
     iov.iov_len = VHOST_USER_HDR_SIZE;
 
-    do {
-        size = recvmsg(u->slave_fd, &msgh, 0);
-    } while (size < 0 && (errno == EINTR || errno == EAGAIN));
-
-    if (size != VHOST_USER_HDR_SIZE) {
-        error_report("Failed to read from slave.");
+    if (qio_channel_readv_full_all(ioc, &iov, 1, &fd, &fdsize, &local_err)) {
+        error_report_err(local_err);
         goto err;
     }
 
-    if (msgh.msg_flags & MSG_CTRUNC) {
-        error_report("Truncated message.");
-        goto err;
-    }
-
-    for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != NULL;
-         cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
-            if (cmsg->cmsg_level == SOL_SOCKET &&
-                cmsg->cmsg_type == SCM_RIGHTS) {
-                    fdsize = cmsg->cmsg_len - CMSG_LEN(0);
-                    memcpy(fd, CMSG_DATA(cmsg), fdsize);
-                    break;
-            }
-    }
-
     if (hdr.size > VHOST_USER_PAYLOAD_SIZE) {
         error_report("Failed to read msg header."
                 " Size %d exceeds the maximum %zu.", hdr.size,
@@ -1450,12 +1521,8 @@ static void slave_read(void *opaque)
     }
 
     /* Read payload */
-    do {
-        size = read(u->slave_fd, &payload, hdr.size);
-    } while (size < 0 && (errno == EINTR || errno == EAGAIN));
-
-    if (size != hdr.size) {
-        error_report("Failed to read payload from slave.");
+    if (qio_channel_read_all(ioc, (char *) &payload, hdr.size, &local_err)) {
+        error_report_err(local_err);
         goto err;
     }
 
@@ -1468,20 +1535,13 @@ static void slave_read(void *opaque)
         break;
     case VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG:
         ret = vhost_user_slave_handle_vring_host_notifier(dev, &payload.area,
-                                                          fd[0]);
+                                                          fd ? fd[0] : -1);
         break;
     default:
         error_report("Received unexpected msg type: %d.", hdr.request);
         ret = -EINVAL;
     }
 
-    /* Close the remaining file descriptors. */
-    for (i = 0; i < fdsize; i++) {
-        if (fd[i] != -1) {
-            close(fd[i]);
-        }
-    }
-
     /*
      * REPLY_ACK feature handling. Other reply types has to be managed
      * directly in their request handlers.
@@ -1501,28 +1561,25 @@ static void slave_read(void *opaque)
         iovec[1].iov_base = &payload;
         iovec[1].iov_len = hdr.size;
 
-        do {
-            size = writev(u->slave_fd, iovec, ARRAY_SIZE(iovec));
-        } while (size < 0 && (errno == EINTR || errno == EAGAIN));
-
-        if (size != VHOST_USER_HDR_SIZE + hdr.size) {
-            error_report("Failed to send msg reply to slave.");
+        if (qio_channel_writev_all(ioc, iovec, ARRAY_SIZE(iovec), &local_err)) {
+            error_report_err(local_err);
             goto err;
         }
     }
 
-    return;
+    goto fdcleanup;
 
 err:
-    qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
-    close(u->slave_fd);
-    u->slave_fd = -1;
-    for (i = 0; i < fdsize; i++) {
-        if (fd[i] != -1) {
+    close_slave_channel(u);
+    rc = G_SOURCE_REMOVE;
+
+fdcleanup:
+    if (fd) {
+        for (i = 0; i < fdsize; i++) {
             close(fd[i]);
         }
     }
-    return;
+    return rc;
 }
 
 static int vhost_setup_slave_channel(struct vhost_dev *dev)
@@ -1535,6 +1592,8 @@ static int vhost_setup_slave_channel(struct vhost_dev *dev)
     int sv[2], ret = 0;
     bool reply_supported = virtio_has_feature(dev->protocol_features,
                                               VHOST_USER_PROTOCOL_F_REPLY_ACK);
+    Error *local_err = NULL;
+    QIOChannel *ioc;
 
     if (!virtio_has_feature(dev->protocol_features,
                             VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
@@ -1546,8 +1605,13 @@ static int vhost_setup_slave_channel(struct vhost_dev *dev)
         return -1;
     }
 
-    u->slave_fd = sv[0];
-    qemu_set_fd_handler(u->slave_fd, slave_read, NULL, dev);
+    ioc = QIO_CHANNEL(qio_channel_socket_new_fd(sv[0], &local_err));
+    if (!ioc) {
+        error_report_err(local_err);
+        return -1;
+    }
+    u->slave_ioc = ioc;
+    slave_update_read_handler(dev, NULL);
 
     if (reply_supported) {
         msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
@@ -1565,9 +1629,7 @@ static int vhost_setup_slave_channel(struct vhost_dev *dev)
 out:
     close(sv[1]);
     if (ret) {
-        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
-        close(u->slave_fd);
-        u->slave_fd = -1;
+        close_slave_channel(u);
     }
 
     return ret;
@@ -1804,7 +1866,6 @@ static int vhost_user_backend_init(struct vhost_dev *dev, void *opaque)
 
     u = g_new0(struct vhost_user, 1);
     u->user = opaque;
-    u->slave_fd = -1;
     u->dev = dev;
     dev->opaque = u;
 
@@ -1919,10 +1980,8 @@ static int vhost_user_backend_cleanup(struct vhost_dev *dev)
         close(u->postcopy_fd.fd);
         u->postcopy_fd.handler = NULL;
     }
-    if (u->slave_fd >= 0) {
-        qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
-        close(u->slave_fd);
-        u->slave_fd = -1;
+    if (u->slave_ioc) {
+        close_slave_channel(u);
     }
     g_free(u->region_rb);
     u->region_rb = NULL;