diff options
Diffstat (limited to 'async.c')
-rw-r--r-- | async.c | 102 |
1 files changed, 74 insertions, 28 deletions
@@ -47,8 +47,9 @@ #include "async_private.h" -/* Forward declaration of function in hiredis.c */ +/* Forward declarations of hiredis.c functions */ int __redisAppendCommand(redisContext *c, const char *cmd, size_t len); +void __redisSetError(redisContext *c, int type, const char *str); /* Functions managing dictionary of callbacks for pub/sub. */ static unsigned int callbackHash(const void *key) { @@ -58,7 +59,12 @@ static unsigned int callbackHash(const void *key) { static void *callbackValDup(void *privdata, const void *src) { ((void) privdata); - redisCallback *dup = hi_malloc(sizeof(*dup)); + redisCallback *dup; + + dup = hi_malloc(sizeof(*dup)); + if (dup == NULL) + return NULL; + memcpy(dup,src,sizeof(*dup)); return dup; } @@ -80,7 +86,7 @@ static void callbackKeyDestructor(void *privdata, void *key) { static void callbackValDestructor(void *privdata, void *val) { ((void) privdata); - free(val); + hi_free(val); } static dictType callbackDict = { @@ -94,10 +100,19 @@ static dictType callbackDict = { static redisAsyncContext *redisAsyncInitialize(redisContext *c) { redisAsyncContext *ac; + dict *channels = NULL, *patterns = NULL; + + channels = dictCreate(&callbackDict,NULL); + if (channels == NULL) + goto oom; + + patterns = dictCreate(&callbackDict,NULL); + if (patterns == NULL) + goto oom; - ac = realloc(c,sizeof(redisAsyncContext)); + ac = hi_realloc(c,sizeof(redisAsyncContext)); if (ac == NULL) - return NULL; + goto oom; c = &(ac->c); @@ -126,9 +141,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { ac->replies.tail = NULL; ac->sub.invalid.head = NULL; ac->sub.invalid.tail = NULL; - ac->sub.channels = dictCreate(&callbackDict,NULL); - ac->sub.patterns = dictCreate(&callbackDict,NULL); + ac->sub.channels = channels; + ac->sub.patterns = patterns; + return ac; +oom: + if (channels) dictRelease(channels); + if (patterns) dictRelease(patterns); + return NULL; } /* We want the error field to be accessible directly instead of requiring @@ -216,7 +236,7 @@ static int __redisPushCallback(redisCallbackList *list, redisCallback *source) { redisCallback *cb; /* Copy callback from stack to heap */ - cb = malloc(sizeof(*cb)); + cb = hi_malloc(sizeof(*cb)); if (cb == NULL) return REDIS_ERR_OOM; @@ -244,7 +264,7 @@ static int __redisShiftCallback(redisCallbackList *list, redisCallback *target) /* Copy callback from heap to stack */ if (target != NULL) memcpy(target,cb,sizeof(*cb)); - free(cb); + hi_free(cb); return REDIS_OK; } return REDIS_ERR; @@ -275,17 +295,27 @@ static void __redisAsyncFree(redisAsyncContext *ac) { __redisRunCallback(ac,&cb,NULL); /* Run subscription callbacks callbacks with NULL reply */ - it = dictGetIterator(ac->sub.channels); - while ((de = dictNext(it)) != NULL) - __redisRunCallback(ac,dictGetEntryVal(de),NULL); - dictReleaseIterator(it); - dictRelease(ac->sub.channels); - - it = dictGetIterator(ac->sub.patterns); - while ((de = dictNext(it)) != NULL) - __redisRunCallback(ac,dictGetEntryVal(de),NULL); - dictReleaseIterator(it); - dictRelease(ac->sub.patterns); + if (ac->sub.channels) { + it = dictGetIterator(ac->sub.channels); + if (it != NULL) { + while ((de = dictNext(it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); + dictReleaseIterator(it); + } + + dictRelease(ac->sub.channels); + } + + if (ac->sub.patterns) { + it = dictGetIterator(ac->sub.patterns); + if (it != NULL) { + while ((de = dictNext(it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); + dictReleaseIterator(it); + } + + dictRelease(ac->sub.patterns); + } /* Signal event lib to clean up */ _EL_CLEANUP(ac); @@ -388,6 +418,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, /* Locate the right callback */ assert(reply->element[1]->type == REDIS_REPLY_STRING); sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); + if (sname == NULL) + goto oom; + de = dictFind(callbacks,sname); if (de != NULL) { cb = dictGetEntryVal(de); @@ -421,6 +454,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, __redisShiftCallback(&ac->sub.invalid,dstcb); } return REDIS_OK; +oom: + __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; } void redisProcessCallbacks(redisAsyncContext *ac) { @@ -588,8 +624,6 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) { c->funcs->async_write(ac); } -void __redisSetError(redisContext *c, int type, const char *str); - void redisAsyncHandleTimeout(redisAsyncContext *ac) { redisContext *c = &(ac->c); redisCallback cb; @@ -672,6 +706,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void /* Add every channel/pattern to the list of subscription callbacks. */ while ((p = nextArgument(p,&astr,&alen)) != NULL) { sname = sdsnewlen(astr,alen); + if (sname == NULL) + goto oom; + if (pvariant) cbdict = ac->sub.patterns; else @@ -715,6 +752,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void _EL_ADD_WRITE(ac); return REDIS_OK; +oom: + __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; } int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *format, va_list ap) { @@ -728,7 +768,7 @@ int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdat return REDIS_ERR; status = __redisAsyncCommand(ac,fn,privdata,cmd,len); - free(cmd); + hi_free(cmd); return status; } @@ -758,15 +798,21 @@ int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void return status; } -void redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) { +int redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) { if (!ac->c.timeout) { ac->c.timeout = hi_calloc(1, sizeof(tv)); + if (ac->c.timeout == NULL) { + __redisSetError(&ac->c, REDIS_ERR_OOM, "Out of memory"); + __redisAsyncCopyError(ac); + return REDIS_ERR; + } } - if (tv.tv_sec == ac->c.timeout->tv_sec && - tv.tv_usec == ac->c.timeout->tv_usec) { - return; + if (tv.tv_sec != ac->c.timeout->tv_sec || + tv.tv_usec != ac->c.timeout->tv_usec) + { + *ac->c.timeout = tv; } - *ac->c.timeout = tv; + return REDIS_OK; } |