From 8715ba5c82af165f783ef8bc90b4d8d5a8072175 Mon Sep 17 00:00:00 2001
From: Yossi Gottlieb <yossigo@gmail.com>
Date: Thu, 29 Aug 2019 22:08:54 +0300
Subject: wip: SSL code reorganization, see #705.

---
 Makefile                        |  12 +-
 async.c                         | 143 +++--------------
 async.h                         |   2 +
 async_private.h                 |  72 +++++++++
 examples/example-libevent-ssl.c |   1 +
 examples/example-ssl.c          |   1 +
 hiredis.c                       |  25 +--
 hiredis.h                       |  27 ++--
 hiredis_ssl.h                   |  53 +++++++
 sslio.c                         | 343 +++++++++++++++++++++++++++++++---------
 sslio.h                         |  64 --------
 11 files changed, 455 insertions(+), 288 deletions(-)
 create mode 100644 async_private.h
 create mode 100644 hiredis_ssl.h
 delete mode 100644 sslio.h

diff --git a/Makefile b/Makefile
index 7a55f41..b509566 100644
--- a/Makefile
+++ b/Makefile
@@ -4,8 +4,10 @@
 # This file is released under the BSD license, see the COPYING file
 
 OBJ=net.o hiredis.o sds.o async.o read.o sockcompat.o sslio.o
-EXAMPLES=hiredis-example hiredis-example-libevent hiredis-example-libev hiredis-example-glib \
-		 hiredis-example-ssl hiredis-example-libevent-ssl
+EXAMPLES=hiredis-example hiredis-example-libevent hiredis-example-libev hiredis-example-glib
+ifeq ($(USE_SSL),1)
+EXAMPLES+=hiredis-example-ssl hiredis-example-libevent-ssl
+endif
 TESTS=hiredis-test
 LIBNAME=libhiredis
 PKGCONFNAME=hiredis.pc
@@ -87,12 +89,12 @@ all: $(DYLIBNAME) $(STLIBNAME) hiredis-test $(PKGCONFNAME)
 # Deps (use make dep to generate this)
 async.o: async.c fmacros.h async.h hiredis.h read.h sds.h net.h dict.c dict.h
 dict.o: dict.c fmacros.h dict.h
-hiredis.o: hiredis.c fmacros.h hiredis.h read.h sds.h net.h sslio.h win32.h
+hiredis.o: hiredis.c fmacros.h hiredis.h read.h sds.h net.h win32.h
 net.o: net.c fmacros.h net.h hiredis.h read.h sds.h sockcompat.h win32.h
 read.o: read.c fmacros.h read.h sds.h
 sds.o: sds.c sds.h
 sockcompat.o: sockcompat.c sockcompat.h
-sslio.o: sslio.c sslio.h hiredis.h
+sslio.o: sslio.c hiredis.h
 test.o: test.c fmacros.h hiredis.h read.h sds.h
 
 $(DYLIBNAME): $(OBJ)
@@ -205,7 +207,7 @@ endif
 
 install: $(DYLIBNAME) $(STLIBNAME) $(PKGCONFNAME)
 	mkdir -p $(INSTALL_INCLUDE_PATH) $(INSTALL_INCLUDE_PATH)/adapters $(INSTALL_LIBRARY_PATH)
-	$(INSTALL) hiredis.h async.h read.h sds.h sslio.h $(INSTALL_INCLUDE_PATH)
+	$(INSTALL) hiredis.h async.h read.h sds.h $(INSTALL_INCLUDE_PATH)
 	$(INSTALL) adapters/*.h $(INSTALL_INCLUDE_PATH)/adapters
 	$(INSTALL) $(DYLIBNAME) $(INSTALL_LIBRARY_PATH)/$(DYLIB_MINOR_NAME)
 	cd $(INSTALL_LIBRARY_PATH) && ln -sf $(DYLIB_MINOR_NAME) $(DYLIBNAME)
diff --git a/async.c b/async.c
index e46573f..4f422d5 100644
--- a/async.c
+++ b/async.c
@@ -42,42 +42,9 @@
 #include "net.h"
 #include "dict.c"
 #include "sds.h"
-#include "sslio.h"
 #include "win32.h"
 
-#define _EL_ADD_READ(ctx)                                         \
-    do {                                                          \
-        refreshTimeout(ctx);                                      \
-        if ((ctx)->ev.addRead) (ctx)->ev.addRead((ctx)->ev.data); \
-    } while (0)
-#define _EL_DEL_READ(ctx) do { \
-        if ((ctx)->ev.delRead) (ctx)->ev.delRead((ctx)->ev.data); \
-    } while(0)
-#define _EL_ADD_WRITE(ctx)                                          \
-    do {                                                            \
-        refreshTimeout(ctx);                                        \
-        if ((ctx)->ev.addWrite) (ctx)->ev.addWrite((ctx)->ev.data); \
-    } while (0)
-#define _EL_DEL_WRITE(ctx) do { \
-        if ((ctx)->ev.delWrite) (ctx)->ev.delWrite((ctx)->ev.data); \
-    } while(0)
-#define _EL_CLEANUP(ctx) do { \
-        if ((ctx)->ev.cleanup) (ctx)->ev.cleanup((ctx)->ev.data); \
-        ctx->ev.cleanup = NULL; \
-    } while(0);
-
-static void refreshTimeout(redisAsyncContext *ctx) {
-    if (ctx->c.timeout && ctx->ev.scheduleTimer &&
-        (ctx->c.timeout->tv_sec || ctx->c.timeout->tv_usec)) {
-        ctx->ev.scheduleTimer(ctx->ev.data, *ctx->c.timeout);
-    // } else {
-    //     printf("Not scheduling timer.. (tmo=%p)\n", ctx->c.timeout);
-    //     if (ctx->c.timeout){
-    //         printf("tv_sec: %u. tv_usec: %u\n", ctx->c.timeout->tv_sec,
-    //                ctx->c.timeout->tv_usec);
-    //     }
-    }
-}
+#include "async_private.h"
 
 /* Forward declaration of function in hiredis.c */
 int __redisAppendCommand(redisContext *c, const char *cmd, size_t len);
