summaryrefslogtreecommitdiff
path: root/async.c
diff options
context:
space:
mode:
Diffstat (limited to 'async.c')
-rw-r--r--async.c102
1 files changed, 74 insertions, 28 deletions
diff --git a/async.c b/async.c
index 68a656f..261fe24 100644
--- a/async.c
+++ b/async.c
@@ -47,8 +47,9 @@
#include "async_private.h"
-/* Forward declaration of function in hiredis.c */
+/* Forward declarations of hiredis.c functions */
int __redisAppendCommand(redisContext *c, const char *cmd, size_t len);
+void __redisSetError(redisContext *c, int type, const char *str);
/* Functions managing dictionary of callbacks for pub/sub. */
static unsigned int callbackHash(const void *key) {
@@ -58,7 +59,12 @@ static unsigned int callbackHash(const void *key) {
static void *callbackValDup(void *privdata, const void *src) {
((void) privdata);
- redisCallback *dup = hi_malloc(sizeof(*dup));
+ redisCallback *dup;
+
+ dup = hi_malloc(sizeof(*dup));
+ if (dup == NULL)
+ return NULL;
+
memcpy(dup,src,sizeof(*dup));
return dup;
}
@@ -80,7 +86,7 @@ static void callbackKeyDestructor(void *privdata, void *key) {
static void callbackValDestructor(void *privdata, void *val) {
((void) privdata);
- free(val);
+ hi_free(val);
}
static dictType callbackDict = {
@@ -94,10 +100,19 @@ static dictType callbackDict = {
static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
redisAsyncContext *ac;
+ dict *channels = NULL, *patterns = NULL;
+
+ channels = dictCreate(&callbackDict,NULL);
+ if (channels == NULL)
+ goto oom;
+
+ patterns = dictCreate(&callbackDict,NULL);
+ if (patterns == NULL)
+ goto oom;
- ac = realloc(c,sizeof(redisAsyncContext));
+ ac = hi_realloc(c,sizeof(redisAsyncContext));
if (ac == NULL)
- return NULL;
+ goto oom;
c = &(ac->c);
@@ -126,9 +141,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
ac->replies.tail = NULL;
ac->sub.invalid.head = NULL;
ac->sub.invalid.tail = NULL;
- ac->sub.channels = dictCreate(&callbackDict,NULL);
- ac->sub.patterns = dictCreate(&callbackDict,NULL);
+ ac->sub.channels = channels;
+ ac->sub.patterns = patterns;
+
return ac;
+oom:
+ if (channels) dictRelease(channels);
+ if (patterns) dictRelease(patterns);
+ return NULL;
}
/* We want the error field to be accessible directly instead of requiring
@@ -216,7 +236,7 @@ static int __redisPushCallback(redisCallbackList *list, redisCallback *source) {
redisCallback *cb;
/* Copy callback from stack to heap */
- cb = malloc(sizeof(*cb));
+ cb = hi_malloc(sizeof(*cb));
if (cb == NULL)
return REDIS_ERR_OOM;
@@ -244,7 +264,7 @@ static int __redisShiftCallback(redisCallbackList *list, redisCallback *target)
/* Copy callback from heap to stack */
if (target != NULL)
memcpy(target,cb,sizeof(*cb));
- free(cb);
+ hi_free(cb);
return REDIS_OK;
}
return REDIS_ERR;
@@ -275,17 +295,27 @@ static void __redisAsyncFree(redisAsyncContext *ac) {
__redisRunCallback(ac,&cb,NULL);
/* Run subscription callbacks callbacks with NULL reply */
- it = dictGetIterator(ac->sub.channels);
- while ((de = dictNext(it)) != NULL)
- __redisRunCallback(ac,dictGetEntryVal(de),NULL);
- dictReleaseIterator(it);
- dictRelease(ac->sub.channels);
-
- it = dictGetIterator(ac->sub.patterns);
- while ((de = dictNext(it)) != NULL)
- __redisRunCallback(ac,dictGetEntryVal(de),NULL);
- dictReleaseIterator(it);
- dictRelease(ac->sub.patterns);
+ if (ac->sub.channels) {
+ it = dictGetIterator(ac->sub.channels);
+ if (it != NULL) {
+ while ((de = dictNext(it)) != NULL)
+ __redisRunCallback(ac,dictGetEntryVal(de),NULL);
+ dictReleaseIterator(it);
+ }
+
+ dictRelease(ac->sub.channels);
+ }
+
+ if (ac->sub.patterns) {
+ it = dictGetIterator(ac->sub.patterns);
+ if (it != NULL) {
+ while ((de = dictNext(it)) != NULL)
+ __redisRunCallback(ac,dictGetEntryVal(de),NULL);
+ dictReleaseIterator(it);
+ }
+
+ dictRelease(ac->sub.patterns);
+ }
/* Signal event lib to clean up */
_EL_CLEANUP(ac);
@@ -388,6 +418,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
/* Locate the right callback */
assert(reply->element[1]->type == REDIS_REPLY_STRING);
sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len);
+ if (sname == NULL)
+ goto oom;
+
de = dictFind(callbacks,sname);
if (de != NULL) {
cb = dictGetEntryVal(de);
@@ -421,6 +454,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
__redisShiftCallback(&ac->sub.invalid,dstcb);
}
return REDIS_OK;
+oom:
+ __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory");
+ return REDIS_ERR;
}
void redisProcessCallbacks(redisAsyncContext *ac) {
@@ -588,8 +624,6 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) {
c->funcs->async_write(ac);
}
-void __redisSetError(redisContext *c, int type, const char *str);
-
void redisAsyncHandleTimeout(redisAsyncContext *ac) {
redisContext *c = &(ac->c);
redisCallback cb;
@@ -672,6 +706,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
/* Add every channel/pattern to the list of subscription callbacks. */
while ((p = nextArgument(p,&astr,&alen)) != NULL) {
sname = sdsnewlen(astr,alen);
+ if (sname == NULL)
+ goto oom;
+
if (pvariant)
cbdict = ac->sub.patterns;
else
@@ -715,6 +752,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
_EL_ADD_WRITE(ac);
return REDIS_OK;
+oom:
+ __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory");
+ return REDIS_ERR;
}
int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *format, va_list ap) {
@@ -728,7 +768,7 @@ int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdat
return REDIS_ERR;
status = __redisAsyncCommand(ac,fn,privdata,cmd,len);
- free(cmd);
+ hi_free(cmd);
return status;
}
@@ -758,15 +798,21 @@ int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
return status;
}
-void redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) {
+int redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) {
if (!ac->c.timeout) {
ac->c.timeout = hi_calloc(1, sizeof(tv));
+ if (ac->c.timeout == NULL) {
+ __redisSetError(&ac->c, REDIS_ERR_OOM, "Out of memory");
+ __redisAsyncCopyError(ac);
+ return REDIS_ERR;
+ }
}
- if (tv.tv_sec == ac->c.timeout->tv_sec &&
- tv.tv_usec == ac->c.timeout->tv_usec) {
- return;
+ if (tv.tv_sec != ac->c.timeout->tv_sec ||
+ tv.tv_usec != ac->c.timeout->tv_usec)
+ {
+ *ac->c.timeout = tv;
}
- *ac->c.timeout = tv;
+ return REDIS_OK;
}