diff options
| author | Pieter Noordhuis <pcnoordhuis@gmail.com> | 2011-04-20 13:15:58 +0200 | 
|---|---|---|
| committer | Pieter Noordhuis <pcnoordhuis@gmail.com> | 2011-04-20 13:15:58 +0200 | 
| commit | 5d78214557f043ee135ef898f4738a93bbcbe525 (patch) | |
| tree | 810597b969d1c71ccac016aaacebf56b880219f2 | |
| parent | 178024244d4040a4c760527bb58fcf5bdc962b02 (diff) | |
| download | hiredict-5d78214557f043ee135ef898f4738a93bbcbe525.tar.xz | |
First pass at making the protocol reader properly handle OOM
| -rw-r--r-- | hiredis.c | 124 | 
1 files changed, 94 insertions, 30 deletions
| @@ -41,6 +41,10 @@  #include "sds.h"  #include "util.h" +#define REDIS_READER_OOM -2 +#define REDIS_READER_NEED_MORE_DATA -1 +#define REDIS_READER_OK 0 +  typedef struct redisReader {      struct redisReplyObjectFunctions *fn;      sds error; /* holds optional error */ @@ -62,7 +66,8 @@ static void *createIntegerObject(const redisReadTask *task, long long value);  static void *createNilObject(const redisReadTask *task);  static void redisSetReplyReaderError(redisReader *r, sds err); -/* Default set of functions to build the reply. */ +/* Default set of functions to build the reply. Keep in mind that such a + * function returning NULL is interpreted as OOM. */  static redisReplyObjectFunctions defaultFunctions = {      createStringObject,      createArrayObject, @@ -73,9 +78,11 @@ static redisReplyObjectFunctions defaultFunctions = {  /* Create a reply object */  static redisReply *createReplyObject(int type) { -    redisReply *r = malloc(sizeof(*r)); +    redisReply *r = calloc(1,sizeof(*r)); + +    if (r == NULL) +        return NULL; -    if (!r) redisOOM();      r->type = type;      return r;  } @@ -89,35 +96,49 @@ void freeReplyObject(void *reply) {      case REDIS_REPLY_INTEGER:          break; /* Nothing to free */      case REDIS_REPLY_ARRAY: -        for (j = 0; j < r->elements; j++) -            if (r->element[j]) freeReplyObject(r->element[j]); -        free(r->element); +        if (r->elements > 0 && r->element != NULL) { +            for (j = 0; j < r->elements; j++) +                if (r->element[j] != NULL) +                    freeReplyObject(r->element[j]); +            free(r->element); +        }          break;      case REDIS_REPLY_ERROR:      case REDIS_REPLY_STATUS:      case REDIS_REPLY_STRING: -        free(r->str); +        if (r->str != NULL) +            free(r->str);          break;      }      free(r);  }  static void *createStringObject(const redisReadTask *task, char *str, size_t len) { -    redisReply *r = createReplyObject(task->type); -    char *value = malloc(len+1); -    if (!value) redisOOM(); -    assert(task->type == REDIS_REPLY_ERROR || +    redisReply *r, *parent; +    char *buf; + +    r = createReplyObject(task->type); +    if (r == NULL) +        return NULL; + +    buf = malloc(len+1); +    if (buf == NULL) { +        freeReplyObject(r); +        return NULL; +    } + +    assert(task->type == REDIS_REPLY_ERROR  ||             task->type == REDIS_REPLY_STATUS ||             task->type == REDIS_REPLY_STRING);      /* Copy string value */ -    memcpy(value,str,len); -    value[len] = '\0'; -    r->str = value; +    memcpy(buf,str,len); +    buf[len] = '\0'; +    r->str = buf;      r->len = len;      if (task->parent) { -        redisReply *parent = task->parent->obj; +        parent = task->parent->obj;          assert(parent->type == REDIS_REPLY_ARRAY);          parent->element[task->idx] = r;      } @@ -125,12 +146,22 @@ static void *createStringObject(const redisReadTask *task, char *str, size_t len  }  static void *createArrayObject(const redisReadTask *task, int elements) { -    redisReply *r = createReplyObject(REDIS_REPLY_ARRAY); +    redisReply *r, *parent; + +    r = createReplyObject(REDIS_REPLY_ARRAY); +    if (r == NULL) +        return NULL; + +    r->element = calloc(elements,sizeof(redisReply*)); +    if (r->element == NULL) { +        freeReplyObject(r); +        return NULL; +    } +      r->elements = elements; -    if ((r->element = calloc(sizeof(redisReply*),elements)) == NULL) -        redisOOM(); +      if (task->parent) { -        redisReply *parent = task->parent->obj; +        parent = task->parent->obj;          assert(parent->type == REDIS_REPLY_ARRAY);          parent->element[task->idx] = r;      } @@ -138,10 +169,16 @@ static void *createArrayObject(const redisReadTask *task, int elements) {  }  static void *createIntegerObject(const redisReadTask *task, long long value) { -    redisReply *r = createReplyObject(REDIS_REPLY_INTEGER); +    redisReply *r, *parent; + +    r = createReplyObject(REDIS_REPLY_INTEGER); +    if (r == NULL) +        return NULL; +      r->integer = value; +      if (task->parent) { -        redisReply *parent = task->parent->obj; +        parent = task->parent->obj;          assert(parent->type == REDIS_REPLY_ARRAY);          parent->element[task->idx] = r;      } @@ -149,9 +186,14 @@ static void *createIntegerObject(const redisReadTask *task, long long value) {  }  static void *createNilObject(const redisReadTask *task) { -    redisReply *r = createReplyObject(REDIS_REPLY_NIL); +    redisReply *r, *parent; + +    r = createReplyObject(REDIS_REPLY_NIL); +    if (r == NULL) +        return NULL; +      if (task->parent) { -        redisReply *parent = task->parent->obj; +        parent = task->parent->obj;          assert(parent->type == REDIS_REPLY_ARRAY);          parent->element[task->idx] = r;      } @@ -284,12 +326,16 @@ static int processLineItem(redisReader *r) {                  obj = (void*)(size_t)(cur->type);          } +        if (obj == NULL) +            return REDIS_READER_OOM; +          /* Set reply if this is the root object. */          if (r->ridx == 0) r->reply = obj;          moveToNextTask(r); -        return 0; +        return REDIS_READER_OK;      } -    return -1; + +    return REDIS_READER_NEED_MORE_DATA;  }  static int processBulkItem(redisReader *r) { @@ -328,15 +374,19 @@ static int processBulkItem(redisReader *r) {          /* Proceed when obj was created. */          if (success) { +            if (obj == NULL) +                return REDIS_READER_OOM; +              r->pos += bytelen;              /* Set reply if this is the root object. */              if (r->ridx == 0) r->reply = obj;              moveToNextTask(r); -            return 0; +            return REDIS_READER_OK;          }      } -    return -1; + +    return REDIS_READER_NEED_MORE_DATA;  }  static int processMultiBulkItem(redisReader *r) { @@ -362,6 +412,10 @@ static int processMultiBulkItem(redisReader *r) {                  obj = r->fn->createNil(cur);              else                  obj = (void*)REDIS_REPLY_NIL; + +            if (obj == NULL) +                return REDIS_READER_OOM; +              moveToNextTask(r);          } else {              if (r->fn && r->fn->createArray) @@ -369,6 +423,9 @@ static int processMultiBulkItem(redisReader *r) {              else                  obj = (void*)REDIS_REPLY_ARRAY; +            if (obj == NULL) +                return REDIS_READER_OOM; +              /* Modify task stack when there are more than 0 elements. */              if (elements > 0) {                  cur->elements = elements; @@ -387,9 +444,10 @@ static int processMultiBulkItem(redisReader *r) {          /* Set reply if this is the root object. */          if (root) r->reply = obj; -        return 0; +        return REDIS_READER_OK;      } -    return -1; + +    return REDIS_READER_NEED_MORE_DATA;  }  static int processItem(redisReader *r) { @@ -534,6 +592,8 @@ void redisReplyReaderFeed(void *reader, const char *buf, size_t len) {  int redisReplyReaderGetReply(void *reader, void **reply) {      redisReader *r = reader; +    int ret = REDIS_READER_OK; +      if (reply != NULL) *reply = NULL;      /* When the buffer is empty, there will never be a reply. */ @@ -553,9 +613,13 @@ int redisReplyReaderGetReply(void *reader, void **reply) {      /* Process items in reply. */      while (r->ridx >= 0) -        if (processItem(r) < 0) +        if ((ret = processItem(r)) != REDIS_READER_OK)              break; +    /* Set errors on OOM. */ +    if (ret == REDIS_READER_OOM) +        return REDIS_ERR; +      /* Discard part of the buffer when we've consumed at least 1k, to avoid       * doing unnecessary calls to memmove() in sds.c. */      if (r->pos >= 1024) { | 
