summary refs log tree commit diff stats
path: root/contrib/libvhost-user/libvhost-user.c
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/libvhost-user/libvhost-user.c')
-rw-r--r--contrib/libvhost-user/libvhost-user.c139
1 files changed, 126 insertions, 13 deletions
diff --git a/contrib/libvhost-user/libvhost-user.c b/contrib/libvhost-user/libvhost-user.c
index b89bf18501..3bca996c62 100644
--- a/contrib/libvhost-user/libvhost-user.c
+++ b/contrib/libvhost-user/libvhost-user.c
@@ -136,6 +136,7 @@ vu_request_to_string(unsigned int req)
         REQ(VHOST_USER_GET_INFLIGHT_FD),
         REQ(VHOST_USER_SET_INFLIGHT_FD),
         REQ(VHOST_USER_GPU_SET_SOCKET),
+        REQ(VHOST_USER_VRING_KICK),
         REQ(VHOST_USER_MAX),
     };
 #undef REQ
@@ -163,7 +164,10 @@ vu_panic(VuDev *dev, const char *msg, ...)
     dev->panic(dev, buf);
     free(buf);
 
-    /* FIXME: find a way to call virtio_error? */
+    /*
+     * FIXME:
+     * find a way to call virtio_error, or perhaps close the connection?
+     */
 }
 
 /* Translate guest physical address to our virtual address.  */
@@ -948,6 +952,7 @@ static bool
 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
 {
     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
+    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
 
     if (index >= dev->max_queues) {
         vmsg_close_fds(vmsg);
@@ -955,8 +960,12 @@ vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
         return false;
     }
 
-    if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
-        vmsg->fd_num != 1) {
+    if (nofd) {
+        vmsg_close_fds(vmsg);
+        return true;
+    }
+
+    if (vmsg->fd_num != 1) {
         vmsg_close_fds(vmsg);
         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
         return false;
@@ -1053,6 +1062,7 @@ static bool
 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
 {
     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
+    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
 
     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 
@@ -1066,8 +1076,8 @@ vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
         dev->vq[index].kick_fd = -1;
     }
 
-    dev->vq[index].kick_fd = vmsg->fds[0];
-    DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
+    dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
+    DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
 
     dev->vq[index].started = true;
     if (dev->iface->queue_set_started) {
@@ -1147,6 +1157,7 @@ static bool
 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
 {
     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
+    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
 
     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 
@@ -1159,14 +1170,14 @@ vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
         dev->vq[index].call_fd = -1;
     }
 
-    dev->vq[index].call_fd = vmsg->fds[0];
+    dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
 
     /* in case of I/O hang after reconnecting */
-    if (eventfd_write(vmsg->fds[0], 1)) {
+    if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
         return -1;
     }
 
-    DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);
+    DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
 
     return false;
 }
@@ -1175,6 +1186,7 @@ static bool
 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
 {
     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
+    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
 
     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 
@@ -1187,7 +1199,7 @@ vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
         dev->vq[index].err_fd = -1;
     }
 
-    dev->vq[index].err_fd = vmsg->fds[0];
+    dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
 
     return false;
 }
