summary refs log tree commit diff stats
path: root/io
diff options
context:
space:
mode:
Diffstat (limited to 'io')
-rw-r--r--io/channel-tls.c92
-rw-r--r--io/channel.c9
-rw-r--r--io/trace-events5
3 files changed, 100 insertions, 6 deletions
diff --git a/io/channel-tls.c b/io/channel-tls.c
index aab630e5ae..caf8301a9e 100644
--- a/io/channel-tls.c
+++ b/io/channel-tls.c
@@ -162,16 +162,17 @@ static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
                                            GMainContext *context)
 {
     Error *err = NULL;
-    QCryptoTLSSessionHandshakeStatus status;
+    int status;
 
-    if (qcrypto_tls_session_handshake(ioc->session, &err) < 0) {
+    status = qcrypto_tls_session_handshake(ioc->session, &err);
+
+    if (status < 0) {
         trace_qio_channel_tls_handshake_fail(ioc);
         qio_task_set_error(task, err);
         qio_task_complete(task);
         return;
     }
 
-    status = qcrypto_tls_session_get_handshake_status(ioc->session);
     if (status == QCRYPTO_TLS_HANDSHAKE_COMPLETE) {
         trace_qio_channel_tls_handshake_complete(ioc);
         if (qcrypto_tls_session_check_credentials(ioc->session,
@@ -247,6 +248,85 @@ void qio_channel_tls_handshake(QIOChannelTLS *ioc,
     qio_channel_tls_handshake_task(ioc, task, context);
 }
 
+static gboolean qio_channel_tls_bye_io(QIOChannel *ioc, GIOCondition condition,
+                                       gpointer user_data);
+
+static void qio_channel_tls_bye_task(QIOChannelTLS *ioc, QIOTask *task,
+                                     GMainContext *context)
+{
+    GIOCondition condition;
+    QIOChannelTLSData *data;
+    int status;
+    Error *err = NULL;
+
+    status = qcrypto_tls_session_bye(ioc->session, &err);
+
+    if (status < 0) {
+        trace_qio_channel_tls_bye_fail(ioc);
+        qio_task_set_error(task, err);
+        qio_task_complete(task);
+        return;
+    }
+
+    if (status == QCRYPTO_TLS_BYE_COMPLETE) {
+        qio_task_complete(task);
+        return;
+    }
+
+    data = g_new0(typeof(*data), 1);
+    data->task = task;
+    data->context = context;
+
+    if (context) {
+        g_main_context_ref(context);
+    }
+
+    if (status == QCRYPTO_TLS_BYE_SENDING) {
+        condition = G_IO_OUT;
+    } else {
+        condition = G_IO_IN;
+    }
+
+    trace_qio_channel_tls_bye_pending(ioc, status);
+    ioc->bye_ioc_tag = qio_channel_add_watch_full(ioc->master, condition,
+                                                  qio_channel_tls_bye_io,
+                                                  data, NULL, context);
+}
+
+
+static gboolean qio_channel_tls_bye_io(QIOChannel *ioc, GIOCondition condition,
+                                       gpointer user_data)
+{
+    QIOChannelTLSData *data = user_data;
+    QIOTask *task = data->task;
+    GMainContext *context = data->context;
+    QIOChannelTLS *tioc = QIO_CHANNEL_TLS(qio_task_get_source(task));
+
+    tioc->bye_ioc_tag = 0;
+    g_free(data);
+    qio_channel_tls_bye_task(tioc, task, context);
+
+    if (context) {
+        g_main_context_unref(context);
+    }
+
+    return FALSE;
+}
+
+static void propagate_error(QIOTask *task, gpointer opaque)
+{
+    qio_task_propagate_error(task, opaque);
+}
+
+void qio_channel_tls_bye(QIOChannelTLS *ioc, Error **errp)
+{
+    QIOTask *task;
+
+    task = qio_task_new(OBJECT(ioc), propagate_error, errp, NULL);
+
+    trace_qio_channel_tls_bye_start(ioc);
+    qio_channel_tls_bye_task(ioc, task, NULL);
+}
 
 static void qio_channel_tls_init(Object *obj G_GNUC_UNUSED)
 {
@@ -279,6 +359,7 @@ static ssize_t qio_channel_tls_readv(QIOChannel *ioc,
             tioc->session,
             iov[i].iov_base,
             iov[i].iov_len,
+            flags & QIO_CHANNEL_READ_FLAG_RELAXED_EOF ||
             qatomic_load_acquire(&tioc->shutdown) & QIO_CHANNEL_SHUTDOWN_READ,
             errp);
         if (ret == QCRYPTO_TLS_SESSION_ERR_BLOCK) {
@@ -379,6 +460,11 @@ static int qio_channel_tls_close(QIOChannel *ioc,
         g_clear_handle_id(&tioc->hs_ioc_tag, g_source_remove);
     }
 
+    if (tioc->bye_ioc_tag) {
+        trace_qio_channel_tls_bye_cancel(ioc);
+        g_clear_handle_id(&tioc->bye_ioc_tag, g_source_remove);
+    }
+
     return qio_channel_close(tioc->master, errp);
 }
 
diff --git a/io/channel.c b/io/channel.c
index e3f17c24a0..ebd9322765 100644
--- a/io/channel.c
+++ b/io/channel.c
@@ -115,7 +115,8 @@ int coroutine_mixed_fn qio_channel_readv_all_eof(QIOChannel *ioc,
                                                  size_t niov,
                                                  Error **errp)
 {
-    return qio_channel_readv_full_all_eof(ioc, iov, niov, NULL, NULL, errp);
+    return qio_channel_readv_full_all_eof(ioc, iov, niov, NULL, NULL, 0,
+                                          errp);
 }
 
 int coroutine_mixed_fn qio_channel_readv_all(QIOChannel *ioc,
@@ -130,6 +131,7 @@ int coroutine_mixed_fn qio_channel_readv_full_all_eof(QIOChannel *ioc,
                                                       const struct iovec *iov,
                                                       size_t niov,
                                                       int **fds, size_t *nfds,
+                                                      int flags,
                                                       Error **errp)
 {
     int ret = -1;
@@ -155,7 +157,7 @@ int coroutine_mixed_fn qio_channel_readv_full_all_eof(QIOChannel *ioc,
     while ((nlocal_iov > 0) || local_fds) {
         ssize_t len;
         len = qio_channel_readv_full(ioc, local_iov, nlocal_iov, local_fds,
-                                     local_nfds, 0, errp);
+                                     local_nfds, flags, errp);
         if (len == QIO_CHANNEL_ERR_BLOCK) {
             if (qemu_in_coroutine()) {
                 qio_channel_yield(ioc, G_IO_IN);
@@ -222,7 +224,8 @@ int coroutine_mixed_fn qio_channel_readv_full_all(QIOChannel *ioc,
                                                   int **fds, size_t *nfds,
                                                   Error **errp)
 {
-    int ret = qio_channel_readv_full_all_eof(ioc, iov, niov, fds, nfds, errp);
+    int ret = qio_channel_readv_full_all_eof(ioc, iov, niov, fds, nfds, 0,
+                                             errp);
 
     if (ret == 0) {
         error_setg(errp, "Unexpected end-of-file before all data were read");
diff --git a/io/trace-events b/io/trace-events
index d4c0f84a9a..dc3a63ba1f 100644
--- a/io/trace-events
+++ b/io/trace-events
@@ -44,6 +44,11 @@ qio_channel_tls_handshake_pending(void *ioc, int status) "TLS handshake pending
 qio_channel_tls_handshake_fail(void *ioc) "TLS handshake fail ioc=%p"
 qio_channel_tls_handshake_complete(void *ioc) "TLS handshake complete ioc=%p"
 qio_channel_tls_handshake_cancel(void *ioc) "TLS handshake cancel ioc=%p"
+qio_channel_tls_bye_start(void *ioc) "TLS termination start ioc=%p"
+qio_channel_tls_bye_pending(void *ioc, int status) "TLS termination pending ioc=%p status=%d"
+qio_channel_tls_bye_fail(void *ioc) "TLS termination fail ioc=%p"
+qio_channel_tls_bye_complete(void *ioc) "TLS termination complete ioc=%p"
+qio_channel_tls_bye_cancel(void *ioc) "TLS termination cancel ioc=%p"
 qio_channel_tls_credentials_allow(void *ioc) "TLS credentials allow ioc=%p"
 qio_channel_tls_credentials_deny(void *ioc) "TLS credentials deny ioc=%p"