@@ -347,7 +314,7 @@ void redisAsyncFree(redisAsyncContext *ac) {
 }
 
 /* Helper function to make the disconnect happen and clean up. */
-static void __redisAsyncDisconnect(redisAsyncContext *ac) {
+void __redisAsyncDisconnect(redisAsyncContext *ac) {
     redisContext *c = &(ac->c);
 
     /* Make sure error is accessible if there is any */
@@ -552,76 +519,18 @@ static int __redisAsyncHandleConnect(redisAsyncContext *ac) {
     }
 }
 
-/**
- * Handle SSL when socket becomes available for reading. This also handles
- * read-while-write and write-while-read.
- *
- * These functions will not work properly unless `HIREDIS_SSL` is defined
- * (however, they will compile)
- */
-static void asyncSslRead(redisAsyncContext *ac) {
-    int rv;
-    redisSsl *ssl = ac->c.ssl;
-    redisContext *c = &ac->c;
-
-    ssl->wantRead = 0;
-
-    if (ssl->pendingWrite) {
-        int done;
-
-        /* This is probably just a write event */
-        ssl->pendingWrite = 0;
-        rv = redisBufferWrite(c, &done);
-        if (rv == REDIS_ERR) {
-            __redisAsyncDisconnect(ac);
-            return;
-        } else if (!done) {
-            _EL_ADD_WRITE(ac);
-        }
-    }
+void redisAsyncRead(redisAsyncContext *ac) {
+    redisContext *c = &(ac->c);
 
-    rv = redisBufferRead(c);
-    if (rv == REDIS_ERR) {
+    if (redisBufferRead(c) == REDIS_ERR) {
         __redisAsyncDisconnect(ac);
     } else {
+        /* Always re-schedule reads */
         _EL_ADD_READ(ac);
         redisProcessCallbacks(ac);
     }
 }
 
-/**
- * Handle SSL when socket becomes available for writing
- */
-static void asyncSslWrite(redisAsyncContext *ac) {
-    int rv, done = 0;
-    redisSsl *ssl = ac->c.ssl;
-    redisContext *c = &ac->c;
-
-    ssl->pendingWrite = 0;
-    rv = redisBufferWrite(c, &done);
-    if (rv == REDIS_ERR) {
-        __redisAsyncDisconnect(ac);
-        return;
-    }
-
-    if (!done) {
-        if (ssl->wantRead) {
-            /* Need to read-before-write */
-            ssl->pendingWrite = 1;
-            _EL_DEL_WRITE(ac);
-        } else {
-            /* No extra reads needed, just need to write more */
-            _EL_ADD_WRITE(ac);
-        }
-    } else {
-        /* Already done! */
-        _EL_DEL_WRITE(ac);
-    }
-
-    /* Always reschedule a read */
-    _EL_ADD_READ(ac);
-}
-
 /* This function should be called when the socket is readable.
  * It processes all replies that can be read and executes their callbacks.
  */
@@ -637,23 +546,29 @@ void redisAsyncHandleRead(redisAsyncContext *ac) {
             return;
     }
 
-    if (c->flags & REDIS_SSL) {
-        asyncSslRead(ac);
-        return;
-    }
+    c->funcs->async_read(ac);
+}
 
-    if (redisBufferRead(c) == REDIS_ERR) {
+void redisAsyncWrite(redisAsyncContext *ac) {
+    redisContext *c = &(ac->c);
+    int done = 0;
+
+    if (redisBufferWrite(c,&done) == REDIS_ERR) {
         __redisAsyncDisconnect(ac);
     } else {
-        /* Always re-schedule reads */
+        /* Continue writing when not done, stop writing otherwise */
+        if (!done)
+            _EL_ADD_WRITE(ac);
+        else
+            _EL_DEL_WRITE(ac);
+
+        /* Always schedule reads after writes */
         _EL_ADD_READ(ac);
-        redisProcessCallbacks(ac);
     }
 }
 
 void redisAsyncHandleWrite(redisAsyncContext *ac) {
     redisContext *c = &(ac->c);
-    int done = 0;
 
     if (!(c->flags & REDIS_CONNECTED)) {
         /* Abort connect was not successful. */
@@ -664,23 +579,7 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) {
             return;
     }
 
-    if (c->flags & REDIS_SSL) {
-        asyncSslWrite(ac);
-        return;
-    }
-
-    if (redisBufferWrite(c,&done) == REDIS_ERR) {
-        __redisAsyncDisconnect(ac);
-    } else {
-        /* Continue writing when not done, stop writing otherwise */
-        if (!done)
-            _EL_ADD_WRITE(ac);
-        else
-            _EL_DEL_WRITE(ac);
-
-        /* Always schedule reads after writes */
-        _EL_ADD_READ(ac);
-    }
+    c->funcs->async_write(ac);
 }
 
 void __redisSetError(redisContext *c, int type, const char *str);
diff --git a/async.h b/async.h
index 40a1819..4f6b3b7 100644
--- a/async.h
+++ b/async.h
@@ -125,6 +125,8 @@ void redisAsyncFree(redisAsyncContext *ac);
 void redisAsyncHandleRead(redisAsyncContext *ac);
 void redisAsyncHandleWrite(redisAsyncContext *ac);
 void redisAsyncHandleTimeout(redisAsyncContext *ac);
+void redisAsyncRead(redisAsyncContext *ac);
+void redisAsyncWrite(redisAsyncContext *ac);
 
 /* Command functions for an async context. Write the command to the
  * output buffer and register the provided callback. */
diff --git a/async_private.h b/async_private.h
new file mode 100644
index 0000000..d0133ae
--- /dev/null
+++ b/async_private.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2009-2011, Salvatore Sanfilippo <antirez at gmail dot com>
+ * Copyright (c) 2010-2011, Pieter Noordhuis <pcnoordhuis at gmail dot com>
+ *
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ *   * Redistributions of source code must retain the above copyright notice,
+ *     this list of conditions and the following disclaimer.
+ *   * Redistributions in binary form must reproduce the above copyright
+ *     notice, this list of conditions and the following disclaimer in the
+ *     documentation and/or other materials provided with the distribution.
+ *   * Neither the name of Redis nor the names of its contributors may be used
+ *     to endorse or promote products derived from this software without
+ *     specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef __HIREDIS_ASYNC_PRIVATE_H
+#define __HIREDIS_ASYNC_PRIVATE_H
+
+#define _EL_ADD_READ(ctx)                                         \
+    do {                                                          \
+        refreshTimeout(ctx);                                      \
+        if ((ctx)->ev.addRead) (ctx)->ev.addRead((ctx)->ev.data); \
+    } while (0)
+#define _EL_DEL_READ(ctx) do { \
+        if ((ctx)->ev.delRead) (ctx)->ev.delRead((ctx)->ev.data); \
+    } while(0)
+#define _EL_ADD_WRITE(ctx)                                          \
+    do {                                                            \
+        refreshTimeout(ctx);                                        \
+        if ((ctx)->ev.addWrite) (ctx)->ev.addWrite((ctx)->ev.data); \
+    } while (0)
+#define _EL_DEL_WRITE(ctx) do { \
+        if ((ctx)->ev.delWrite) (ctx)->ev.delWrite((ctx)->ev.data); \
+    } while(0)
+#define _EL_CLEANUP(ctx) do { \
+        if ((ctx)->ev.cleanup) (ctx)->ev.cleanup((ctx)->ev.data); \
+        ctx->ev.cleanup = NULL; \
+    } while(0);
+
+static inline void refreshTimeout(redisAsyncContext *ctx) {
+    if (ctx->c.timeout && ctx->ev.scheduleTimer &&
+        (ctx->c.timeout->tv_sec || ctx->c.timeout->tv_usec)) {
+        ctx->ev.scheduleTimer(ctx->ev.data, *ctx->c.timeout);
+    // } else {
+    //     printf("Not scheduling timer.. (tmo=%p)\n", ctx->c.timeout);
+    //     if (ctx->c.timeout){
+    //         printf("tv_sec: %u. tv_usec: %u\n", ctx->c.timeout->tv_sec,
+    //                ctx->c.timeout->tv_usec);
+    //     }
+    }
+}
+
+void __redisAsyncDisconnect(redisAsyncContext *ac);
+void redisProcessCallbacks(redisAsyncContext *ac);
+
+#endif  /* __HIREDIS_ASYNC_PRIVATE_H */
diff --git a/examples/example-libevent-ssl.c b/examples/example-libevent-ssl.c
index 562e1a1..1021113 100644
--- a/examples/example-libevent-ssl.c
+++ b/examples/example-libevent-ssl.c
@@ -4,6 +4,7 @@
 #include <signal.h>
 
 #include <hiredis.h>
