diff options
| author | Pieter Noordhuis <pcnoordhuis@gmail.com> | 2011-04-21 20:59:41 +0200 | 
|---|---|---|
| committer | Pieter Noordhuis <pcnoordhuis@gmail.com> | 2011-04-21 20:59:41 +0200 | 
| commit | dd5fc26457017aaf6981a99cc07c04df2a8fe3c2 (patch) | |
| tree | 55dbe467a3aee3d1a3a3ba7574209342af73400b | |
| parent | d4ebb60d65499ca5e2290eedcd02beb4736771ca (diff) | |
| download | hiredict-dd5fc26457017aaf6981a99cc07c04df2a8fe3c2.tar.xz | |
Make command formatters gracefully abort when out of memory
| -rw-r--r-- | hiredis.c | 108 | 
1 files changed, 77 insertions, 31 deletions
| @@ -652,59 +652,74 @@ static int intlen(int i) {      return len;  } -/* Helper function for redisvFormatCommand(). */ -static void addArgument(sds a, char ***argv, int *argc, int *totlen) { -    (*argc)++; -    if ((*argv = realloc(*argv, sizeof(char*)*(*argc))) == NULL) redisOOM(); -    if (totlen) *totlen = *totlen+1+intlen(sdslen(a))+2+sdslen(a)+2; -    (*argv)[(*argc)-1] = a; +/* Helper that calculates the bulk length given a certain string length. */ +static size_t bulklen(size_t len) { +    return 1+intlen(len)+2+len+2;  }  int redisvFormatCommand(char **target, const char *format, va_list ap) { -    size_t size; -    const char *arg, *c = format; +    const char *c = format;      char *cmd = NULL; /* final command */      int pos; /* position in final command */ -    sds current; /* current argument */ +    sds curarg, newarg; /* current argument */      int touched = 0; /* was the current argument touched? */ -    char **argv = NULL; -    int argc = 0, j; +    char **curargv = NULL, **newargv = NULL; +    int argc = 0;      int totlen = 0; +    int j;      /* Abort if there is not target to set */      if (target == NULL)          return -1;      /* Build the command string accordingly to protocol */ -    current = sdsempty(); +    curarg = sdsempty(); +    if (curarg == NULL) +        return -1; +      while(*c != '\0') {          if (*c != '%' || c[1] == '\0') {              if (*c == ' ') {                  if (touched) { -                    addArgument(current, &argv, &argc, &totlen); -                    current = sdsempty(); +                    newargv = realloc(curargv,sizeof(char*)*(argc+1)); +                    if (newargv == NULL) goto err; +                    curargv = newargv; +                    curargv[argc++] = curarg; +                    totlen += bulklen(sdslen(curarg)); + +                    /* curarg is put in argv so it can be overwritten. */ +                    curarg = sdsempty(); +                    if (curarg == NULL) goto err;                      touched = 0;                  }              } else { -                current = sdscatlen(current,c,1); +                newarg = sdscatlen(curarg,c,1); +                if (newarg == NULL) goto err; +                curarg = newarg;                  touched = 1;              }          } else { +            char *arg; +            size_t size; + +            /* Set newarg so it can be checked even if it is not touched. */ +            newarg = curarg; +              switch(c[1]) {              case 's':                  arg = va_arg(ap,char*);                  size = strlen(arg);                  if (size > 0) -                    current = sdscatlen(current,arg,size); +                    newarg = sdscatlen(curarg,arg,size);                  break;              case 'b':                  arg = va_arg(ap,char*);                  size = va_arg(ap,size_t);                  if (size > 0) -                    current = sdscatlen(current,arg,size); +                    newarg = sdscatlen(curarg,arg,size);                  break;              case '%': -                current = sdscat(current,"%"); +                newarg = sdscat(curarg,"%");                  break;              default:                  /* Try to detect printf format */ @@ -746,7 +761,7 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) {                              memcpy(_format,c,_l);                              _format[_l] = '\0';                              va_copy(_cpy,ap); -                            current = sdscatvprintf(current,_format,_cpy); +                            newarg = sdscatvprintf(curarg,_format,_cpy);                              va_end(_cpy);                              /* Update current position (note: outer blocks @@ -759,6 +774,10 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) {                      va_arg(ap,void);                  }              } + +            if (newarg == NULL) goto err; +            curarg = newarg; +              touched = 1;              c++;          } @@ -767,31 +786,55 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) {      /* Add the last argument if needed */      if (touched) { -        addArgument(current, &argv, &argc, &totlen); +        newargv = realloc(curargv,sizeof(char*)*(argc+1)); +        if (newargv == NULL) goto err; +        curargv = newargv; +        curargv[argc++] = curarg; +        totlen += bulklen(sdslen(curarg));      } else { -        sdsfree(current); +        sdsfree(curarg);      } +    /* Clear curarg because it was put in curargv or was free'd. */ +    curarg = NULL; +      /* Add bytes needed to hold multi bulk count */      totlen += 1+intlen(argc)+2;      /* Build the command at protocol level */      cmd = malloc(totlen+1); -    if (!cmd) redisOOM(); +    if (cmd == NULL) goto err; +      pos = sprintf(cmd,"*%d\r\n",argc);      for (j = 0; j < argc; j++) { -        pos += sprintf(cmd+pos,"$%zu\r\n",sdslen(argv[j])); -        memcpy(cmd+pos,argv[j],sdslen(argv[j])); -        pos += sdslen(argv[j]); -        sdsfree(argv[j]); +        pos += sprintf(cmd+pos,"$%zu\r\n",sdslen(curargv[j])); +        memcpy(cmd+pos,curargv[j],sdslen(curargv[j])); +        pos += sdslen(curargv[j]); +        sdsfree(curargv[j]);          cmd[pos++] = '\r';          cmd[pos++] = '\n';      }      assert(pos == totlen); -    free(argv); -    cmd[totlen] = '\0'; +    cmd[pos] = '\0'; + +    free(curargv);      *target = cmd;      return totlen; + +err: +    while(argc--) +        sdsfree(curargv[argc]); +    free(curargv); + +    if (curarg != NULL) +        sdsfree(curarg); + +    /* No need to check cmd since it is the last statement that can fail, +     * but do it anyway to be as defensive as possible. */ +    if (cmd != NULL) +        free(cmd); + +    return -1;  }  /* Format a command according to the Redis protocol. This function @@ -830,12 +873,14 @@ int redisFormatCommandArgv(char **target, int argc, const char **argv, const siz      totlen = 1+intlen(argc)+2;      for (j = 0; j < argc; j++) {          len = argvlen ? argvlen[j] : strlen(argv[j]); -        totlen += 1+intlen(len)+2+len+2; +        totlen += bulklen(len);      }      /* Build the command at protocol level */      cmd = malloc(totlen+1); -    if (!cmd) redisOOM(); +    if (cmd == NULL) +        return -1; +      pos = sprintf(cmd,"*%d\r\n",argc);      for (j = 0; j < argc; j++) {          len = argvlen ? argvlen[j] : strlen(argv[j]); @@ -846,7 +891,8 @@ int redisFormatCommandArgv(char **target, int argc, const char **argv, const siz          cmd[pos++] = '\n';      }      assert(pos == totlen); -    cmd[totlen] = '\0'; +    cmd[pos] = '\0'; +      *target = cmd;      return totlen;  } | 
