diff options
| -rw-r--r-- | sys/src/libsec/port/tlshand.c | 276 |
1 files changed, 133 insertions, 143 deletions
diff --git a/sys/src/libsec/port/tlshand.c b/sys/src/libsec/port/tlshand.c index d8ef46a0d..dfa1837d8 100644 --- a/sys/src/libsec/port/tlshand.c +++ b/sys/src/libsec/port/tlshand.c @@ -17,7 +17,6 @@ enum { TLSFinishedLen = 12, SSL3FinishedLen = MD5dlen+SHA1dlen, MaxKeyData = 160, // amount of secret we may need - MaxChunk = 1<<15, MAXdlen = SHA2_512dlen, RandomSize = 32, MasterSecretSize = 48, @@ -100,13 +99,8 @@ typedef struct TlsConnection{ HandshakeHash handhash; Finished finished; - // input buffer for handshake messages - uchar recvbuf[MaxChunk]; - uchar *rp, *ep; - - // output buffer - uchar sendbuf[MaxChunk]; - uchar *sendp; + uchar *sendp, *recvp, *recvw; + uchar buf[1<<16]; } TlsConnection; typedef struct Msg{ @@ -444,7 +438,7 @@ static int get24(uchar *p); static int get16(uchar *p); static Bytes* newbytes(int len); static Bytes* makebytes(uchar* buf, int len); -static Bytes* mptobytes(mpint* big); +static Bytes* mptobytes(mpint* big, int len); static mpint* bytestomp(Bytes* bytes); static void freebytes(Bytes* b); static Ints* newints(int len); @@ -696,6 +690,8 @@ tlsServer2(int ctl, int hand, c->hand = hand; c->trace = trace; c->version = ProtocolVersion; + c->sendp = c->buf; + c->recvp = c->recvw = &c->buf[sizeof(c->buf)]; memset(&m, 0, sizeof(m)); if(!msgRecv(c, &m)){ @@ -895,6 +891,7 @@ tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys) DHstate *dh = &sec->dh; mpint *G, *P, *Y, *K; Bytes *Yc; + int n; if(p == nil || g == nil || Ys == nil) return nil; @@ -907,14 +904,15 @@ tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys) if(dh_new(dh, P, nil, G) == nil) goto Out; - Yc = mptobytes(dh->y); + n = (mpsignif(P)+7)/8; + Yc = mptobytes(dh->y, n); K = dh_finish(dh, Y); /* zeros dh */ if(K == nil){ freebytes(Yc); Yc = nil; goto Out; } - setMasterSecret(sec, mptobytes(K)); + setMasterSecret(sec, mptobytes(K, n)); Out: mpfree(K); @@ -934,6 +932,7 @@ tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys) ECpub *pub; ECpoint K; Bytes *Yc; + int n; if(Ys == nil) return nil; @@ -959,8 +958,10 @@ Found: ecgen(dom, Q); ecmul(dom, pub, Q->d, &K); - setMasterSecret(sec, mptobytes(K.x)); - Yc = newbytes(1 + 2*((mpsignif(dom->p)+7)/8)); + + n = (mpsignif(dom->p)+7)/8; + setMasterSecret(sec, mptobytes(K.x, n)); + Yc = newbytes(1 + 2*n); Yc->len = ecencodepub(dom, Q, Yc->data, Yc->len); mpfree(K.x); @@ -994,6 +995,8 @@ tlsClient2(int ctl, int hand, c->hand = hand; c->trace = trace; c->cert = nil; + c->sendp = c->buf; + c->recvp = c->recvw = &c->buf[sizeof(c->buf)]; c->version = ProtocolVersion; tlsSecInitc(c->sec, c->version); @@ -1257,14 +1260,13 @@ msgHash(TlsConnection *c, uchar *p, int n) static int msgSend(TlsConnection *c, Msg *m, int act) { - uchar *p; // sendp = start of new message; p = write pointer - int nn, n, i; + uchar *p, *e; // sendp = start of new message; p = write pointer; e = end pointer + int n, i; - if(c->sendp == nil) - c->sendp = c->sendbuf; p = c->sendp; + e = c->recvp; if(c->trace) - c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m)); + c->trace("send %s", msgPrint((char*)p, e - p, m)); p[0] = m->tag; // header - fill in size later p += 4; @@ -1274,134 +1276,127 @@ msgSend(TlsConnection *c, Msg *m, int act) tlsError(c, EInternalError, "can't encode a %d", m->tag); goto Err; case HClientHello: - // version - put16(p, m->u.clientHello.version); - p += 2; - - // random + if(p+2+RandomSize > e) + goto Overflow; + put16(p, m->u.clientHello.version), p += 2; memmove(p, m->u.clientHello.random, RandomSize); p += RandomSize; - // sid - n = m->u.clientHello.sid->len; - p[0] = n; - memmove(p+1, m->u.clientHello.sid->data, n); - p += n+1; - - n = m->u.clientHello.ciphers->len; - put16(p, n*2); - p += 2; - for(i=0; i<n; i++) { - put16(p, m->u.clientHello.ciphers->data[i]); - p += 2; - } + if(p+1+(n = m->u.clientHello.sid->len) > e) + goto Overflow; + *p++ = n; + memmove(p, m->u.clientHello.sid->data, n); + p += n; - n = m->u.clientHello.compressors->len; - p[0] = n; - memmove(p+1, m->u.clientHello.compressors->data, n); - p += n+1; + if(p+2+(n = m->u.clientHello.ciphers->len) > e) + goto Overflow; + put16(p, n*2), p += 2; + for(i=0; i<n; i++) + put16(p, m->u.clientHello.ciphers->data[i]), p += 2; - if(m->u.clientHello.extensions == nil) - break; - n = m->u.clientHello.extensions->len; - if(n == 0) + if(p+1+(n = m->u.clientHello.compressors->len) > e) + goto Overflow; + *p++ = n; + memmove(p, m->u.clientHello.compressors->data, n); + p += n; + + if(m->u.clientHello.extensions == nil + || (n = m->u.clientHello.extensions->len) == 0) break; - put16(p, n); - memmove(p+2, m->u.clientHello.extensions->data, n); - p += n+2; + if(p+2+n > e) + goto Overflow; + put16(p, n), p += 2; + memmove(p, m->u.clientHello.extensions->data, n); + p += n; break; case HServerHello: - put16(p, m->u.serverHello.version); - p += 2; - - // random + if(p+2+RandomSize > e) + goto Overflow; + put16(p, m->u.serverHello.version), p += 2; memmove(p, m->u.serverHello.random, RandomSize); p += RandomSize; - // sid - n = m->u.serverHello.sid->len; - p[0] = n; - memmove(p+1, m->u.serverHello.sid->data, n); - p += n+1; + if(p+1+(n = m->u.serverHello.sid->len) > e) + goto Overflow; + *p++ = n; + memmove(p, m->u.serverHello.sid->data, n); + p += n; - put16(p, m->u.serverHello.cipher); - p += 2; - p[0] = m->u.serverHello.compressor; - p += 1; + if(p+2+1 > e) + goto Overflow; + put16(p, m->u.serverHello.cipher), p += 2; + *p++ = m->u.serverHello.compressor; - if(m->u.serverHello.extensions == nil) + if(m->u.serverHello.extensions == nil + || (n = m->u.serverHello.extensions->len) == 0) break; - n = m->u.serverHello.extensions->len; - if(n == 0) - break; - put16(p, n); - memmove(p+2, m->u.serverHello.extensions->data, n); - p += n+2; + if(p+2+n > e) + goto Overflow; + put16(p, n), p += 2; + memmove(p, m->u.serverHello.extensions->data, n); + p += n; break; case HServerHelloDone: break; case HCertificate: - nn = 0; + n = 0; for(i = 0; i < m->u.certificate.ncert; i++) - nn += 3 + m->u.certificate.certs[i]->len; - if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) { - tlsError(c, EInternalError, "output buffer too small for certificate"); - goto Err; - } - put24(p, nn); - p += 3; + n += 3 + m->u.certificate.certs[i]->len; + if(p+3+n > e) + goto Overflow; + put24(p, n), p += 3; for(i = 0; i < m->u.certificate.ncert; i++){ - put24(p, m->u.certificate.certs[i]->len); - p += 3; - memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len); - p += m->u.certificate.certs[i]->len; + n = m->u.certificate.certs[i]->len; + put24(p, n), p += 3; + memmove(p, m->u.certificate.certs[i]->data, n); + p += n; } break; case HCertificateVerify: - if(m->u.certificateVerify.sigalg != 0){ - put16(p, m->u.certificateVerify.sigalg); - p += 2; - } - put16(p, m->u.certificateVerify.signature->len); - p += 2; - memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len); - p += m->u.certificateVerify.signature->len; + if(p+2+2+(n = m->u.certificateVerify.signature->len) > e) + goto Overflow; + if(m->u.certificateVerify.sigalg != 0) + put16(p, m->u.certificateVerify.sigalg), p += 2; + put16(p, n), p += 2; + memmove(p, m->u.certificateVerify.signature->data, n); + p += n; break; case HServerKeyExchange: if(m->u.serverKeyExchange.pskid != nil){ - n = m->u.serverKeyExchange.pskid->len; - put16(p, n); - p += 2; + if(p+2+(n = m->u.serverKeyExchange.pskid->len) > e) + goto Overflow; + put16(p, n), p += 2; memmove(p, m->u.serverKeyExchange.pskid->data, n); p += n; } if(m->u.serverKeyExchange.dh_parameters == nil) break; - n = m->u.serverKeyExchange.dh_parameters->len; + if(p+(n = m->u.serverKeyExchange.dh_parameters->len) > e) + goto Overflow; memmove(p, m->u.serverKeyExchange.dh_parameters->data, n); p += n; if(m->u.serverKeyExchange.dh_signature == nil) break; - if(c->version >= TLS12Version){ - put16(p, m->u.serverKeyExchange.sigalg); - p += 2; - } - n = m->u.serverKeyExchange.dh_signature->len; + if(p+2+2+(n = m->u.serverKeyExchange.dh_signature->len) > e) + goto Overflow; + if(c->version >= TLS12Version) + put16(p, m->u.serverKeyExchange.sigalg), p += 2; put16(p, n), p += 2; memmove(p, m->u.serverKeyExchange.dh_signature->data, n); p += n; break; case HClientKeyExchange: if(m->u.clientKeyExchange.pskid != nil){ - n = m->u.clientKeyExchange.pskid->len; - put16(p, n); - p += 2; + if(p+2+(n = m->u.clientKeyExchange.pskid->len) > e) + goto Overflow; + put16(p, n), p += 2; memmove(p, m->u.clientKeyExchange.pskid->data, n); p += n; } if(m->u.clientKeyExchange.key == nil) break; - n = m->u.clientKeyExchange.key->len; + if(p+2+(n = m->u.clientKeyExchange.key->len) > e) + goto Overflow; if(isECDHE(c->cipher)) *p++ = n; else if(isDHE(c->cipher) || c->version != SSL3Version) @@ -1410,6 +1405,8 @@ msgSend(TlsConnection *c, Msg *m, int act) p += n; break; case HFinished: + if(p+m->u.finished.n > e) + goto Overflow; memmove(p, m->u.finished.verify, m->u.finished.n); p += m->u.finished.n; break; @@ -1417,7 +1414,6 @@ msgSend(TlsConnection *c, Msg *m, int act) // go back and fill in size n = p - c->sendp; - assert(n <= sizeof(c->sendbuf)); put24(c->sendp+1, n-4); // remember hash of Handshake messages @@ -1426,14 +1422,16 @@ msgSend(TlsConnection *c, Msg *m, int act) c->sendp = p; if(act == AFlush){ - c->sendp = c->sendbuf; - if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){ + c->sendp = c->buf; + if(write(c->hand, c->buf, p - c->buf) < 0){ fprint(2, "write error: %r\n"); goto Err; } } msgClear(m); return 1; +Overflow: + tlsError(c, EInternalError, "not enougth send buffer for message (%d)", m->tag); Err: msgClear(m); return 0; @@ -1442,25 +1440,28 @@ Err: static uchar* tlsReadN(TlsConnection *c, int n) { - uchar *p; - int nn, nr; + uchar *p, *e; - nn = c->ep - c->rp; - if(nn < n){ - if(c->rp != c->recvbuf){ - memmove(c->recvbuf, c->rp, nn); - c->rp = c->recvbuf; - c->ep = &c->recvbuf[nn]; - } - for(; nn < n; nn += nr) { - nr = read(c->hand, &c->rp[nn], n - nn); - if(nr <= 0) - return nil; - c->ep += nr; - } + p = c->recvp; + if(n <= c->recvw - p){ + c->recvp += n; + return p; + } + e = &c->buf[sizeof(c->buf)]; + c->recvp = e - n; + if(c->recvp < c->sendp || n > sizeof(c->buf)){ + tlsError(c, EDecodeError, "handshake message too long %d", n); + return nil; + } + memmove(c->recvp, p, c->recvw - p); + c->recvw -= p - c->recvp; + p = c->recvp; + c->recvp += n; + while(c->recvw < c->recvp){ + if((n = read(c->hand, c->recvw, e - c->recvw)) <= 0) + return nil; + c->recvw += n; } - p = c->rp; - c->rp += n; return p; } @@ -1486,11 +1487,6 @@ msgRecv(TlsConnection *c, Msg *m) } } - if(n > sizeof(c->recvbuf)) { - tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf)); - return 0; - } - if(type == HSSL2ClientHello){ /* Cope with an SSL3 ClientHello expressed in SSL2 record format. This is sent by some clients that we must interoperate @@ -1513,10 +1509,8 @@ msgRecv(TlsConnection *c, Msg *m) p += 6; n -= 6; if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */ - || nrandom < 16 || nn % 3) + || nrandom < 16 || nn % 3 || n - nrandom < nn) goto Err; - if(c->trace && (n - nrandom != nn)) - c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn); /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */ nciph = 0; for(i = 0; i < nn; i += 3) @@ -1806,15 +1800,11 @@ msgRecv(TlsConnection *c, Msg *m) break; } - if(type != HClientHello && type != HServerHello && n != 0) + if(n != 0 && type != HClientHello && type != HServerHello) goto Short; Ok: - if(c->trace){ - char *buf; - buf = emalloc(8000); - c->trace("recv %s", msgPrint(buf, 8000, m)); - free(buf); - } + if(c->trace) + c->trace("recv %s", msgPrint((char*)c->sendp, c->recvp - c->sendp, m)); return 1; Short: tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type); @@ -2624,7 +2614,8 @@ tlsSecECDHEs2(TlsSec *sec, Bytes *Yc) K.y = mpnew(0); ecmul(dom, Y, Q->d, &K); - setMasterSecret(sec, mptobytes(K.x)); + + setMasterSecret(sec, mptobytes(K.x, (mpsignif(dom->p)+7)/8)); mpfree(K.x); mpfree(K.y); @@ -2858,7 +2849,7 @@ pkcs1_decrypt(TlsSec *sec, Bytes *data) y = factotum_rsa_decrypt(sec->rpc, bytestomp(data)); if(y == nil) return nil; - data = mptobytes(y); + data = mptobytes(y, (mpsignif(y)+7)/8); if((data->len = pkcs1unpadbuf(data->data, data->len, sec->rsapub->n, 2)) < 0){ freebytes(data); return nil; @@ -2884,10 +2875,11 @@ pkcs1_sign(TlsSec *sec, uchar *digest, int digestlen, int sigalg) werrstr("bad digest algorithm"); return nil; } + signedMP = factotum_rsa_decrypt(sec->rpc, pkcs1padbuf(buf, digestlen, sec->rsapub->n, 1)); if(signedMP == nil) return nil; - signature = mptobytes(signedMP); + signature = mptobytes(signedMP, (mpsignif(sec->rsapub->n)+7)/8); mpfree(signedMP); return signature; } @@ -2999,14 +2991,12 @@ bytestomp(Bytes* bytes) * Convert mpint* to Bytes, putting high order byte first. */ static Bytes* -mptobytes(mpint* big) +mptobytes(mpint *big, int len) { Bytes* ans; - int n; - n = (mpsignif(big)+7)/8; - if(n == 0) n = 1; - ans = newbytes(n); + if(len == 0) len++; + ans = newbytes(len); mptober(big, ans->data, ans->len); return ans; } |
