diff options
-rw-r--r-- | common/connection.c | 11 | ||||
-rw-r--r-- | include/connection.h | 2 | ||||
-rw-r--r-- | libseat/backend/seatd.c | 195 |
3 files changed, 102 insertions, 106 deletions
diff --git a/common/connection.c b/common/connection.c index a6739c0..6b2c366 100644 --- a/common/connection.c +++ b/common/connection.c @@ -284,15 +284,14 @@ int connection_get(struct connection *connection, void *dst, size_t count) { return count; } -int connection_get_fd(struct connection *connection) { - int fd; - if (sizeof fd > connection_buffer_size(&connection->fds_in)) { +int connection_get_fd(struct connection *connection, int *fd) { + if (sizeof(int) > connection_buffer_size(&connection->fds_in)) { errno = EAGAIN; return -1; } - connection_buffer_copy(&connection->fds_in, &fd, sizeof fd); - connection_buffer_consume(&connection->fds_in, sizeof fd); - return fd; + connection_buffer_copy(&connection->fds_in, fd, sizeof(int)); + connection_buffer_consume(&connection->fds_in, sizeof(int)); + return 0; } void connection_close_fds(struct connection *connection) { diff --git a/include/connection.h b/include/connection.h index 3e15403..d32e766 100644 --- a/include/connection.h +++ b/include/connection.h @@ -28,7 +28,7 @@ int connection_put_fd(struct connection *connection, int fd); size_t connection_pending(struct connection *connection); int connection_get(struct connection *connection, void *dst, size_t count); -int connection_get_fd(struct connection *connection); +int connection_get_fd(struct connection *connection, int *fd); void connection_restore(struct connection *connection, size_t count); void connection_close_fds(struct connection *connection); diff --git a/libseat/backend/seatd.c b/libseat/backend/seatd.c index 6321d39..718838a 100644 --- a/libseat/backend/seatd.c +++ b/libseat/backend/seatd.c @@ -58,9 +58,11 @@ static int seatd_connect(void) { } addr = {{0}}; int fd = socket(AF_UNIX, SOCK_STREAM, 0); if (fd == -1) { + log_errorf("Could not create socket: %s", strerror(errno)); return -1; } if (set_nonblock(fd) == -1) { + log_errorf("Could not make socket non-blocking: %s", strerror(errno)); close(fd); return -1; } @@ -72,6 +74,7 @@ static int seatd_connect(void) { strncpy(addr.unix.sun_path, path, sizeof addr.unix.sun_path); socklen_t size = offsetof(struct sockaddr_un, sun_path) + strlen(addr.unix.sun_path); if (connect(fd, &addr.generic, size) == -1) { + log_debugf("Could not connect to socket: %s", strerror(errno)); close(fd); return -1; }; @@ -88,21 +91,54 @@ static struct backend_seatd *backend_seatd_from_libseat_backend(struct libseat * return (struct backend_seatd *)base; } -static size_t read_header(struct connection *connection, uint16_t expected_opcode) { +static inline int conn_put(struct backend_seatd *backend, const void *data, const size_t data_len) { + if (connection_put(&backend->connection, data, data_len) == -1) { + log_errorf("Could not send request: %s", strerror(errno)); + return -1; + } + return 0; +} + +static inline int conn_flush(struct backend_seatd *backend) { + if (connection_flush(&backend->connection) == -1) { + log_errorf("Could not flush connection: %s", strerror(errno)); + return -1; + } + return 0; +} + +static inline int conn_get(struct backend_seatd *backend, void *target, const size_t target_len) { + if (connection_get(&backend->connection, target, target_len) == -1) { + log_error("Invalid message: insufficient data received"); + errno = EBADMSG; + return -1; + } + return 0; +} + +static inline int conn_get_fd(struct backend_seatd *backend, int *fd) { + if (connection_get_fd(&backend->connection, fd) == -1) { + log_error("Invalid message: insufficient data received"); + errno = EBADMSG; + return -1; + } + return 0; +} + +static size_t read_header(struct backend_seatd *backend, uint16_t expected_opcode, + size_t expected_size, bool variable) { struct proto_header header; - if (connection_get(connection, &header, sizeof header) == -1) { - log_error("Received invalid message: header too short"); + if (conn_get(backend, &header, sizeof header) == -1) { return SIZE_MAX; } if (header.opcode != expected_opcode) { - connection_restore(connection, sizeof header); + connection_restore(&backend->connection, sizeof header); struct proto_server_error msg; if (header.opcode != SERVER_ERROR) { - log_errorf("Received invalid message: expected opcode %d, received opcode %d", + log_errorf("Unexpected response: expected opcode %d, received opcode %d", expected_opcode, header.opcode); errno = EBADMSG; - } else if (connection_get(connection, &msg, sizeof msg) == -1) { - log_error("Received invalid message"); + } else if (conn_get(backend, &msg, sizeof msg) == -1) { errno = EBADMSG; } else { errno = msg.error_code; @@ -110,12 +146,19 @@ static size_t read_header(struct connection *connection, uint16_t expected_opcod return SIZE_MAX; } + if ((!variable && header.size != expected_size) || (variable && header.size < expected_size)) { + log_errorf("Invalid message: does not match expected size: variable: %d, header.size: %d, expected size: %zd", + variable, header.size, expected_size); + errno = EBADMSG; + return SIZE_MAX; + } return header.size; } static int queue_event(struct backend_seatd *backend, int opcode) { struct pending_event *ev = calloc(1, sizeof(struct pending_event)); if (ev == NULL) { + log_errorf("Allocation failed: %s", strerror(errno)); return -1; } @@ -211,8 +254,7 @@ static int poll_connection(struct backend_seatd *backend, int timeout) { } static int dispatch(struct backend_seatd *backend) { - if (connection_flush(&backend->connection) == -1) { - log_errorf("Could not flush connection: %s", strerror(errno)); + if (conn_flush(backend) == -1) { return -1; } int opcode = 0, res = 0; @@ -278,9 +320,10 @@ static struct libseat *_open_seat(struct libseat_seat_listener *listener, void * assert(listener->enable_seat != NULL && listener->disable_seat != NULL); struct backend_seatd *backend = calloc(1, sizeof(struct backend_seatd)); if (backend == NULL) { - close(fd); - return NULL; + log_errorf("Allocation failed: %s", strerror(errno)); + goto alloc_error; } + backend->seat_listener = listener; backend->seat_listener_data = data; backend->connection.fd = fd; @@ -291,42 +334,31 @@ static struct libseat *_open_seat(struct libseat_seat_listener *listener, void * .opcode = CLIENT_OPEN_SEAT, .size = 0, }; - - if (connection_put(&backend->connection, &header, sizeof header) == -1 || - dispatch(backend) == -1) { - destroy(backend); - return NULL; - } - - size_t size = read_header(&backend->connection, SERVER_SEAT_OPENED); - if (size == SIZE_MAX) { - destroy(backend); - return NULL; + if (conn_put(backend, &header, sizeof header) == -1 || dispatch(backend) == -1) { + goto backend_error; } struct proto_server_seat_opened rmsg; - if (sizeof rmsg > size) { - goto badmsg_error; + size_t size = read_header(backend, SERVER_SEAT_OPENED, sizeof rmsg, true); + if (size == SIZE_MAX || conn_get(backend, &rmsg, sizeof rmsg) == -1) { + goto backend_error; } - - if (connection_get(&backend->connection, &rmsg, sizeof rmsg) == -1) { - goto badmsg_error; - }; - - if (sizeof rmsg + rmsg.seat_name_len > size || - rmsg.seat_name_len >= sizeof backend->seat_name) { - goto badmsg_error; + if (rmsg.seat_name_len != size - sizeof rmsg) { + log_errorf("Invalid message: seat_name_len does not match remaining message size (%d != %zd)", + rmsg.seat_name_len, size); + errno = EBADMSG; + goto backend_error; + } + if (conn_get(backend, backend->seat_name, rmsg.seat_name_len) == -1) { + goto backend_error; } - - if (connection_get(&backend->connection, backend->seat_name, rmsg.seat_name_len) == -1) { - goto badmsg_error; - }; return &backend->base; -badmsg_error: - log_error("Received invalid message"); - errno = EBADMSG; +backend_error: + destroy(backend); +alloc_error: + close(fd); return NULL; } @@ -346,21 +378,20 @@ static int close_seat(struct libseat *base) { .opcode = CLIENT_CLOSE_SEAT, .size = 0, }; - - if (connection_put(&backend->connection, &header, sizeof header) == -1 || - dispatch(backend) == -1) { - destroy(backend); - return -1; + if (conn_put(backend, &header, sizeof header) == -1 || dispatch(backend) == -1) { + goto error; } - size_t size = read_header(&backend->connection, SERVER_SEAT_CLOSED); - if (size == SIZE_MAX) { - destroy(backend); - return -1; + if (read_header(backend, SERVER_SEAT_CLOSED, 0, false) == SIZE_MAX) { + goto error; } destroy(backend); return 0; + +error: + destroy(backend); + return -1; } static const char *seat_name(struct libseat *base) { @@ -384,38 +415,19 @@ static int open_device(struct libseat *base, const char *path, int *fd) { .opcode = CLIENT_OPEN_DEVICE, .size = sizeof msg + pathlen, }; - - if (connection_put(&backend->connection, &header, sizeof header) == -1 || - connection_put(&backend->connection, &msg, sizeof msg) == -1 || - connection_put(&backend->connection, path, pathlen) == -1 || dispatch(backend) == -1) { - return -1; - } - - size_t size = read_header(&backend->connection, SERVER_DEVICE_OPENED); - if (size == SIZE_MAX) { + if (conn_put(backend, &header, sizeof header) == -1 || + conn_put(backend, &msg, sizeof msg) == -1 || conn_put(backend, path, pathlen) == -1 || + dispatch(backend) == -1) { return -1; } struct proto_server_device_opened rmsg; - if (sizeof rmsg > size) { - goto badmsg_error; - } - if (connection_get(&backend->connection, &rmsg, sizeof rmsg) == -1) { - goto badmsg_error; - } - - int received_fd = connection_get_fd(&backend->connection); - if (received_fd == -1) { - goto badmsg_error; + if (read_header(backend, SERVER_DEVICE_OPENED, sizeof rmsg, false) == SIZE_MAX || + conn_get(backend, &rmsg, sizeof rmsg) == -1 || conn_get_fd(backend, fd)) { + return -1; } - *fd = received_fd; return rmsg.device_id; - -badmsg_error: - log_error("Received invalid message"); - errno = EBADMSG; - return -1; } static int close_device(struct libseat *base, int device_id) { @@ -432,34 +444,23 @@ static int close_device(struct libseat *base, int device_id) { .opcode = CLIENT_CLOSE_DEVICE, .size = sizeof msg, }; - - if (connection_put(&backend->connection, &header, sizeof header) == -1 || - connection_put(&backend->connection, &msg, sizeof msg) == -1 || dispatch(backend) == -1) { - return -1; - } - - size_t size = read_header(&backend->connection, SERVER_DEVICE_CLOSED); - if (size == SIZE_MAX) { + if (conn_put(backend, &header, sizeof header) == -1 || + conn_put(backend, &msg, sizeof msg) == -1 || dispatch(backend) == -1) { return -1; } struct proto_server_device_closed rmsg; - if (sizeof rmsg > size) { - goto badmsg_error; - } - if (connection_get(&backend->connection, &rmsg, sizeof rmsg) == -1) { - goto badmsg_error; + if (read_header(backend, SERVER_DEVICE_CLOSED, sizeof rmsg, false) == SIZE_MAX || + conn_get(backend, &rmsg, sizeof rmsg) == -1) { + return -1; } if (rmsg.device_id != device_id) { - goto badmsg_error; + log_errorf("Unexpected response: expected device close for %d, got device close for %d", + rmsg.device_id, device_id); + return -1; } return 0; - -badmsg_error: - log_error("Received invalid message"); - errno = EBADMSG; - return -1; } static int switch_session(struct libseat *base, int session) { @@ -475,10 +476,8 @@ static int switch_session(struct libseat *base, int session) { .opcode = CLIENT_SWITCH_SESSION, .size = sizeof msg, }; - - if (connection_put(&backend->connection, &header, sizeof header) == -1 || - connection_put(&backend->connection, &msg, sizeof msg) == -1 || - connection_flush(&backend->connection) == -1) { + if (conn_put(backend, &header, sizeof header) == -1 || + conn_put(backend, &msg, sizeof msg) == -1 || conn_flush(backend) == -1) { return -1; } @@ -491,9 +490,7 @@ static int disable_seat(struct libseat *base) { .opcode = CLIENT_DISABLE_SEAT, .size = 0, }; - - if (connection_put(&backend->connection, &header, sizeof header) == -1 || - connection_flush(&backend->connection) == -1) { + if (conn_put(backend, &header, sizeof header) == -1 || conn_flush(backend) == -1) { return -1; } |