aboutsummaryrefslogtreecommitdiff
path: root/common/connection.c
diff options
context:
space:
mode:
Diffstat (limited to 'common/connection.c')
-rw-r--r--common/connection.c312
1 files changed, 312 insertions, 0 deletions
diff --git a/common/connection.c b/common/connection.c
new file mode 100644
index 0000000..2545e5d
--- /dev/null
+++ b/common/connection.c
@@ -0,0 +1,312 @@
+#include <errno.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include "compiler.h"
+#include "connection.h"
+
+#define CLEN (CMSG_LEN(MAX_FDS_OUT * sizeof(int)))
+
+ALWAYS_INLINE static uint32_t connection_buffer_mask(const uint32_t idx) {
+ return idx & (CONNECTION_BUFFER_SIZE - 1);
+}
+
+ALWAYS_INLINE static uint32_t connection_buffer_size(const struct connection_buffer *b) {
+ return b->head - b->tail;
+}
+
+ALWAYS_INLINE static void connection_buffer_consume(struct connection_buffer *b, const size_t size) {
+ b->tail += size;
+}
+
+ALWAYS_INLINE static void connection_buffer_restore(struct connection_buffer *b, const size_t size) {
+ b->tail -= size;
+}
+
+/*
+ * connection_buffer_get_iov prepares I/O vectors pointing to our ring buffer.
+ * Two may be used if the buffer has wrapped around.
+ */
+static void connection_buffer_get_iov(struct connection_buffer *b, struct iovec *iov, int *count) {
+ uint32_t head = connection_buffer_mask(b->head);
+ uint32_t tail = connection_buffer_mask(b->tail);
+ if (tail < head) {
+ iov[0].iov_base = b->data + tail;
+ iov[0].iov_len = head - tail;
+ *count = 1;
+ } else if (head == 0) {
+ iov[0].iov_base = b->data + tail;
+ iov[0].iov_len = sizeof b->data - tail;
+ *count = 1;
+ } else {
+ iov[0].iov_base = b->data + tail;
+ iov[0].iov_len = sizeof b->data - tail;
+ iov[1].iov_base = b->data;
+ iov[1].iov_len = head;
+ *count = 2;
+ }
+}
+
+/*
+ * connection_buffer_put_iov prepares I/O vectors pointing to our ring buffer.
+ * Two may be used if the buffer has wrapped around.
+ */
+static void connection_buffer_put_iov(struct connection_buffer *b, struct iovec *iov, int *count) {
+ uint32_t head = connection_buffer_mask(b->head);
+ uint32_t tail = connection_buffer_mask(b->tail);
+ if (head < tail) {
+ iov[0].iov_base = b->data + head;
+ iov[0].iov_len = tail - head;
+ *count = 1;
+ } else if (tail == 0) {
+ iov[0].iov_base = b->data + head;
+ iov[0].iov_len = sizeof b->data - head;
+ *count = 1;
+ } else {
+ iov[0].iov_base = b->data + head;
+ iov[0].iov_len = sizeof b->data - head;
+ iov[1].iov_base = b->data;
+ iov[1].iov_len = tail;
+ *count = 2;
+ }
+}
+
+/*
+ * connection_buffer_copy copies from our ring buffer into a linear buffer.
+ */
+static void connection_buffer_copy(const struct connection_buffer *b, void *data, const size_t count) {
+ uint32_t tail = connection_buffer_mask(b->tail);
+ if (tail + count <= sizeof b->data) {
+ memcpy(data, b->data + tail, count);
+ return;
+ }
+
+ uint32_t size = sizeof b->data - tail;
+ memcpy(data, b->data + tail, size);
+ memcpy((char *)data + size, b->data, count - size);
+}
+/*
+ * connection_buffer_copy copies from a linear buffer into our ring buffer.
+ */
+static int connection_buffer_put(struct connection_buffer *b, const void *data, const size_t count) {
+ if (count > sizeof(b->data)) {
+ errno = EOVERFLOW;
+ return -1;
+ }
+
+ uint32_t head = connection_buffer_mask(b->head);
+ if (head + count <= sizeof b->data) {
+ memcpy(b->data + head, data, count);
+ } else {
+ uint32_t size = sizeof b->data - head;
+ memcpy(b->data + head, data, size);
+ memcpy(b->data, (const char *)data + size, count - size);
+ }
+
+ b->head += count;
+ return 0;
+}
+
+/*
+ * close_fds closes all fds within a connection_buffer
+ */
+static void connection_buffer_close_fds(struct connection_buffer *buffer) {
+ size_t size = connection_buffer_size(buffer);
+ if (size == 0) {
+ return;
+ }
+ int fds[sizeof(buffer->data) / sizeof(int)];
+ connection_buffer_copy(buffer, fds, size);
+ int count = size / sizeof fds[0];
+ size = count * sizeof fds[0];
+ for (int idx = 0; idx < count; idx++) {
+ close(fds[idx]);
+ }
+ connection_buffer_consume(buffer, size);
+}
+
+/*
+ * build_cmsg prepares a cmsg from a buffer full of fds
+ */
+static void build_cmsg(struct connection_buffer *buffer, char *data, int *clen) {
+ size_t size = connection_buffer_size(buffer);
+ if (size > MAX_FDS_OUT * sizeof(int)) {
+ size = MAX_FDS_OUT * sizeof(int);
+ }
+
+ if (size <= 0) {
+ *clen = 0;
+ return;
+ }
+
+ struct cmsghdr *cmsg = (struct cmsghdr *)data;
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(size);
+ connection_buffer_copy(buffer, CMSG_DATA(cmsg), size);
+ *clen = cmsg->cmsg_len;
+}
+
+static int decode_cmsg(struct connection_buffer *buffer, struct msghdr *msg) {
+ bool overflow = false;
+ struct cmsghdr *cmsg;
+ for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
+ if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
+ continue;
+ }
+
+ size_t size = cmsg->cmsg_len - CMSG_LEN(0);
+ size_t max = sizeof(buffer->data) - connection_buffer_size(buffer);
+ if (size > max || overflow) {
+ overflow = true;
+ size /= sizeof(int);
+ for (size_t idx = 0; idx < size; idx++) {
+ close(((int *)CMSG_DATA(cmsg))[idx]);
+ }
+ } else if (connection_buffer_put(buffer, CMSG_DATA(cmsg), size) < 0) {
+ return -1;
+ }
+ }
+
+ if (overflow) {
+ errno = EOVERFLOW;
+ return -1;
+ }
+ return 0;
+}
+
+int connection_read(struct connection *connection) {
+ if (connection_buffer_size(&connection->in) >= sizeof(connection->in.data)) {
+ errno = EOVERFLOW;
+ return -1;
+ }
+
+ int count;
+ struct iovec iov[2];
+ connection_buffer_put_iov(&connection->in, iov, &count);
+
+ char cmsg[CLEN];
+ struct msghdr msg = {
+ .msg_name = NULL,
+ .msg_namelen = 0,
+ .msg_iov = iov,
+ .msg_iovlen = count,
+ .msg_control = cmsg,
+ .msg_controllen = sizeof cmsg,
+ .msg_flags = 0,
+ };
+
+ int len;
+ do {
+ len = recvmsg(connection->fd, &msg, MSG_DONTWAIT | MSG_CMSG_CLOEXEC);
+ if (len == -1 && errno != EINTR)
+ return -1;
+ } while (len == -1);
+
+ if (decode_cmsg(&connection->fds_in, &msg) != 0) {
+ return -1;
+ }
+ connection->in.head += len;
+
+ return connection_buffer_size(&connection->in);
+}
+
+int connection_flush(struct connection *connection) {
+ if (!connection->want_flush) {
+ return 0;
+ }
+
+ uint32_t tail = connection->out.tail;
+ while (connection->out.head - connection->out.tail > 0) {
+ int count;
+ struct iovec iov[2];
+ connection_buffer_get_iov(&connection->out, iov, &count);
+
+ int clen;
+ char cmsg[CLEN];
+ build_cmsg(&connection->fds_out, cmsg, &clen);
+ struct msghdr msg = {
+ .msg_name = NULL,
+ .msg_namelen = 0,
+ .msg_iov = iov,
+ .msg_iovlen = count,
+ .msg_control = (clen > 0) ? cmsg : NULL,
+ .msg_controllen = clen,
+ .msg_flags = 0,
+ };
+
+ int len;
+ do {
+ len = sendmsg(connection->fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT);
+ if (len == -1 && errno != EINTR)
+ return -1;
+ } while (len == -1);
+ connection_buffer_close_fds(&connection->fds_out);
+ connection->out.tail += len;
+ }
+ connection->want_flush = 0;
+ return connection->out.head - tail;
+}
+
+int connection_put(struct connection *connection, const void *data, size_t count) {
+ if (connection_buffer_size(&connection->out) + count > CONNECTION_BUFFER_SIZE) {
+ connection->want_flush = 1;
+ if (connection_flush(connection) == -1) {
+ return -1;
+ }
+ }
+
+ if (connection_buffer_put(&connection->out, data, count) == -1) {
+ return -1;
+ }
+
+ connection->want_flush = 1;
+ return 0;
+}
+
+int connection_put_fd(struct connection *connection, int fd) {
+ if (connection_buffer_size(&connection->fds_out) == MAX_FDS_OUT * sizeof fd) {
+ errno = EOVERFLOW;
+ return -1;
+ }
+
+ return connection_buffer_put(&connection->fds_out, &fd, sizeof fd);
+}
+
+int connection_get(struct connection *connection, void *dst, size_t count) {
+ if (count > connection_buffer_size(&connection->in)) {
+ errno = EAGAIN;
+ return -1;
+ }
+ connection_buffer_copy(&connection->in, dst, count);
+ connection_buffer_consume(&connection->in, count);
+ return count;
+}
+
+int connection_get_fd(struct connection *connection) {
+ int fd;
+ if (sizeof fd > connection_buffer_size(&connection->fds_in)) {
+ errno = EAGAIN;
+ return -1;
+ }
+ connection_buffer_copy(&connection->fds_in, &fd, sizeof fd);
+ connection_buffer_consume(&connection->fds_in, sizeof fd);
+ return fd;
+}
+
+void connection_close_fds(struct connection *connection) {
+ connection_buffer_close_fds(&connection->fds_in);
+ connection_buffer_close_fds(&connection->fds_out);
+}
+
+size_t connection_pending(struct connection *connection) {
+ return connection_buffer_size(&connection->in);
+}
+
+void connection_restore(struct connection *connection, size_t count) {
+ connection_buffer_restore(&connection->in, count);
+}