+#include <hiredis_ssl.h>
 #include <async.h>
 #include <adapters/libevent.h>
 
diff --git a/examples/example-ssl.c b/examples/example-ssl.c
index 156f524..81f4648 100644
--- a/examples/example-ssl.c
+++ b/examples/example-ssl.c
@@ -3,6 +3,7 @@
 #include <string.h>
 
 #include <hiredis.h>
+#include <hiredis_ssl.h>
 
 int main(int argc, char **argv) {
     unsigned int j;
diff --git a/hiredis.c b/hiredis.c
index 9627832..0658e34 100644
--- a/hiredis.c
+++ b/hiredis.c
@@ -41,9 +41,17 @@
 #include "hiredis.h"
 #include "net.h"
 #include "sds.h"
-#include "sslio.h"
+#include "async.h"
 #include "win32.h"
 
+static redisContextFuncs redisContextDefaultFuncs = {
+    .free_privdata = NULL,
+    .async_read = redisAsyncRead,
+    .async_write = redisAsyncWrite,
+    .read = redisNetRead,
+    .write = redisNetWrite
+};
+
 static redisReply *createReplyObject(int type);
 static void *createStringObject(const redisReadTask *task, char *str, size_t len);
 static void *createArrayObject(const redisReadTask *task, size_t elements);
@@ -657,6 +665,7 @@ static redisContext *redisContextInit(const redisOptions *options) {
     if (c == NULL)
         return NULL;
 
+    c->funcs = &redisContextDefaultFuncs;
     c->obuf = sdsempty();
     c->reader = redisReaderCreate();
     c->fd = REDIS_INVALID_FD;
@@ -681,8 +690,8 @@ void redisFree(redisContext *c) {
     free(c->unix_sock.path);
     free(c->timeout);
     free(c->saddr);
-    if (c->ssl) {
-        redisFreeSsl(c->ssl);
+    if (c->funcs->free_privdata) {
+        c->funcs->free_privdata(c->privdata);
     }
     memset(c, 0xff, sizeof(*c));
     free(c);
@@ -824,11 +833,6 @@ redisContext *redisConnectFd(redisFD fd) {
     return redisConnectWithOptions(&options);
 }
 
-int redisSecureConnection(redisContext *c, const char *caPath,
-                          const char *certPath, const char *keyPath, const char *servername) {
-    return redisSslCreate(c, caPath, certPath, keyPath, servername);
-}
-
 /* Set read/write timeout on a blocking socket. */
 int redisSetTimeout(redisContext *c, const struct timeval tv) {
     if (c->flags & REDIS_BLOCK)
@@ -856,8 +860,7 @@ int redisBufferRead(redisContext *c) {
     if (c->err)
         return REDIS_ERR;
 
-    nread = c->flags & REDIS_SSL ?
-        redisSslRead(c, buf, sizeof(buf)) : redisNetRead(c, buf, sizeof(buf));
+    nread = c->funcs->read(c, buf, sizeof(buf));
     if (nread > 0) {
         if (redisReaderFeed(c->reader, buf, nread) != REDIS_OK) {
             __redisSetError(c, c->reader->err, c->reader->errstr);
@@ -886,7 +889,7 @@ int redisBufferWrite(redisContext *c, int *done) {
         return REDIS_ERR;
 
     if (sdslen(c->obuf) > 0) {
-        int nwritten = (c->flags & REDIS_SSL) ? redisSslWrite(c) : redisNetWrite(c);
+        int nwritten = c->funcs->write(c);
         if (nwritten < 0) {
             return REDIS_ERR;
         } else if (nwritten > 0) {
diff --git a/hiredis.h b/hiredis.h
index d76a9e3..68afb26 100644
--- a/hiredis.h
+++ b/hiredis.h
@@ -78,9 +78,6 @@ struct timeval; /* forward declaration */
 /* Flag that is set when we should set SO_REUSEADDR before calling bind() */
 #define REDIS_REUSEADDR 0x80
 
-/* Flag that is set when this connection is done through SSL */
-#define REDIS_SSL 0x100
-
 /**
  * Flag that indicates the user does not want the context to
  * be automatically freed upon error
@@ -193,8 +190,21 @@ typedef struct {
     (opts)->type = REDIS_CONN_UNIX;        \
     (opts)->endpoint.unix_socket = path;
 
+struct redisAsyncContext;
+struct redisContext;
+
+typedef struct redisContextFuncs {
+    void (*free_privdata)(void *);
+    void (*async_read)(struct redisAsyncContext *);
+    void (*async_write)(struct redisAsyncContext *);
+    int (*read)(struct redisContext *, char *, size_t);
+    int (*write)(struct redisContext *);
+} redisContextFuncs;
+
 /* Context for a connection to Redis */
 typedef struct redisContext {
+    redisContextFuncs *funcs;   /* Function table */
+
     int err; /* Error flags, 0 when there is no error */
     char errstr[128]; /* String representation of error when applicable */
     redisFD fd;
@@ -218,9 +228,9 @@ typedef struct redisContext {
     /* For non-blocking connect */
     struct sockadr *saddr;
     size_t addrlen;
-    /* For SSL communication */
-    struct redisSsl *ssl;
 
+    /* Additional private data for hiredis addons such as SSL */
+    void *privdata;
 } redisContext;
 
 redisContext *redisConnectWithOptions(const redisOptions *options);
@@ -236,13 +246,6 @@ redisContext *redisConnectUnixWithTimeout(const char *path, const struct timeval
 redisContext *redisConnectUnixNonBlock(const char *path);
 redisContext *redisConnectFd(redisFD fd);
 
-/**
- * Secure the connection using SSL. This should be done before any command is
- * executed on the connection.
- */
-int redisSecureConnection(redisContext *c, const char *capath, const char *certpath,
-                          const char *keypath, const char *servername);
-
 /**
  * Reconnect the given context using the saved information.
  *
diff --git a/hiredis_ssl.h b/hiredis_ssl.h
new file mode 100644
index 0000000..f844f95
--- /dev/null
+++ b/hiredis_ssl.h
@@ -0,0 +1,53 @@
+
+/*
+ * Copyright (c) 2019, Redis Labs
+ *
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ *   * Redistributions of source code must retain the above copyright notice,
+ *     this list of conditions and the following disclaimer.
+ *   * Redistributions in binary form must reproduce the above copyright
+ *     notice, this list of conditions and the following disclaimer in the
+ *     documentation and/or other materials provided with the distribution.
+ *   * Neither the name of Redis nor the names of its contributors may be used
+ *     to endorse or promote products derived from this software without
+ *     specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef __HIREDIS_SSL_H
+#define __HIREDIS_SSL_H
+
+/* This is the underlying struct for SSL in ssl.h, which is not included to
+ * keep build dependencies short here.
+ */
+struct ssl_st;
+
+/**
+ * Secure the connection using SSL. This should be done before any command is
+ * executed on the connection.
+ */
+int redisSecureConnection(redisContext *c, const char *capath, const char *certpath,
+                          const char *keypath, const char *servername);
+
+/**
+ * Initiate SSL/TLS negotiation on a provided context.
+ */
+
+int redisInitiateSSL(redisContext *c, struct ssl_st *ssl);
+
+#endif  /* __HIREDIS_SSL_H */
diff --git a/sslio.c b/sslio.c
index f2f50a8..5c76370 100644
--- a/sslio.c
+++ b/sslio.c
@@ -1,5 +1,37 @@
+/*
+ * Copyright (c) 2009-2011, Salvatore Sanfilippo <antirez at gmail dot com>
+ * Copyright (c) 2010-2011, Pieter Noordhuis <pcnoordhuis at gmail dot com>
+ * Copyright (c) 2019, Redis Labs
+ *
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ *   * Redistributions of source code must retain the above copyright notice,
+ *     this list of conditions and the following disclaimer.
+ *   * Redistributions in binary form must reproduce the above copyright
+ *     notice, this list of conditions and the following disclaimer in the
+ *     documentation and/or other materials provided with the distribution.
+ *   * Neither the name of Redis nor the names of its contributors may be used
+ *     to endorse or promote products derived from this software without
+ *     specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
 #include "hiredis.h"
-#include "sslio.h"
+#include "async.h"
 
 #include <assert.h>
 #ifdef HIREDIS_SSL
@@ -7,10 +39,45 @@
 #include <errno.h>
 #include <string.h>
 
+#include <openssl/ssl.h>
 #include <openssl/err.h>
 
+#include "async_private.h"
+
 void __redisSetError(redisContext *c, int type, const char *str);
 
+/* The SSL context is attached to SSL/TLS connections as a privdata. */
+typedef struct redisSSLContext {
+    /**
+     * OpenSSL SSL_CTX; It is optional and will not be set when using
+     * user-supplied SSL.
+     */
+    SSL_CTX *ssl_ctx;
+
+    /**
+     * OpenSSL SSL object.
+     */
+    SSL *ssl;
+
+    /**
+     * SSL_write() requires to be called again with the same arguments it was
+     * previously called with in the event of an SSL_read/SSL_write situation
+     */
+    size_t lastLen;
+
+    /** Whether the SSL layer requires read (possibly before a write) */
+    int wantRead;
+
+    /**
+     * Whether a write was requested prior to a read. If set, the write()
+     * should resume whenever a read takes place, if possible
+     */
+    int pendingWrite;
+} redisSSLContext;
+
+/* Forward declaration */
+redisContextFuncs redisContextSSLFuncs;
+
 #ifdef HIREDIS_SSL_TRACE
 /**
  * Callback used for debugging
@@ -43,6 +110,16 @@ static void sslLogCallback(const SSL *ssl, int where, int ret) {
 }
 #endif
 
+/**
+ * OpenSSL global initialization and locking handling callbacks.
+ * Note that this is only required for OpenSSL < 1.1.0.
+ */
+
+#if OPENSSL_VERSION_NUMBER < 0x10100000L
+#define HIREDIS_USE_CRYPTO_LOCKS
+#endif
+
+#ifdef HIREDIS_USE_CRYPTO_LOCKS
 typedef pthread_mutex_t sslLockType;
 static void sslLockInit(sslLockType *l) {
     pthread_mutex_init(l, NULL);
@@ -81,102 +158,129 @@ static void initOpensslLocks(void) {
     }
     CRYPTO_set_locking_callback(opensslDoLock);
 }
+#endif /* HIREDIS_USE_CRYPTO_LOCKS */
 
-void redisFreeSsl(redisSsl *ssl){
-    if (ssl->ctx) {
-        SSL_CTX_free(ssl->ctx);
+/**
+ * SSL Connection initialization.
+ */
+
+static int redisSSLConnect(redisContext *c, SSL_CTX *ssl_ctx, SSL *ssl) {
+    if (c->privdata) {
+        __redisSetError(c, REDIS_ERR_OTHER, "redisContext was already associated");
+        return REDIS_ERR;
     }
-    if (ssl->ssl) {
-        SSL_free(ssl->ssl);
+    c->privdata = calloc(1, sizeof(redisSSLContext));
+
+    c->funcs = &redisContextSSLFuncs;
+    redisSSLContext *rssl = c->privdata;
+
+    rssl->ssl_ctx = ssl_ctx;
+    rssl->ssl = ssl;
+
+    SSL_set_mode(rssl->ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
+    SSL_set_fd(rssl->ssl, c->fd);
+    SSL_set_connect_state(rssl->ssl);
+
+    ERR_clear_error();
+    int rv = SSL_connect(rssl->ssl);
+    if (rv == 1) {
+        return REDIS_OK;
+    }
+
+    rv = SSL_get_error(rssl->ssl, rv);
+    if (((c->flags & REDIS_BLOCK) == 0) &&
+        (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) {
+        return REDIS_OK;
+    }
+
+    if (c->err == 0) {
+        char err[512];
+        if (rv == SSL_ERROR_SYSCALL)
+            snprintf(err,sizeof(err)-1,"SSL_connect failed: %s",strerror(errno));
+        else {
+            unsigned long e = ERR_peek_last_error();
+            snprintf(err,sizeof(err)-1,"SSL_connect failed: %s",
+                    ERR_reason_error_string(e));
+        }
+        __redisSetError(c, REDIS_ERR_IO, err);
     }
-    free(ssl);
+    return REDIS_ERR;
+}
+
+int redisInitiateSSL(redisContext *c, SSL *ssl) {
+    return redisSSLConnect(c, NULL, ssl);
 }
 
-int redisSslCreate(redisContext *c, const char *capath, const char *certpath,
-                   const char *keypath, const char *servername) {
-    assert(!c->ssl);
-    c->ssl = calloc(1, sizeof(*c->ssl));
+int redisSecureConnection(redisContext *c, const char *capath,
+                          const char *certpath, const char *keypath, const char *servername) {
+
+    SSL_CTX *ssl_ctx = NULL;
+    SSL *ssl = NULL;
+
+    /* Initialize global OpenSSL stuff */
     static int isInit = 0;
     if (!isInit) {
         isInit = 1;
         SSL_library_init();
+#ifdef HIREDIS_USE_CRYPTO_LOCKS
         initOpensslLocks();
+#endif
+    }
+
+    ssl_ctx = SSL_CTX_new(SSLv23_client_method());
+    if (!ssl_ctx) {
+        __redisSetError(c, REDIS_ERR_OTHER, "Failed to create SSL_CTX");
+        goto error;
     }
 
-    redisSsl *s = c->ssl;
-    s->ctx = SSL_CTX_new(SSLv23_client_method());
 #ifdef HIREDIS_SSL_TRACE
-    SSL_CTX_set_info_callback(s->ctx, sslLogCallback);
+    SSL_CTX_set_info_callback(ssl_ctx, sslLogCallback);
 #endif
-    SSL_CTX_set_mode(s->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
-    SSL_CTX_set_options(s->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
-    SSL_CTX_set_verify(s->ctx, SSL_VERIFY_PEER, NULL);
-
+    SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
+    SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, NULL);
     if ((certpath != NULL && keypath == NULL) || (keypath != NULL && certpath == NULL)) {
-        __redisSetError(c, REDIS_ERR, "certpath and keypath must be specified together");
-        return REDIS_ERR;
+        __redisSetError(c, REDIS_ERR_OTHER, "certpath and keypath must be specified together");
+        goto error;
     }
 
     if (capath) {
-        if (!SSL_CTX_load_verify_locations(s->ctx, capath, NULL)) {
-            __redisSetError(c, REDIS_ERR, "Invalid CA certificate");
-            return REDIS_ERR;
+        if (!SSL_CTX_load_verify_locations(ssl_ctx, capath, NULL)) {
+            __redisSetError(c, REDIS_ERR_OTHER, "Invalid CA certificate");
+            goto error;
         }
     }
     if (certpath) {
-        if (!SSL_CTX_use_certificate_chain_file(s->ctx, certpath)) {
-            __redisSetError(c, REDIS_ERR, "Invalid client certificate");
-            return REDIS_ERR;
+        if (!SSL_CTX_use_certificate_chain_file(ssl_ctx, certpath)) {
+            __redisSetError(c, REDIS_ERR_OTHER, "Invalid client certificate");
+            goto error;
         }
-        if (!SSL_CTX_use_PrivateKey_file(s->ctx, keypath, SSL_FILETYPE_PEM)) {
-            __redisSetError(c, REDIS_ERR, "Invalid client key");
-            return REDIS_ERR;
+        if (!SSL_CTX_use_PrivateKey_file(ssl_ctx, keypath, SSL_FILETYPE_PEM)) {
+            __redisSetError(c, REDIS_ERR_OTHER, "Invalid client key");
+            goto error;
         }
     }
 
-    s->ssl = SSL_new(s->ctx);
-    if (!s->ssl) {
-        __redisSetError(c, REDIS_ERR, "Couldn't create new SSL instance");
-        return REDIS_ERR;
+    ssl = SSL_new(ssl_ctx);
+    if (!ssl) {
+        __redisSetError(c, REDIS_ERR_OTHER, "Couldn't create new SSL instance");
+        goto error;
     }
     if (servername) {
-        if (!SSL_set_tlsext_host_name(s->ssl, servername)) {
-            __redisSetError(c, REDIS_ERR, "Couldn't set server name indication");
-            return REDIS_ERR;
+        if (!SSL_set_tlsext_host_name(ssl, servername)) {
+            __redisSetError(c, REDIS_ERR_OTHER, "Couldn't set server name indication");
+            goto error;
         }
     }
 
-    SSL_set_fd(s->ssl, c->fd);
-    SSL_set_connect_state(s->ssl);
+    return redisSSLConnect(c, ssl_ctx, ssl);
 
-    c->flags |= REDIS_SSL;
-    ERR_clear_error();
-    int rv = SSL_connect(c->ssl->ssl);
-    if (rv == 1) {
-        return REDIS_OK;
-    }
-
-    rv = SSL_get_error(s->ssl, rv);
-    if (((c->flags & REDIS_BLOCK) == 0) &&
-        (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) {
-        return REDIS_OK;
-    }
-
-    if (c->err == 0) {
-        char err[512];
-        if (rv == SSL_ERROR_SYSCALL)
-            snprintf(err,sizeof(err)-1,"SSL_connect failed: %s",strerror(errno));
-        else {
-            unsigned long e = ERR_peek_last_error();
-            snprintf(err,sizeof(err)-1,"SSL_connect failed: %s",
-                    ERR_reason_error_string(e));
-        }
-        __redisSetError(c, REDIS_ERR_IO, err);
-    }
+error:
+    if (ssl) SSL_free(ssl);
+    if (ssl_ctx) SSL_CTX_free(ssl_ctx);
     return REDIS_ERR;
 }
 
-static int maybeCheckWant(redisSsl *rssl, int rv) {
+static int maybeCheckWant(redisSSLContext *rssl, int rv) {
     /**
      * If the error is WANT_READ or WANT_WRITE, the appropriate flags are set
      * and true is returned. False is returned otherwise
@@ -192,15 +296,36 @@ static int maybeCheckWant(redisSsl *rssl, int rv) {
     }
 }
 
-int redisSslRead(redisContext *c, char *buf, size_t bufcap) {
-    int nread = SSL_read(c->ssl->ssl, buf, bufcap);
+/**
+ * Implementation of redisContextFuncs for SSL connections.
+ */
+
+static void redisSSLFreeContext(void *privdata){
+    redisSSLContext *rsc = privdata;
+
+    if (!rsc) return;
+    if (rsc->ssl) {
+        SSL_free(rsc->ssl);
+        rsc->ssl = NULL;
+    }
+    if (rsc->ssl_ctx) {
+        SSL_CTX_free(rsc->ssl_ctx);
+        rsc->ssl_ctx = NULL;
+    }
+    free(rsc);
+}
+
+static int redisSSLRead(redisContext *c, char *buf, size_t bufcap) {
+    redisSSLContext *rssl = c->privdata;
+
+    int nread = SSL_read(rssl->ssl, buf, bufcap);
     if (nread > 0) {
         return nread;
     } else if (nread == 0) {
         __redisSetError(c, REDIS_ERR_EOF, "Server closed the connection");
         return -1;
     } else {
-        int err = SSL_get_error(c->ssl->ssl, nread);
+        int err = SSL_get_error(rssl->ssl, nread);
         if (c->flags & REDIS_BLOCK) {
             /**
              * In blocking mode, we should never end up in a situation where
@@ -223,7 +348,7 @@ int redisSslRead(redisContext *c, char *buf, size_t bufcap) {
         /**
          * We can very well get an EWOULDBLOCK/EAGAIN, however
          */
-        if (maybeCheckWant(c->ssl, err)) {
+        if (maybeCheckWant(rssl, err)) {
             return 0;
         } else {
             __redisSetError(c, REDIS_ERR_IO, NULL);
@@ -232,17 +357,19 @@ int redisSslRead(redisContext *c, char *buf, size_t bufcap) {
     }
 }
 
-int redisSslWrite(redisContext *c) {
-    size_t len = c->ssl->lastLen ? c->ssl->lastLen : sdslen(c->obuf);
-    int rv = SSL_write(c->ssl->ssl, c->obuf, len);
+static int redisSSLWrite(redisContext *c) {
+    redisSSLContext *rssl = c->privdata;
+
+    size_t len = rssl->lastLen ? rssl->lastLen : sdslen(c->obuf);
+    int rv = SSL_write(rssl->ssl, c->obuf, len);
 
     if (rv > 0) {
-        c->ssl->lastLen = 0;
+        rssl->lastLen = 0;
     } else if (rv < 0) {
-        c->ssl->lastLen = len;
+        rssl->lastLen = len;
 
-        int err = SSL_get_error(c->ssl->ssl, rv);
-        if ((c->flags & REDIS_BLOCK) == 0 && maybeCheckWant(c->ssl, err)) {
+        int err = SSL_get_error(rssl->ssl, rv);
+        if ((c->flags & REDIS_BLOCK) == 0 && maybeCheckWant(rssl, err)) {
             return 0;
         } else {
             __redisSetError(c, REDIS_ERR_IO, NULL);
@@ -252,4 +379,72 @@ int redisSslWrite(redisContext *c) {
     return rv;
 }
 
+static void redisSSLAsyncRead(redisAsyncContext *ac) {
+    int rv;
+    redisSSLContext *rssl = ac->c.privdata;
+    redisContext *c = &ac->c;
+
+    rssl->wantRead = 0;
+
+    if (rssl->pendingWrite) {
+        int done;
+
+        /* This is probably just a write event */
+        rssl->pendingWrite = 0;
+        rv = redisBufferWrite(c, &done);
+        if (rv == REDIS_ERR) {
+            __redisAsyncDisconnect(ac);
+            return;
+        } else if (!done) {
+            _EL_ADD_WRITE(ac);
+        }
+    }
+
+    rv = redisBufferRead(c);
+    if (rv == REDIS_ERR) {
+        __redisAsyncDisconnect(ac);
+    } else {
+        _EL_ADD_READ(ac);
+        redisProcessCallbacks(ac);
+    }
+}
+
+static void redisSSLAsyncWrite(redisAsyncContext *ac) {
+    int rv, done = 0;
+    redisSSLContext *rssl = ac->c.privdata;
+    redisContext *c = &ac->c;
+
+    rssl->pendingWrite = 0;
+    rv = redisBufferWrite(c, &done);
+    if (rv == REDIS_ERR) {
+        __redisAsyncDisconnect(ac);
+        return;
+    }
+
+    if (!done) {
+        if (rssl->wantRead) {
+            /* Need to read-before-write */
+            rssl->pendingWrite = 1;
+            _EL_DEL_WRITE(ac);
+        } else {
+            /* No extra reads needed, just need to write more */
+            _EL_ADD_WRITE(ac);
+        }
+    } else {
+        /* Already done! */
+        _EL_DEL_WRITE(ac);
+    }
+
+    /* Always reschedule a read */
+    _EL_ADD_READ(ac);
+}
+
+redisContextFuncs redisContextSSLFuncs = {
+    .free_privdata = redisSSLFreeContext,
+    .async_read = redisSSLAsyncRead,
+    .async_write = redisSSLAsyncWrite,
+    .read = redisSSLRead,
+    .write = redisSSLWrite
+};
+
 #endif
diff --git a/sslio.h b/sslio.h
deleted file mode 100644
index e5493b7..0000000
--- a/sslio.h
+++ /dev/null
@@ -1,64 +0,0 @@
-#ifndef REDIS_SSLIO_H
-#define REDIS_SSLIO_H
-
-
-#ifndef HIREDIS_SSL
-typedef struct redisSsl {
-    size_t lastLen;
-    int wantRead;
-    int pendingWrite;
-} redisSsl;
-static inline void redisFreeSsl(redisSsl *ssl) {
-    (void)ssl;
-}
-static inline int redisSslCreate(struct redisContext *c, const char *ca,
-                          const char *cert, const char *key, const char *servername) {
-    (void)c;(void)ca;(void)cert;(void)key;(void)servername;
-    return REDIS_ERR;
-}
-static inline int redisSslRead(struct redisContext *c, char *s, size_t n) {
-    (void)c;(void)s;(void)n;
-    return -1;
-}
-static inline int redisSslWrite(struct redisContext *c) {
-    (void)c;
-    return -1;
-}
-#else
-#include <openssl/ssl.h>
-
-/**
- * This file contains routines for HIREDIS' SSL
- */
-
-typedef struct redisSsl {
-    SSL *ssl;
-    SSL_CTX *ctx;
-
-    /**
-     * SSL_write() requires to be called again with the same arguments it was
-     * previously called with in the event of an SSL_read/SSL_write situation
-     */
-    size_t lastLen;
-
-    /** Whether the SSL layer requires read (possibly before a write) */
-    int wantRead;
-
-    /**
-     * Whether a write was requested prior to a read. If set, the write()
-     * should resume whenever a read takes place, if possible
-     */
-    int pendingWrite;
-} redisSsl;
-
-struct redisContext;
-
-void redisFreeSsl(redisSsl *);
-int redisSslCreate(struct redisContext *c, const char *caPath,
-                   const char *certPath, const char *keyPath, const char *servername);
-
-int redisSslRead(struct redisContext *c, char *buf, size_t bufcap);
-int redisSslWrite(struct redisContext *c);
-
-#endif /* HIREDIS_SSL */
-#endif /* HIREDIS_SSLIO_H */
-- 
cgit v1.2.3