diff options
author | anon5 <anon5clam@protonmail.com> | 2020-11-07 18:01:24 +0000 |
---|---|---|
committer | anon5 <anon5clam@protonmail.com> | 2020-11-07 18:01:24 +0000 |
commit | a865d2bce1aa097273fdb9d0d02d9cfa8460aefd (patch) | |
tree | 95ff9b35607b57b330ec579f4b96f84e3bf33349 /rudp/process.go | |
download | mt-a865d2bce1aa097273fdb9d0d02d9cfa8460aefd.tar.xz |
Initial public release
Diffstat (limited to 'rudp/process.go')
-rw-r--r-- | rudp/process.go | 259 |
1 files changed, 259 insertions, 0 deletions
diff --git a/rudp/process.go b/rudp/process.go new file mode 100644 index 0000000..c36af81 --- /dev/null +++ b/rudp/process.go @@ -0,0 +1,259 @@ +package rudp + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" +) + +// A PktError is an error that occured while processing a packet. +type PktError struct { + Type string // "net", "raw" or "rel". + Data []byte + Err error +} + +func (e PktError) Error() string { + return "error processing " + e.Type + " pkt: " + + hex.EncodeToString(e.Data) + ": " + + e.Err.Error() +} + +func (e PktError) Unwrap() error { return e.Err } + +func (p *Peer) processNetPkts(pkts <-chan netPkt) { + for pkt := range pkts { + if err := p.processNetPkt(pkt); err != nil { + p.errs <- PktError{"net", pkt.Data, err} + } + } + + close(p.pkts) +} + +// A TrailingDataError reports a packet with trailing data, +// it doesn't stop a packet from being processed. +type TrailingDataError []byte + +func (e TrailingDataError) Error() string { + return "trailing data: " + hex.EncodeToString([]byte(e)) +} + +func (p *Peer) processNetPkt(pkt netPkt) (err error) { + if pkt.SrcAddr.String() != p.Addr().String() { + return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String()) + } + + if len(pkt.Data) < MtHdrSize { + return io.ErrUnexpectedEOF + } + + if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID { + return fmt.Errorf("unsupported protocol id: 0x%08x", id) + } + + // src PeerID at pkt.Data[4:6] + + chno := pkt.Data[6] + if chno >= ChannelCount { + return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno) + } + + p.mu.RLock() + if p.timeout != nil { + p.timeout.Reset(ConnTimeout) + } + p.mu.RUnlock() + + rpkt := rawPkt{ + Data: pkt.Data[MtHdrSize:], + ChNo: chno, + Unrel: true, + } + if err := p.processRawPkt(rpkt); err != nil { + p.errs <- PktError{"raw", rpkt.Data, err} + } + + return nil +} + +func (p *Peer) processRawPkt(pkt rawPkt) (err error) { + errWrap := func(format string, a ...interface{}) { + if err != nil { + err = fmt.Errorf(format, append(a, err)...) + } + } + + c := &p.chans[pkt.ChNo] + + if len(pkt.Data) < 1 { + return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF) + } + switch t := rawType(pkt.Data[0]); t { + case rawTypeCtl: + defer errWrap("ctl: %w") + + if len(pkt.Data) < 1+1 { + return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF) + } + switch ct := ctlType(pkt.Data[1]); ct { + case ctlAck: + defer errWrap("ack: %w") + + if len(pkt.Data) < 1+1+2 { + return io.ErrUnexpectedEOF + } + + sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4])) + + if ack, ok := c.ackchans.LoadAndDelete(sn); ok { + close(ack.(chan struct{})) + } + + if len(pkt.Data) > 1+1+2 { + return TrailingDataError(pkt.Data[1+1+2:]) + } + case ctlSetPeerID: + defer errWrap("set peer id: %w") + + if len(pkt.Data) < 1+1+2 { + return io.ErrUnexpectedEOF + } + + // Ensure no concurrent senders while peer id changes. + p.mu.Lock() + if p.idOfPeer != PeerIDNil { + return errors.New("peer id already set") + } + + p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4])) + p.mu.Unlock() + + if len(pkt.Data) > 1+1+2 { + return TrailingDataError(pkt.Data[1+1+2:]) + } + case ctlPing: + defer errWrap("ping: %w") + + if len(pkt.Data) > 1+1 { + return TrailingDataError(pkt.Data[1+1:]) + } + case ctlDisco: + defer errWrap("disco: %w") + + if err := p.Close(); err != nil { + return fmt.Errorf("can't close: %w", err) + } + + if len(pkt.Data) > 1+1 { + return TrailingDataError(pkt.Data[1+1:]) + } + default: + return fmt.Errorf("unsupported ctl type: %d", ct) + } + case rawTypeOrig: + p.pkts <- Pkt{ + Data: pkt.Data[1:], + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + } + case rawTypeSplit: + defer errWrap("split: %w") + + if len(pkt.Data) < 1+2+2+2 { + return io.ErrUnexpectedEOF + } + + sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3])) + count := binary.BigEndian.Uint16(pkt.Data[3:5]) + i := binary.BigEndian.Uint16(pkt.Data[5:7]) + + if i >= count { + return nil + } + + splitpkts := p.chans[pkt.ChNo].insplit + + // Delete old incomplete split packets + // so new ones don't get corrupted. + delete(splitpkts, sn-0x8000) + + if splitpkts[sn] == nil { + splitpkts[sn] = make([][]byte, count) + } + + chunks := splitpkts[sn] + + if int(count) != len(chunks) { + return fmt.Errorf("chunk count changed on seqnum: %d", sn) + } + + chunks[i] = pkt.Data[7:] + + for _, chunk := range chunks { + if chunk == nil { + return nil + } + } + + var data []byte + for _, chunk := range chunks { + data = append(data, chunk...) + } + + p.pkts <- Pkt{ + Data: data, + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + } + + delete(splitpkts, sn) + case rawTypeRel: + defer errWrap("rel: %w") + + if len(pkt.Data) < 1+2 { + return io.ErrUnexpectedEOF + } + + sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3])) + + ackdata := make([]byte, 1+1+2) + ackdata[0] = uint8(rawTypeCtl) + ackdata[1] = uint8(ctlAck) + binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn)) + ack := rawPkt{ + Data: ackdata, + ChNo: pkt.ChNo, + Unrel: true, + } + if _, err := p.sendRaw(ack); err != nil { + return fmt.Errorf("can't ack %d: %w", sn, err) + } + + if sn-c.inrelsn >= 0x8000 { + return nil // Already received. + } + + c.inrel[sn] = pkt.Data[3:] + + for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ { + data := c.inrel[c.inrelsn] + delete(c.inrel, c.inrelsn) + + rpkt := rawPkt{ + Data: data, + ChNo: pkt.ChNo, + Unrel: false, + } + if err := p.processRawPkt(rpkt); err != nil { + p.errs <- PktError{"rel", rpkt.Data, err} + } + } + default: + return fmt.Errorf("unsupported pkt type: %d", t) + } + + return nil +} |