#include #include #include #include #include #include #include #include "connection.h" static inline uint32_t connection_buffer_mask(const uint32_t idx) { return idx & (CONNECTION_BUFFER_SIZE - 1); } static inline uint32_t connection_buffer_size(const struct connection_buffer *b) { return b->head - b->tail; } static inline void connection_buffer_consume(struct connection_buffer *b, const size_t size) { b->tail += size; } static inline void connection_buffer_restore(struct connection_buffer *b, const size_t size) { b->tail -= size; } /* * connection_buffer_get_iov prepares I/O vectors pointing to our ring buffer. * Two may be used if the buffer has wrapped around. */ static void connection_buffer_get_iov(struct connection_buffer *b, struct iovec *iov, int *count) { uint32_t head = connection_buffer_mask(b->head); uint32_t tail = connection_buffer_mask(b->tail); if (tail < head) { iov[0].iov_base = b->data + tail; iov[0].iov_len = head - tail; *count = 1; } else if (head == 0) { iov[0].iov_base = b->data + tail; iov[0].iov_len = sizeof b->data - tail; *count = 1; } else { iov[0].iov_base = b->data + tail; iov[0].iov_len = sizeof b->data - tail; iov[1].iov_base = b->data; iov[1].iov_len = head; *count = 2; } } /* * connection_buffer_put_iov prepares I/O vectors pointing to our ring buffer. * Two may be used if the buffer has wrapped around. */ static void connection_buffer_put_iov(struct connection_buffer *b, struct iovec *iov, int *count) { uint32_t head = connection_buffer_mask(b->head); uint32_t tail = connection_buffer_mask(b->tail); if (head < tail) { iov[0].iov_base = b->data + head; iov[0].iov_len = tail - head; *count = 1; } else if (tail == 0) { iov[0].iov_base = b->data + head; iov[0].iov_len = sizeof b->data - head; *count = 1; } else { iov[0].iov_base = b->data + head; iov[0].iov_len = sizeof b->data - head; iov[1].iov_base = b->data; iov[1].iov_len = tail; *count = 2; } } /* * connection_buffer_copy copies from our ring buffer into a linear buffer. */ static void connection_buffer_copy(const struct connection_buffer *b, void *data, const size_t count) { uint32_t tail = connection_buffer_mask(b->tail); if (tail + count <= sizeof b->data) { memcpy(data, b->data + tail, count); return; } uint32_t size = sizeof b->data - tail; memcpy(data, b->data + tail, size); memcpy((char *)data + size, b->data, count - size); } /* * connection_buffer_copy copies from a linear buffer into our ring buffer. */ static int connection_buffer_put(struct connection_buffer *b, const void *data, const size_t count) { if (count > sizeof(b->data)) { errno = EOVERFLOW; return -1; } uint32_t head = connection_buffer_mask(b->head); if (head + count <= sizeof b->data) { memcpy(b->data + head, data, count); } else { uint32_t size = sizeof b->data - head; memcpy(b->data + head, data, size); memcpy(b->data, (const char *)data + size, count - size); } b->head += count; return 0; } /* * close_fds closes all fds within a connection_buffer */ static void connection_buffer_close_fds(struct connection_buffer *buffer) { size_t size = connection_buffer_size(buffer); if (size == 0) { return; } int fds[sizeof(buffer->data) / sizeof(int)]; connection_buffer_copy(buffer, fds, size); int count = size / sizeof fds[0]; size = count * sizeof fds[0]; for (int idx = 0; idx < count; idx++) { close(fds[idx]); } connection_buffer_consume(buffer, size); } /* * build_cmsg prepares a cmsg from a buffer full of fds */ static void build_cmsg(struct connection_buffer *buffer, char *data, int *clen) { size_t size = connection_buffer_size(buffer); if (size > MAX_FDS * sizeof(int)) { size = MAX_FDS * sizeof(int); } if (size <= 0) { *clen = 0; return; } struct cmsghdr *cmsg = (struct cmsghdr *)data; cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; cmsg->cmsg_len = CMSG_LEN(size); connection_buffer_copy(buffer, CMSG_DATA(cmsg), size); *clen = cmsg->cmsg_len; } static int decode_cmsg(struct connection_buffer *buffer, struct msghdr *msg) { bool overflow = false; struct cmsghdr *cmsg; for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) { if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) { continue; } size_t size = cmsg->cmsg_len - CMSG_LEN(0); size_t max = sizeof(buffer->data) - connection_buffer_size(buffer); if (size > max || overflow) { overflow = true; size /= sizeof(int); for (size_t idx = 0; idx < size; idx++) { close(((int *)CMSG_DATA(cmsg))[idx]); } } else if (connection_buffer_put(buffer, CMSG_DATA(cmsg), size) < 0) { return -1; } } if (overflow) { errno = EOVERFLOW; return -1; } return 0; } int connection_read(struct connection *connection) { if (connection_buffer_size(&connection->in) >= sizeof(connection->in.data)) { errno = EOVERFLOW; return -1; } int count; struct iovec iov[2]; connection_buffer_put_iov(&connection->in, iov, &count); char cmsg[CMSG_LEN(CONNECTION_BUFFER_SIZE)]; struct msghdr msg = { .msg_name = NULL, .msg_namelen = 0, .msg_iov = iov, .msg_iovlen = count, .msg_control = cmsg, .msg_controllen = sizeof cmsg, .msg_flags = 0, }; int len; do { len = recvmsg(connection->fd, &msg, MSG_DONTWAIT | MSG_CMSG_CLOEXEC); if (len == -1 && errno != EINTR) return -1; } while (len == -1); if (decode_cmsg(&connection->fds_in, &msg) != 0) { return -1; } connection->in.head += len; return connection_buffer_size(&connection->in); } int connection_flush(struct connection *connection) { if (!connection->want_flush) { return 0; } uint32_t tail = connection->out.tail; while (connection->out.head - connection->out.tail > 0) { int count; struct iovec iov[2]; connection_buffer_get_iov(&connection->out, iov, &count); int clen; char cmsg[CMSG_LEN(CONNECTION_BUFFER_SIZE)]; build_cmsg(&connection->fds_out, cmsg, &clen); struct msghdr msg = { .msg_name = NULL, .msg_namelen = 0, .msg_iov = iov, .msg_iovlen = count, .msg_control = (clen > 0) ? cmsg : NULL, .msg_controllen = clen, .msg_flags = 0, }; int len; do { len = sendmsg(connection->fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT); if (len == -1 && errno != EINTR) return -1; } while (len == -1); connection_buffer_close_fds(&connection->fds_out); connection->out.tail += len; } connection->want_flush = 0; return connection->out.head - tail; } int connection_put(struct connection *connection, const void *data, size_t count) { if (connection_buffer_size(&connection->out) + count > CONNECTION_BUFFER_SIZE) { connection->want_flush = 1; if (connection_flush(connection) == -1) { return -1; } } if (connection_buffer_put(&connection->out, data, count) == -1) { return -1; } connection->want_flush = 1; return 0; } int connection_put_fd(struct connection *connection, int fd) { if (connection_buffer_size(&connection->fds_out) == MAX_FDS * sizeof fd) { errno = EOVERFLOW; return -1; } return connection_buffer_put(&connection->fds_out, &fd, sizeof fd); } int connection_get(struct connection *connection, void *dst, size_t count) { if (count > connection_buffer_size(&connection->in)) { errno = EAGAIN; return -1; } connection_buffer_copy(&connection->in, dst, count); connection_buffer_consume(&connection->in, count); return count; } int connection_get_fd(struct connection *connection) { int fd; if (sizeof fd > connection_buffer_size(&connection->fds_in)) { errno = EAGAIN; return -1; } connection_buffer_copy(&connection->fds_in, &fd, sizeof fd); connection_buffer_consume(&connection->fds_in, sizeof fd); return fd; } void connection_close_fds(struct connection *connection) { connection_buffer_close_fds(&connection->fds_in); connection_buffer_close_fds(&connection->fds_out); } size_t connection_pending(struct connection *connection) { return connection_buffer_size(&connection->in); } void connection_restore(struct connection *connection, size_t count) { connection_buffer_restore(&connection->in, count); }