@@ -1195,11 +1207,20 @@ vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
 static bool
 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 {
+    /*
+     * Note that we support, but intentionally do not set,
+     * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
+     * a device implementation can return it in its callback
+     * (get_protocol_features) if it wants to use this for
+     * simulation, but it is otherwise not desirable (if even
+     * implemented by the master.)
+     */
     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
                         1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
                         1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
-                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD;
+                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
+                        1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK;
 
     if (have_userfault()) {
         features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
@@ -1226,6 +1247,25 @@ vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 
     dev->protocol_features = vmsg->payload.u64;
 
+    if (vu_has_protocol_feature(dev,
+                                VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
+        (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
+         !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
+        /*
+         * The use case for using messages for kick/call is simulation, to make
+         * the kick and call synchronous. To actually get that behaviour, both
+         * of the other features are required.
+         * Theoretically, one could use only kick messages, or do them without
+         * having F_REPLY_ACK, but too many (possibly pending) messages on the
+         * socket will eventually cause the master to hang, to avoid this in
+         * scenarios where not desired enforce that the settings are in a way
+         * that actually enables the simulation case.
+         */
+        vu_panic(dev,
+                 "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
+        return false;
+    }
+
     if (dev->iface->set_protocol_features) {
         dev->iface->set_protocol_features(dev, features);
     }
@@ -1487,6 +1527,34 @@ vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
 }
 
 static bool
+vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
+{
+    unsigned int index = vmsg->payload.state.index;
+
+    if (index >= dev->max_queues) {
+        vu_panic(dev, "Invalid queue index: %u", index);
+        return false;
+    }
+
+    DPRINT("Got kick message: handler:%p idx:%d\n",
+           dev->vq[index].handler, index);
+
+    if (!dev->vq[index].started) {
+        dev->vq[index].started = true;
+
+        if (dev->iface->queue_set_started) {
+            dev->iface->queue_set_started(dev, index, true);
+        }
+    }
+
+    if (dev->vq[index].handler) {
+        dev->vq[index].handler(dev, index);
+    }
+
+    return false;
+}
+
+static bool
 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
 {
     int do_reply = 0;
@@ -1568,6 +1636,8 @@ vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
         return vu_get_inflight_fd(dev, vmsg);
     case VHOST_USER_SET_INFLIGHT_FD:
         return vu_set_inflight_fd(dev, vmsg);
+    case VHOST_USER_VRING_KICK:
+        return vu_handle_vring_kick(dev, vmsg);
     default:
         vmsg_close_fds(vmsg);
         vu_panic(dev, "Unhandled request: %d", vmsg->request);
@@ -1581,13 +1651,20 @@ vu_dispatch(VuDev *dev)
 {
     VhostUserMsg vmsg = { 0, };
     int reply_requested;
-    bool success = false;
+    bool need_reply, success = false;
 
     if (!vu_message_read(dev, dev->sock, &vmsg)) {
         goto end;
     }
 
+    need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
+
     reply_requested = vu_process_message(dev, &vmsg);
+    if (!reply_requested && need_reply) {
+        vmsg_set_reply_u64(&vmsg, 0);
+        reply_requested = 1;
+    }
+
     if (!reply_requested) {
         success = true;
         goto end;
@@ -2022,8 +2099,7 @@ vring_notify(VuDev *dev, VuVirtq *vq)
     return !v || vring_need_event(vring_get_used_event(vq), new, old);
 }
 
-void
-vu_queue_notify(VuDev *dev, VuVirtq *vq)
+static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
 {
     if (unlikely(dev->broken) ||
         unlikely(!vq->vring.avail)) {
@@ -2035,11 +2111,48 @@ vu_queue_notify(VuDev *dev, VuVirtq *vq)
         return;
     }
 
+    if (vq->call_fd < 0 &&
+        vu_has_protocol_feature(dev,
+                                VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
+        vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
+        VhostUserMsg vmsg = {
+            .request = VHOST_USER_SLAVE_VRING_CALL,
+            .flags = VHOST_USER_VERSION,
+            .size = sizeof(vmsg.payload.state),
+            .payload.state = {
+                .index = vq - dev->vq,
+            },
+        };
+        bool ack = sync &&
+                   vu_has_protocol_feature(dev,
+                                           VHOST_USER_PROTOCOL_F_REPLY_ACK);
+
+        if (ack) {
+            vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
+        }
+
+        vu_message_write(dev, dev->slave_fd, &vmsg);
+        if (ack) {
+            vu_message_read(dev, dev->slave_fd, &vmsg);
+        }
+        return;
+    }
+
     if (eventfd_write(vq->call_fd, 1) < 0) {
         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
     }
 }
 
+void vu_queue_notify(VuDev *dev, VuVirtq *vq)
+{
+    _vu_queue_notify(dev, vq, false);
+}
+
+void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
+{
+    _vu_queue_notify(dev, vq, true);
+}
+
 static inline void
 vring_used_flags_set_bit(VuVirtq *vq, int mask)
 {