diff options
Diffstat (limited to 'async.c')
| -rw-r--r-- | async.c | 102 | 
1 files changed, 74 insertions, 28 deletions
| @@ -47,8 +47,9 @@  #include "async_private.h" -/* Forward declaration of function in hiredis.c */ +/* Forward declarations of hiredis.c functions */  int __redisAppendCommand(redisContext *c, const char *cmd, size_t len); +void __redisSetError(redisContext *c, int type, const char *str);  /* Functions managing dictionary of callbacks for pub/sub. */  static unsigned int callbackHash(const void *key) { @@ -58,7 +59,12 @@ static unsigned int callbackHash(const void *key) {  static void *callbackValDup(void *privdata, const void *src) {      ((void) privdata); -    redisCallback *dup = hi_malloc(sizeof(*dup)); +    redisCallback *dup; + +    dup = hi_malloc(sizeof(*dup)); +    if (dup == NULL) +        return NULL; +      memcpy(dup,src,sizeof(*dup));      return dup;  } @@ -80,7 +86,7 @@ static void callbackKeyDestructor(void *privdata, void *key) {  static void callbackValDestructor(void *privdata, void *val) {      ((void) privdata); -    free(val); +    hi_free(val);  }  static dictType callbackDict = { @@ -94,10 +100,19 @@ static dictType callbackDict = {  static redisAsyncContext *redisAsyncInitialize(redisContext *c) {      redisAsyncContext *ac; +    dict *channels = NULL, *patterns = NULL; + +    channels = dictCreate(&callbackDict,NULL); +    if (channels == NULL) +        goto oom; + +    patterns = dictCreate(&callbackDict,NULL); +    if (patterns == NULL) +        goto oom; -    ac = realloc(c,sizeof(redisAsyncContext)); +    ac = hi_realloc(c,sizeof(redisAsyncContext));      if (ac == NULL) -        return NULL; +        goto oom;      c = &(ac->c); @@ -126,9 +141,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) {      ac->replies.tail = NULL;      ac->sub.invalid.head = NULL;      ac->sub.invalid.tail = NULL; -    ac->sub.channels = dictCreate(&callbackDict,NULL); -    ac->sub.patterns = dictCreate(&callbackDict,NULL); +    ac->sub.channels = channels; +    ac->sub.patterns = patterns; +      return ac; +oom: +    if (channels) dictRelease(channels); +    if (patterns) dictRelease(patterns); +    return NULL;  }  /* We want the error field to be accessible directly instead of requiring @@ -216,7 +236,7 @@ static int __redisPushCallback(redisCallbackList *list, redisCallback *source) {      redisCallback *cb;      /* Copy callback from stack to heap */ -    cb = malloc(sizeof(*cb)); +    cb = hi_malloc(sizeof(*cb));      if (cb == NULL)          return REDIS_ERR_OOM; @@ -244,7 +264,7 @@ static int __redisShiftCallback(redisCallbackList *list, redisCallback *target)          /* Copy callback from heap to stack */          if (target != NULL)              memcpy(target,cb,sizeof(*cb)); -        free(cb); +        hi_free(cb);          return REDIS_OK;      }      return REDIS_ERR; @@ -275,17 +295,27 @@ static void __redisAsyncFree(redisAsyncContext *ac) {          __redisRunCallback(ac,&cb,NULL);      /* Run subscription callbacks callbacks with NULL reply */ -    it = dictGetIterator(ac->sub.channels); -    while ((de = dictNext(it)) != NULL) -        __redisRunCallback(ac,dictGetEntryVal(de),NULL); -    dictReleaseIterator(it); -    dictRelease(ac->sub.channels); - -    it = dictGetIterator(ac->sub.patterns); -    while ((de = dictNext(it)) != NULL) -        __redisRunCallback(ac,dictGetEntryVal(de),NULL); -    dictReleaseIterator(it); -    dictRelease(ac->sub.patterns); +    if (ac->sub.channels) { +        it = dictGetIterator(ac->sub.channels); +        if (it != NULL) { +            while ((de = dictNext(it)) != NULL) +                __redisRunCallback(ac,dictGetEntryVal(de),NULL); +            dictReleaseIterator(it); +        } + +        dictRelease(ac->sub.channels); +    } + +    if (ac->sub.patterns) { +        it = dictGetIterator(ac->sub.patterns); +        if (it != NULL) { +            while ((de = dictNext(it)) != NULL) +                __redisRunCallback(ac,dictGetEntryVal(de),NULL); +            dictReleaseIterator(it); +        } + +        dictRelease(ac->sub.patterns); +    }      /* Signal event lib to clean up */      _EL_CLEANUP(ac); @@ -388,6 +418,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,          /* Locate the right callback */          assert(reply->element[1]->type == REDIS_REPLY_STRING);          sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); +        if (sname == NULL) +            goto oom; +          de = dictFind(callbacks,sname);          if (de != NULL) {              cb = dictGetEntryVal(de); @@ -421,6 +454,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,          __redisShiftCallback(&ac->sub.invalid,dstcb);      }      return REDIS_OK; +oom: +    __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); +    return REDIS_ERR;  }  void redisProcessCallbacks(redisAsyncContext *ac) { @@ -588,8 +624,6 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) {      c->funcs->async_write(ac);  } -void __redisSetError(redisContext *c, int type, const char *str); -  void redisAsyncHandleTimeout(redisAsyncContext *ac) {      redisContext *c = &(ac->c);      redisCallback cb; @@ -672,6 +706,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void          /* Add every channel/pattern to the list of subscription callbacks. */          while ((p = nextArgument(p,&astr,&alen)) != NULL) {              sname = sdsnewlen(astr,alen); +            if (sname == NULL) +                goto oom; +              if (pvariant)                  cbdict = ac->sub.patterns;              else @@ -715,6 +752,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void      _EL_ADD_WRITE(ac);      return REDIS_OK; +oom: +    __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); +    return REDIS_ERR;  }  int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *format, va_list ap) { @@ -728,7 +768,7 @@ int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdat          return REDIS_ERR;      status = __redisAsyncCommand(ac,fn,privdata,cmd,len); -    free(cmd); +    hi_free(cmd);      return status;  } @@ -758,15 +798,21 @@ int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void      return status;  } -void redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) { +int redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) {      if (!ac->c.timeout) {          ac->c.timeout = hi_calloc(1, sizeof(tv)); +        if (ac->c.timeout == NULL) { +            __redisSetError(&ac->c, REDIS_ERR_OOM, "Out of memory"); +            __redisAsyncCopyError(ac); +            return REDIS_ERR; +        }      } -    if (tv.tv_sec == ac->c.timeout->tv_sec && -        tv.tv_usec == ac->c.timeout->tv_usec) { -        return; +    if (tv.tv_sec != ac->c.timeout->tv_sec || +        tv.tv_usec != ac->c.timeout->tv_usec) +    { +        *ac->c.timeout = tv;      } -    *ac->c.timeout = tv; +    return REDIS_OK;  } | 
