summaryrefslogtreecommitdiff
path: root/sslio.c
diff options
context:
space:
mode:
authorMark Nunberg <mnunberg@users.noreply.github.com>2019-08-09 04:02:53 -0400
committerGitHub <noreply@github.com>2019-08-09 04:02:53 -0400
commitf9bccfb7baa0bd0c9fdbaee602398590ea364f67 (patch)
treee94e828e31b9f71166436c2a560f33f650ffb13e /sslio.c
parent300fc013c1559ece991abcbf886c86e807c36ba4 (diff)
parent5d013039a95402ff555ab51a7340d4af91b6f72b (diff)
Merge branch 'master' into createArray-size_t
Diffstat (limited to 'sslio.c')
-rw-r--r--sslio.c239
1 files changed, 239 insertions, 0 deletions
diff --git a/sslio.c b/sslio.c
new file mode 100644
index 0000000..25cb5c2
--- /dev/null
+++ b/sslio.c
@@ -0,0 +1,239 @@
+#include "hiredis.h"
+#include "sslio.h"
+
+#include <assert.h>
+#ifdef HIREDIS_SSL
+#include <pthread.h>
+#include <errno.h>
+
+void __redisSetError(redisContext *c, int type, const char *str);
+
+/**
+ * Callback used for debugging
+ */
+static void sslLogCallback(const SSL *ssl, int where, int ret) {
+ const char *retstr = "";
+ int should_log = 1;
+ /* Ignore low-level SSL stuff */
+
+ if (where & SSL_CB_ALERT) {
+ should_log = 1;
+ }
+ if (where == SSL_CB_HANDSHAKE_START || where == SSL_CB_HANDSHAKE_DONE) {
+ should_log = 1;
+ }
+ if ((where & SSL_CB_EXIT) && ret == 0) {
+ should_log = 1;
+ }
+
+ if (!should_log) {
+ return;
+ }
+
+ retstr = SSL_alert_type_string(ret);
+ printf("ST(0x%x). %s. R(0x%x)%s\n", where, SSL_state_string_long(ssl), ret, retstr);
+
+ if (where == SSL_CB_HANDSHAKE_DONE) {
+ printf("Using SSL version %s. Cipher=%s\n", SSL_get_version(ssl), SSL_get_cipher_name(ssl));
+ }
+}
+
+typedef pthread_mutex_t sslLockType;
+static void sslLockInit(sslLockType *l) {
+ pthread_mutex_init(l, NULL);
+}
+static void sslLockAcquire(sslLockType *l) {
+ pthread_mutex_lock(l);
+}
+static void sslLockRelease(sslLockType *l) {
+ pthread_mutex_unlock(l);
+}
+static pthread_mutex_t *ossl_locks;
+
+static void opensslDoLock(int mode, int lkid, const char *f, int line) {
+ sslLockType *l = ossl_locks + lkid;
+
+ if (mode & CRYPTO_LOCK) {
+ sslLockAcquire(l);
+ } else {
+ sslLockRelease(l);
+ }
+
+ (void)f;
+ (void)line;
+}
+
+static void initOpensslLocks(void) {
+ unsigned ii, nlocks;
+ if (CRYPTO_get_locking_callback() != NULL) {
+ /* Someone already set the callback before us. Don't destroy it! */
+ return;
+ }
+ nlocks = CRYPTO_num_locks();
+ ossl_locks = malloc(sizeof(*ossl_locks) * nlocks);
+ for (ii = 0; ii < nlocks; ii++) {
+ sslLockInit(ossl_locks + ii);
+ }
+ CRYPTO_set_locking_callback(opensslDoLock);
+}
+
+void redisFreeSsl(redisSsl *ssl){
+ if (ssl->ctx) {
+ SSL_CTX_free(ssl->ctx);
+ }
+ if (ssl->ssl) {
+ SSL_free(ssl->ssl);
+ }
+ free(ssl);
+}
+
+int redisSslCreate(redisContext *c, const char *capath, const char *certpath,
+ const char *keypath, const char *servername) {
+ assert(!c->ssl);
+ c->ssl = calloc(1, sizeof(*c->ssl));
+ static int isInit = 0;
+ if (!isInit) {
+ isInit = 1;
+ SSL_library_init();
+ initOpensslLocks();
+ }
+
+ redisSsl *s = c->ssl;
+ s->ctx = SSL_CTX_new(SSLv23_client_method());
+ SSL_CTX_set_info_callback(s->ctx, sslLogCallback);
+ SSL_CTX_set_mode(s->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
+ SSL_CTX_set_options(s->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
+ SSL_CTX_set_verify(s->ctx, SSL_VERIFY_PEER, NULL);
+
+ if ((certpath != NULL && keypath == NULL) || (keypath != NULL && certpath == NULL)) {
+ __redisSetError(c, REDIS_ERR, "certpath and keypath must be specified together");
+ return REDIS_ERR;
+ }
+
+ if (capath) {
+ if (!SSL_CTX_load_verify_locations(s->ctx, capath, NULL)) {
+ __redisSetError(c, REDIS_ERR, "Invalid CA certificate");
+ return REDIS_ERR;
+ }
+ }
+ if (certpath) {
+ if (!SSL_CTX_use_certificate_chain_file(s->ctx, certpath)) {
+ __redisSetError(c, REDIS_ERR, "Invalid client certificate");
+ return REDIS_ERR;
+ }
+ if (!SSL_CTX_use_PrivateKey_file(s->ctx, keypath, SSL_FILETYPE_PEM)) {
+ __redisSetError(c, REDIS_ERR, "Invalid client key");
+ return REDIS_ERR;
+ }
+ }
+
+ s->ssl = SSL_new(s->ctx);
+ if (!s->ssl) {
+ __redisSetError(c, REDIS_ERR, "Couldn't create new SSL instance");
+ return REDIS_ERR;
+ }
+ if (servername) {
+ if (!SSL_set_tlsext_host_name(s->ssl, servername)) {
+ __redisSetError(c, REDIS_ERR, "Couldn't set server name indication");
+ return REDIS_ERR;
+ }
+ }
+
+ SSL_set_fd(s->ssl, c->fd);
+ SSL_set_connect_state(s->ssl);
+
+ c->flags |= REDIS_SSL;
+ int rv = SSL_connect(c->ssl->ssl);
+ if (rv == 1) {
+ return REDIS_OK;
+ }
+
+ rv = SSL_get_error(s->ssl, rv);
+ if (((c->flags & REDIS_BLOCK) == 0) &&
+ (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) {
+ return REDIS_OK;
+ }
+
+ if (c->err == 0) {
+ __redisSetError(c, REDIS_ERR_IO, "SSL_connect() failed");
+ }
+ return REDIS_ERR;
+}
+
+static int maybeCheckWant(redisSsl *rssl, int rv) {
+ /**
+ * If the error is WANT_READ or WANT_WRITE, the appropriate flags are set
+ * and true is returned. False is returned otherwise
+ */
+ if (rv == SSL_ERROR_WANT_READ) {
+ rssl->wantRead = 1;
+ return 1;
+ } else if (rv == SSL_ERROR_WANT_WRITE) {
+ rssl->pendingWrite = 1;
+ return 1;
+ } else {
+ return 0;
+ }
+}
+
+int redisSslRead(redisContext *c, char *buf, size_t bufcap) {
+ int nread = SSL_read(c->ssl->ssl, buf, bufcap);
+ if (nread > 0) {
+ return nread;
+ } else if (nread == 0) {
+ __redisSetError(c, REDIS_ERR_EOF, "Server closed the connection");
+ return -1;
+ } else {
+ int err = SSL_get_error(c->ssl->ssl, nread);
+ if (c->flags & REDIS_BLOCK) {
+ /**
+ * In blocking mode, we should never end up in a situation where
+ * we get an error without it being an actual error, except
+ * in the case of EINTR, which can be spuriously received from
+ * debuggers or whatever.
+ */
+ if (errno == EINTR) {
+ return 0;
+ } else {
+ const char *msg = NULL;
+ if (errno == EAGAIN) {
+ msg = "Timed out";
+ }
+ __redisSetError(c, REDIS_ERR_IO, msg);
+ return -1;
+ }
+ }
+
+ /**
+ * We can very well get an EWOULDBLOCK/EAGAIN, however
+ */
+ if (maybeCheckWant(c->ssl, err)) {
+ return 0;
+ } else {
+ __redisSetError(c, REDIS_ERR_IO, NULL);
+ return -1;
+ }
+ }
+}
+
+int redisSslWrite(redisContext *c) {
+ size_t len = c->ssl->lastLen ? c->ssl->lastLen : sdslen(c->obuf);
+ int rv = SSL_write(c->ssl->ssl, c->obuf, len);
+
+ if (rv > 0) {
+ c->ssl->lastLen = 0;
+ } else if (rv < 0) {
+ c->ssl->lastLen = len;
+
+ int err = SSL_get_error(c->ssl->ssl, rv);
+ if ((c->flags & REDIS_BLOCK) == 0 && maybeCheckWant(c->ssl, err)) {
+ return 0;
+ } else {
+ __redisSetError(c, REDIS_ERR_IO, NULL);
+ return -1;
+ }
+ }
+ return rv;
+}
+
+#endif