summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--nbd.c74
1 files changed, 53 insertions, 21 deletions
diff --git a/nbd.c b/nbd.c
index ca18c10a19..6d7d1f8d59 100644
--- a/nbd.c
+++ b/nbd.c
@@ -20,6 +20,8 @@
 #include "block.h"
 #include "block_int.h"
 
+#include "qemu-coroutine.h"
+
 #include <errno.h>
 #include <string.h>
 #ifndef _WIN32
@@ -607,6 +609,11 @@ struct NBDClient {
 
     NBDExport *exp;
     int sock;
+
+    Coroutine *recv_coroutine;
+
+    CoMutex send_lock;
+    Coroutine *send_coroutine;
 };
 
 static void nbd_client_get(NBDClient *client)
@@ -681,13 +688,20 @@ void nbd_export_close(NBDExport *exp)
     g_free(exp);
 }
 
-static int nbd_do_send_reply(NBDRequest *req, struct nbd_reply *reply,
+static void nbd_read(void *opaque);
+static void nbd_restart_write(void *opaque);
+
+static int nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
                              int len)
 {
     NBDClient *client = req->client;
     int csock = client->sock;
     int rc, ret;
 
+    qemu_co_mutex_lock(&client->send_lock);
+    qemu_set_fd_handler2(csock, NULL, nbd_read, nbd_restart_write, client);
+    client->send_coroutine = qemu_coroutine_self();
+
     if (!len) {
         rc = nbd_send_reply(csock, reply);
         if (rc == -1) {
@@ -697,7 +711,7 @@ static int nbd_do_send_reply(NBDRequest *req, struct nbd_reply *reply,
         socket_set_cork(csock, 1);
         rc = nbd_send_reply(csock, reply);
         if (rc != -1) {
-            ret = write_sync(csock, req->data, len);
+            ret = qemu_co_send(csock, req->data, len);
             if (ret != len) {
                 errno = EIO;
                 rc = -1;
@@ -708,15 +722,20 @@ static int nbd_do_send_reply(NBDRequest *req, struct nbd_reply *reply,
         }
         socket_set_cork(csock, 0);
     }
+
+    client->send_coroutine = NULL;
+    qemu_set_fd_handler2(csock, NULL, nbd_read, NULL, client);
+    qemu_co_mutex_unlock(&client->send_lock);
     return rc;
 }
 
-static int nbd_do_receive_request(NBDRequest *req, struct nbd_request *request)
+static int nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
 {
     NBDClient *client = req->client;
     int csock = client->sock;
     int rc;
 
+    client->recv_coroutine = qemu_coroutine_self();
     if (nbd_receive_request(csock, request) == -1) {
         rc = -EIO;
         goto out;
@@ -741,7 +760,7 @@ static int nbd_do_receive_request(NBDRequest *req, struct nbd_request *request)
     if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) {
         TRACE("Reading %u byte(s)", request->len);
 
-        if (read_sync(csock, req->data, request->len) != request->len) {
+        if (qemu_co_recv(csock, req->data, request->len) != request->len) {
             LOG("reading from socket failed");
             rc = -EIO;
             goto out;
@@ -750,21 +769,22 @@ static int nbd_do_receive_request(NBDRequest *req, struct nbd_request *request)
     rc = 0;
 
 out:
+    client->recv_coroutine = NULL;
     return rc;
 }
 
-static int nbd_trip(NBDClient *client)
+static void nbd_trip(void *opaque)
 {
+    NBDClient *client = opaque;
     NBDRequest *req = nbd_request_get(client);
     NBDExport *exp = client->exp;
     struct nbd_request request;
     struct nbd_reply reply;
-    int rc = -1;
     int ret;
 
     TRACE("Reading request.");
 
-    ret = nbd_do_receive_request(req, &request);
+    ret = nbd_co_receive_request(req, &request);
     if (ret == -EIO) {
         goto out;
     }
@@ -799,7 +819,7 @@ static int nbd_trip(NBDClient *client)
         }
 
         TRACE("Read %u byte(s)", request.len);
-        if (nbd_do_send_reply(req, &reply, request.len) < 0)
+        if (nbd_co_send_reply(req, &reply, request.len) < 0)
             goto out;
         break;
     case NBD_CMD_WRITE:
@@ -822,7 +842,7 @@ static int nbd_trip(NBDClient *client)
         }
 
         if (request.type & NBD_CMD_FLAG_FUA) {
-            ret = bdrv_flush(exp->bs);
+            ret = bdrv_co_flush(exp->bs);
             if (ret < 0) {
                 LOG("flush failed");
                 reply.error = -ret;
@@ -830,34 +850,34 @@ static int nbd_trip(NBDClient *client)
             }
         }
 
-        if (nbd_do_send_reply(req, &reply, 0) < 0)
+        if (nbd_co_send_reply(req, &reply, 0) < 0)
             goto out;
         break;
     case NBD_CMD_DISC:
         TRACE("Request type is DISCONNECT");
         errno = 0;
-        return 1;
+        goto out;
     case NBD_CMD_FLUSH:
         TRACE("Request type is FLUSH");
 
-        ret = bdrv_flush(exp->bs);
+        ret = bdrv_co_flush(exp->bs);
         if (ret < 0) {
             LOG("flush failed");
             reply.error = -ret;
         }
 
-        if (nbd_do_send_reply(req, &reply, 0) < 0)
+        if (nbd_co_send_reply(req, &reply, 0) < 0)
             goto out;
         break;
     case NBD_CMD_TRIM:
         TRACE("Request type is TRIM");
-        ret = bdrv_discard(exp->bs, (request.from + exp->dev_offset) / 512,
-                           request.len / 512);
+        ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
+                              request.len / 512);
         if (ret < 0) {
             LOG("discard failed");
             reply.error = -ret;
         }
-        if (nbd_do_send_reply(req, &reply, 0) < 0)
+        if (nbd_co_send_reply(req, &reply, 0) < 0)
             goto out;
         break;
     default:
@@ -865,28 +885,39 @@ static int nbd_trip(NBDClient *client)
     invalid_request:
         reply.error = -EINVAL;
     error_reply:
-        if (nbd_do_send_reply(req, &reply, 0) == -1)
+        if (nbd_co_send_reply(req, &reply, 0) == -1)
             goto out;
         break;
     }
 
     TRACE("Request/Reply complete");
 
-    rc = 0;
+    nbd_request_put(req);
+    return;
+
 out:
     nbd_request_put(req);
-    return rc;
+    nbd_client_close(client);
 }
 
 static void nbd_read(void *opaque)
 {
     NBDClient *client = opaque;
 
-    if (nbd_trip(client) != 0) {
-        nbd_client_close(client);
+    if (client->recv_coroutine) {
+        qemu_coroutine_enter(client->recv_coroutine, NULL);
+    } else {
+        qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
     }
 }
 
+static void nbd_restart_write(void *opaque)
+{
+    NBDClient *client = opaque;
+
+    qemu_coroutine_enter(client->send_coroutine, NULL);
+}
+
 NBDClient *nbd_client_new(NBDExport *exp, int csock,
                           void (*close)(NBDClient *))
 {
@@ -899,6 +930,7 @@ NBDClient *nbd_client_new(NBDExport *exp, int csock,
     client->exp = exp;
     client->sock = csock;
     client->close = close;
+    qemu_co_mutex_init(&client->send_lock);
     qemu_set_fd_handler2(csock, NULL, nbd_read, NULL, client);
     return client;
 }