From 190bca88d0635f1e53d9e39fc69f4f21b67a1baf Mon Sep 17 00:00:00 2001 From: Yossi Gottlieb Date: Fri, 22 May 2020 14:10:08 +0300 Subject: New SSL API to replace redisSecureConnection(). --- ssl.c | 266 ++++++++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 153 insertions(+), 113 deletions(-) (limited to 'ssl.c') diff --git a/ssl.c b/ssl.c index 8cb133a..f25dce2 100644 --- a/ssl.c +++ b/ssl.c @@ -47,17 +47,20 @@ #include "win32.h" #include "async_private.h" +#include "hiredis_ssl.h" void __redisSetError(redisContext *c, int type, const char *str); -/* The SSL context is attached to SSL/TLS connections as a privdata. */ -typedef struct redisSSLContext { - /** - * OpenSSL SSL_CTX; It is optional and will not be set when using - * user-supplied SSL. - */ +struct redisSSLContext { + /* Associated OpenSSL SSL_CTX as created by redisCreateSSLContext() */ SSL_CTX *ssl_ctx; + /* Requested SNI, or NULL */ + char *server_name; +}; + +/* The SSL connection context is attached to SSL/TLS connections as a privdata. */ +typedef struct redisSSL { /** * OpenSSL SSL object. */ @@ -77,43 +80,11 @@ typedef struct redisSSLContext { * should resume whenever a read takes place, if possible */ int pendingWrite; -} redisSSLContext; +} redisSSL; /* Forward declaration */ redisContextFuncs redisContextSSLFuncs; -#ifdef HIREDIS_SSL_TRACE -/** - * Callback used for debugging - */ -static void sslLogCallback(const SSL *ssl, int where, int ret) { - const char *retstr; - int should_log = 0; - /* Ignore low-level SSL stuff */ - - if (where & SSL_CB_ALERT) { - should_log = 1; - } - if (where == SSL_CB_HANDSHAKE_START || where == SSL_CB_HANDSHAKE_DONE) { - should_log = 1; - } - if ((where & SSL_CB_EXIT) && ret == 0) { - should_log = 1; - } - - if (!should_log) { - return; - } - - retstr = SSL_alert_type_string(ret); - printf("ST(0x%x). %s. R(0x%x)%s\n", where, SSL_state_string_long(ssl), ret, retstr); - - if (where == SSL_CB_HANDSHAKE_DONE) { - printf("Using SSL version %s. Cipher=%s\n", SSL_get_version(ssl), SSL_get_cipher_name(ssl)); - } -} -#endif - /** * OpenSSL global initialization and locking handling callbacks. * Note that this is only required for OpenSSL < 1.1.0. @@ -182,26 +153,130 @@ static int initOpensslLocks(void) { } #endif /* HIREDIS_USE_CRYPTO_LOCKS */ +int redisInitOpenSSL(void) +{ + SSL_library_init(); +#ifdef HIREDIS_USE_CRYPTO_LOCKS + initOpensslLocks(); +#endif + + return REDIS_OK; +} + +/** + * redisSSLContext helper context destruction. + */ + +const char *redisSSLContextGetError(redisSSLContextError error) +{ + switch (error) { + case REDIS_SSL_CTX_NONE: + return "No Error"; + case REDIS_SSL_CTX_CREATE_FAILED: + return "Failed to create OpenSSL SSL_CTX"; + case REDIS_SSL_CTX_CERT_KEY_REQUIRED: + return "Client cert and key must both be specified or skipped"; + case REDIS_SSL_CTX_CA_CERT_LOAD_FAILED: + return "Failed to load CA Certificate or CA Path"; + case REDIS_SSL_CTX_CLIENT_CERT_LOAD_FAILED: + return "Failed to load client certificate"; + case REDIS_SSL_CTX_PRIVATE_KEY_LOAD_FAILED: + return "Failed to load private key"; + default: + return "Unknown error code"; + } +} + +void redisFreeSSLContext(redisSSLContext *ctx) +{ + if (!ctx) + return; + + if (ctx->server_name) { + hi_free(ctx->server_name); + ctx->server_name = NULL; + } + + if (ctx->ssl_ctx) { + SSL_CTX_free(ctx->ssl_ctx); + ctx->ssl_ctx = NULL; + } + + hi_free(ctx); +} + + +/** + * redisSSLContext helper context initialization. + */ + +redisSSLContext *redisCreateSSLContext(const char *cacert_filename, const char *capath, + const char *cert_filename, const char *private_key_filename, + const char *server_name, redisSSLContextError *error) +{ + redisSSLContext *ctx = hi_calloc(1, sizeof(redisSSLContext)); + + ctx->ssl_ctx = SSL_CTX_new(SSLv23_client_method()); + if (!ctx->ssl_ctx) { + if (error) *error = REDIS_SSL_CTX_CREATE_FAILED; + goto error; + } + + SSL_CTX_set_options(ctx->ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + SSL_CTX_set_verify(ctx->ssl_ctx, SSL_VERIFY_PEER, NULL); + + if ((cert_filename != NULL && private_key_filename == NULL) || + (private_key_filename != NULL && cert_filename == NULL)) { + if (error) *error = REDIS_SSL_CTX_CERT_KEY_REQUIRED; + goto error; + } + + if (capath || cacert_filename) { + if (!SSL_CTX_load_verify_locations(ctx->ssl_ctx, cacert_filename, capath)) { + if (error) *error = REDIS_SSL_CTX_CA_CERT_LOAD_FAILED; + goto error; + } + } + + if (cert_filename) { + if (!SSL_CTX_use_certificate_chain_file(ctx->ssl_ctx, cert_filename)) { + if (error) *error = REDIS_SSL_CTX_CLIENT_CERT_LOAD_FAILED; + goto error; + } + if (!SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, private_key_filename, SSL_FILETYPE_PEM)) { + if (error) *error = REDIS_SSL_CTX_PRIVATE_KEY_LOAD_FAILED; + goto error; + } + } + + if (server_name) + ctx->server_name = hi_strdup(server_name); + + return ctx; + +error: + redisFreeSSLContext(ctx); + return NULL; +} + /** * SSL Connection initialization. */ -static int redisSSLConnect(redisContext *c, SSL_CTX *ssl_ctx, SSL *ssl) { - redisSSLContext *rssl; +static int redisSSLConnect(redisContext *c, SSL *ssl) { if (c->privdata) { __redisSetError(c, REDIS_ERR_OTHER, "redisContext was already associated"); return REDIS_ERR; } - rssl = hi_calloc(1, sizeof(redisSSLContext)); + redisSSL *rssl = hi_calloc(1, sizeof(redisSSL)); if (rssl == NULL) { __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); return REDIS_ERR; } c->funcs = &redisContextSSLFuncs; - rssl->ssl_ctx = ssl_ctx; rssl->ssl = ssl; SSL_set_mode(rssl->ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); @@ -238,84 +313,53 @@ static int redisSSLConnect(redisContext *c, SSL_CTX *ssl_ctx, SSL *ssl) { return REDIS_ERR; } +/** + * A wrapper around redisSSLConnect() for users who manage their own context and + * create their own SSL object. + */ + int redisInitiateSSL(redisContext *c, SSL *ssl) { - return redisSSLConnect(c, NULL, ssl); + return redisSSLConnect(c, ssl); } -int redisSecureConnection(redisContext *c, const char *capath, - const char *certpath, const char *keypath, const char *servername) { - - SSL_CTX *ssl_ctx = NULL; - SSL *ssl = NULL; - - /* Initialize global OpenSSL stuff */ - static int isInit = 0; - if (!isInit) { - isInit = 1; - SSL_library_init(); -#ifdef HIREDIS_USE_CRYPTO_LOCKS - if (initOpensslLocks() == REDIS_ERR) { - __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); - goto error; - } -#endif - } - - ssl_ctx = SSL_CTX_new(SSLv23_client_method()); - if (!ssl_ctx) { - __redisSetError(c, REDIS_ERR_OTHER, "Failed to create SSL_CTX"); - goto error; - } +/** + * A wrapper around redisSSLConnect() for users who use redisSSLContext and don't + * manage their own SSL objects. + */ -#ifdef HIREDIS_SSL_TRACE - SSL_CTX_set_info_callback(ssl_ctx, sslLogCallback); -#endif - SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); - SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, NULL); - if ((certpath != NULL && keypath == NULL) || (keypath != NULL && certpath == NULL)) { - __redisSetError(c, REDIS_ERR_OTHER, "certpath and keypath must be specified together"); - goto error; - } +int redisInitiateSSLWithContext(redisContext *c, redisSSLContext *redis_ssl_ctx) +{ + if (!c || !redis_ssl_ctx) + return REDIS_ERR; - if (capath) { - if (!SSL_CTX_load_verify_locations(ssl_ctx, capath, NULL)) { - __redisSetError(c, REDIS_ERR_OTHER, "Invalid CA certificate"); - goto error; - } - } - if (certpath) { - if (!SSL_CTX_use_certificate_chain_file(ssl_ctx, certpath)) { - __redisSetError(c, REDIS_ERR_OTHER, "Invalid client certificate"); - goto error; - } - if (!SSL_CTX_use_PrivateKey_file(ssl_ctx, keypath, SSL_FILETYPE_PEM)) { - __redisSetError(c, REDIS_ERR_OTHER, "Invalid client key"); - goto error; - } - } + /* We want to verify that redisSSLConnect() won't fail on this, as it will + * not own the SSL object in that case and we'll end up leaking. + */ + if (c->privdata) + return REDIS_ERR; - ssl = SSL_new(ssl_ctx); + SSL *ssl = SSL_new(redis_ssl_ctx->ssl_ctx); if (!ssl) { __redisSetError(c, REDIS_ERR_OTHER, "Couldn't create new SSL instance"); goto error; } - if (servername) { - if (!SSL_set_tlsext_host_name(ssl, servername)) { - __redisSetError(c, REDIS_ERR_OTHER, "Couldn't set server name indication"); + + if (redis_ssl_ctx->server_name) { + if (!SSL_set_tlsext_host_name(ssl, redis_ssl_ctx->server_name)) { + __redisSetError(c, REDIS_ERR_OTHER, "Failed to set server_name/SNI"); goto error; } } - if (redisSSLConnect(c, ssl_ctx, ssl) == REDIS_OK) - return REDIS_OK; + return redisSSLConnect(c, ssl); error: - if (ssl) SSL_free(ssl); - if (ssl_ctx) SSL_CTX_free(ssl_ctx); + if (ssl) + SSL_free(ssl); return REDIS_ERR; } -static int maybeCheckWant(redisSSLContext *rssl, int rv) { +static int maybeCheckWant(redisSSL *rssl, int rv) { /** * If the error is WANT_READ or WANT_WRITE, the appropriate flags are set * and true is returned. False is returned otherwise @@ -335,23 +379,19 @@ static int maybeCheckWant(redisSSLContext *rssl, int rv) { * Implementation of redisContextFuncs for SSL connections. */ -static void redisSSLFreeContext(void *privdata){ - redisSSLContext *rsc = privdata; +static void redisSSLFree(void *privdata){ + redisSSL *rsc = privdata; if (!rsc) return; if (rsc->ssl) { SSL_free(rsc->ssl); rsc->ssl = NULL; } - if (rsc->ssl_ctx) { - SSL_CTX_free(rsc->ssl_ctx); - rsc->ssl_ctx = NULL; - } hi_free(rsc); } static int redisSSLRead(redisContext *c, char *buf, size_t bufcap) { - redisSSLContext *rssl = c->privdata; + redisSSL *rssl = c->privdata; int nread = SSL_read(rssl->ssl, buf, bufcap); if (nread > 0) { @@ -393,7 +433,7 @@ static int redisSSLRead(redisContext *c, char *buf, size_t bufcap) { } static int redisSSLWrite(redisContext *c) { - redisSSLContext *rssl = c->privdata; + redisSSL *rssl = c->privdata; size_t len = rssl->lastLen ? rssl->lastLen : sdslen(c->obuf); int rv = SSL_write(rssl->ssl, c->obuf, len); @@ -416,7 +456,7 @@ static int redisSSLWrite(redisContext *c) { static void redisSSLAsyncRead(redisAsyncContext *ac) { int rv; - redisSSLContext *rssl = ac->c.privdata; + redisSSL *rssl = ac->c.privdata; redisContext *c = &ac->c; rssl->wantRead = 0; @@ -446,7 +486,7 @@ static void redisSSLAsyncRead(redisAsyncContext *ac) { static void redisSSLAsyncWrite(redisAsyncContext *ac) { int rv, done = 0; - redisSSLContext *rssl = ac->c.privdata; + redisSSL *rssl = ac->c.privdata; redisContext *c = &ac->c; rssl->pendingWrite = 0; @@ -475,7 +515,7 @@ static void redisSSLAsyncWrite(redisAsyncContext *ac) { } redisContextFuncs redisContextSSLFuncs = { - .free_privdata = redisSSLFreeContext, + .free_privdata = redisSSLFree, .async_read = redisSSLAsyncRead, .async_write = redisSSLAsyncWrite, .read = redisSSLRead, -- cgit v1.2.3