diff options
Diffstat (limited to 'common/connection.c')
-rw-r--r-- | common/connection.c | 312 |
1 files changed, 312 insertions, 0 deletions
diff --git a/common/connection.c b/common/connection.c new file mode 100644 index 0000000..2545e5d --- /dev/null +++ b/common/connection.c @@ -0,0 +1,312 @@ +#include <errno.h> +#include <stddef.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <unistd.h> + +#include "compiler.h" +#include "connection.h" + +#define CLEN (CMSG_LEN(MAX_FDS_OUT * sizeof(int))) + +ALWAYS_INLINE static uint32_t connection_buffer_mask(const uint32_t idx) { + return idx & (CONNECTION_BUFFER_SIZE - 1); +} + +ALWAYS_INLINE static uint32_t connection_buffer_size(const struct connection_buffer *b) { + return b->head - b->tail; +} + +ALWAYS_INLINE static void connection_buffer_consume(struct connection_buffer *b, const size_t size) { + b->tail += size; +} + +ALWAYS_INLINE static void connection_buffer_restore(struct connection_buffer *b, const size_t size) { + b->tail -= size; +} + +/* + * connection_buffer_get_iov prepares I/O vectors pointing to our ring buffer. + * Two may be used if the buffer has wrapped around. + */ +static void connection_buffer_get_iov(struct connection_buffer *b, struct iovec *iov, int *count) { + uint32_t head = connection_buffer_mask(b->head); + uint32_t tail = connection_buffer_mask(b->tail); + if (tail < head) { + iov[0].iov_base = b->data + tail; + iov[0].iov_len = head - tail; + *count = 1; + } else if (head == 0) { + iov[0].iov_base = b->data + tail; + iov[0].iov_len = sizeof b->data - tail; + *count = 1; + } else { + iov[0].iov_base = b->data + tail; + iov[0].iov_len = sizeof b->data - tail; + iov[1].iov_base = b->data; + iov[1].iov_len = head; + *count = 2; + } +} + +/* + * connection_buffer_put_iov prepares I/O vectors pointing to our ring buffer. + * Two may be used if the buffer has wrapped around. + */ +static void connection_buffer_put_iov(struct connection_buffer *b, struct iovec *iov, int *count) { + uint32_t head = connection_buffer_mask(b->head); + uint32_t tail = connection_buffer_mask(b->tail); + if (head < tail) { + iov[0].iov_base = b->data + head; + iov[0].iov_len = tail - head; + *count = 1; + } else if (tail == 0) { + iov[0].iov_base = b->data + head; + iov[0].iov_len = sizeof b->data - head; + *count = 1; + } else { + iov[0].iov_base = b->data + head; + iov[0].iov_len = sizeof b->data - head; + iov[1].iov_base = b->data; + iov[1].iov_len = tail; + *count = 2; + } +} + +/* + * connection_buffer_copy copies from our ring buffer into a linear buffer. + */ +static void connection_buffer_copy(const struct connection_buffer *b, void *data, const size_t count) { + uint32_t tail = connection_buffer_mask(b->tail); + if (tail + count <= sizeof b->data) { + memcpy(data, b->data + tail, count); + return; + } + + uint32_t size = sizeof b->data - tail; + memcpy(data, b->data + tail, size); + memcpy((char *)data + size, b->data, count - size); +} +/* + * connection_buffer_copy copies from a linear buffer into our ring buffer. + */ +static int connection_buffer_put(struct connection_buffer *b, const void *data, const size_t count) { + if (count > sizeof(b->data)) { + errno = EOVERFLOW; + return -1; + } + + uint32_t head = connection_buffer_mask(b->head); + if (head + count <= sizeof b->data) { + memcpy(b->data + head, data, count); + } else { + uint32_t size = sizeof b->data - head; + memcpy(b->data + head, data, size); + memcpy(b->data, (const char *)data + size, count - size); + } + + b->head += count; + return 0; +} + +/* + * close_fds closes all fds within a connection_buffer + */ +static void connection_buffer_close_fds(struct connection_buffer *buffer) { + size_t size = connection_buffer_size(buffer); + if (size == 0) { + return; + } + int fds[sizeof(buffer->data) / sizeof(int)]; + connection_buffer_copy(buffer, fds, size); + int count = size / sizeof fds[0]; + size = count * sizeof fds[0]; + for (int idx = 0; idx < count; idx++) { + close(fds[idx]); + } + connection_buffer_consume(buffer, size); +} + +/* + * build_cmsg prepares a cmsg from a buffer full of fds + */ +static void build_cmsg(struct connection_buffer *buffer, char *data, int *clen) { + size_t size = connection_buffer_size(buffer); + if (size > MAX_FDS_OUT * sizeof(int)) { + size = MAX_FDS_OUT * sizeof(int); + } + + if (size <= 0) { + *clen = 0; + return; + } + + struct cmsghdr *cmsg = (struct cmsghdr *)data; + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(size); + connection_buffer_copy(buffer, CMSG_DATA(cmsg), size); + *clen = cmsg->cmsg_len; +} + +static int decode_cmsg(struct connection_buffer *buffer, struct msghdr *msg) { + bool overflow = false; + struct cmsghdr *cmsg; + for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) { + continue; + } + + size_t size = cmsg->cmsg_len - CMSG_LEN(0); + size_t max = sizeof(buffer->data) - connection_buffer_size(buffer); + if (size > max || overflow) { + overflow = true; + size /= sizeof(int); + for (size_t idx = 0; idx < size; idx++) { + close(((int *)CMSG_DATA(cmsg))[idx]); + } + } else if (connection_buffer_put(buffer, CMSG_DATA(cmsg), size) < 0) { + return -1; + } + } + + if (overflow) { + errno = EOVERFLOW; + return -1; + } + return 0; +} + +int connection_read(struct connection *connection) { + if (connection_buffer_size(&connection->in) >= sizeof(connection->in.data)) { + errno = EOVERFLOW; + return -1; + } + + int count; + struct iovec iov[2]; + connection_buffer_put_iov(&connection->in, iov, &count); + + char cmsg[CLEN]; + struct msghdr msg = { + .msg_name = NULL, + .msg_namelen = 0, + .msg_iov = iov, + .msg_iovlen = count, + .msg_control = cmsg, + .msg_controllen = sizeof cmsg, + .msg_flags = 0, + }; + + int len; + do { + len = recvmsg(connection->fd, &msg, MSG_DONTWAIT | MSG_CMSG_CLOEXEC); + if (len == -1 && errno != EINTR) + return -1; + } while (len == -1); + + if (decode_cmsg(&connection->fds_in, &msg) != 0) { + return -1; + } + connection->in.head += len; + + return connection_buffer_size(&connection->in); +} + +int connection_flush(struct connection *connection) { + if (!connection->want_flush) { + return 0; + } + + uint32_t tail = connection->out.tail; + while (connection->out.head - connection->out.tail > 0) { + int count; + struct iovec iov[2]; + connection_buffer_get_iov(&connection->out, iov, &count); + + int clen; + char cmsg[CLEN]; + build_cmsg(&connection->fds_out, cmsg, &clen); + struct msghdr msg = { + .msg_name = NULL, + .msg_namelen = 0, + .msg_iov = iov, + .msg_iovlen = count, + .msg_control = (clen > 0) ? cmsg : NULL, + .msg_controllen = clen, + .msg_flags = 0, + }; + + int len; + do { + len = sendmsg(connection->fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT); + if (len == -1 && errno != EINTR) + return -1; + } while (len == -1); + connection_buffer_close_fds(&connection->fds_out); + connection->out.tail += len; + } + connection->want_flush = 0; + return connection->out.head - tail; +} + +int connection_put(struct connection *connection, const void *data, size_t count) { + if (connection_buffer_size(&connection->out) + count > CONNECTION_BUFFER_SIZE) { + connection->want_flush = 1; + if (connection_flush(connection) == -1) { + return -1; + } + } + + if (connection_buffer_put(&connection->out, data, count) == -1) { + return -1; + } + + connection->want_flush = 1; + return 0; +} + +int connection_put_fd(struct connection *connection, int fd) { + if (connection_buffer_size(&connection->fds_out) == MAX_FDS_OUT * sizeof fd) { + errno = EOVERFLOW; + return -1; + } + + return connection_buffer_put(&connection->fds_out, &fd, sizeof fd); +} + +int connection_get(struct connection *connection, void *dst, size_t count) { + if (count > connection_buffer_size(&connection->in)) { + errno = EAGAIN; + return -1; + } + connection_buffer_copy(&connection->in, dst, count); + connection_buffer_consume(&connection->in, count); + return count; +} + +int connection_get_fd(struct connection *connection) { + int fd; + if (sizeof fd > connection_buffer_size(&connection->fds_in)) { + errno = EAGAIN; + return -1; + } + connection_buffer_copy(&connection->fds_in, &fd, sizeof fd); + connection_buffer_consume(&connection->fds_in, sizeof fd); + return fd; +} + +void connection_close_fds(struct connection *connection) { + connection_buffer_close_fds(&connection->fds_in); + connection_buffer_close_fds(&connection->fds_out); +} + +size_t connection_pending(struct connection *connection) { + return connection_buffer_size(&connection->in); +} + +void connection_restore(struct connection *connection, size_t count) { + connection_buffer_restore(&connection->in, count); +} |