summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--nbd.c33
-rw-r--r--nbd.h1
2 files changed, 28 insertions, 6 deletions
diff --git a/nbd.c b/nbd.c
index 4aeb80ae18..eb72f4a6e5 100644
--- a/nbd.c
+++ b/nbd.c
@@ -109,6 +109,7 @@ struct NBDClient {
     Coroutine *send_coroutine;
 
     int nb_requests;
+    bool closing;
 };
 
 /* That's all folks */
@@ -655,19 +656,35 @@ void nbd_client_get(NBDClient *client)
 void nbd_client_put(NBDClient *client)
 {
     if (--client->refcount == 0) {
+        /* The last reference should be dropped by client->close,
+         * which is called by nbd_client_close.
+         */
+        assert(client->closing);
+
+        qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
+        close(client->sock);
+        client->sock = -1;
         g_free(client);
     }
 }
 
-static void nbd_client_close(NBDClient *client)
+void nbd_client_close(NBDClient *client)
 {
-    qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
-    close(client->sock);
-    client->sock = -1;
+    if (client->closing) {
+        return;
+    }
+
+    client->closing = true;
+
+    /* Force requests to finish.  They will drop their own references,
+     * then we'll close the socket and free the NBDClient.
+     */
+    shutdown(client->sock, 2);
+
+    /* Also tell the client, so that they release their reference.  */
     if (client->close) {
         client->close(client);
     }
-    nbd_client_put(client);
 }
 
 static NBDRequest *nbd_request_get(NBDClient *client)
@@ -810,14 +827,18 @@ out:
 static void nbd_trip(void *opaque)
 {
     NBDClient *client = opaque;
-    NBDRequest *req = nbd_request_get(client);
     NBDExport *exp = client->exp;
+    NBDRequest *req;
     struct nbd_request request;
     struct nbd_reply reply;
     ssize_t ret;
 
     TRACE("Reading request.");
+    if (client->closing) {
+        return;
+    }
 
+    req = nbd_request_get(client);
     ret = nbd_co_receive_request(req, &request);
     if (ret == -EAGAIN) {
         goto done;
diff --git a/nbd.h b/nbd.h
index a9038dc196..8b84a50ed4 100644
--- a/nbd.h
+++ b/nbd.h
@@ -84,6 +84,7 @@ void nbd_export_close(NBDExport *exp);
 
 NBDClient *nbd_client_new(NBDExport *exp, int csock,
                           void (*close)(NBDClient *));
+void nbd_client_close(NBDClient *client);
 void nbd_client_get(NBDClient *client);
 void nbd_client_put(NBDClient *client);