summaryrefslogtreecommitdiff
path: root/rudp/process.go
diff options
context:
space:
mode:
Diffstat (limited to 'rudp/process.go')
-rw-r--r--rudp/process.go259
1 files changed, 259 insertions, 0 deletions
diff --git a/rudp/process.go b/rudp/process.go
new file mode 100644
index 0000000..c36af81
--- /dev/null
+++ b/rudp/process.go
@@ -0,0 +1,259 @@
+package rudp
+
+import (
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// A PktError is an error that occured while processing a packet.
+type PktError struct {
+ Type string // "net", "raw" or "rel".
+ Data []byte
+ Err error
+}
+
+func (e PktError) Error() string {
+ return "error processing " + e.Type + " pkt: " +
+ hex.EncodeToString(e.Data) + ": " +
+ e.Err.Error()
+}
+
+func (e PktError) Unwrap() error { return e.Err }
+
+func (p *Peer) processNetPkts(pkts <-chan netPkt) {
+ for pkt := range pkts {
+ if err := p.processNetPkt(pkt); err != nil {
+ p.errs <- PktError{"net", pkt.Data, err}
+ }
+ }
+
+ close(p.pkts)
+}
+
+// A TrailingDataError reports a packet with trailing data,
+// it doesn't stop a packet from being processed.
+type TrailingDataError []byte
+
+func (e TrailingDataError) Error() string {
+ return "trailing data: " + hex.EncodeToString([]byte(e))
+}
+
+func (p *Peer) processNetPkt(pkt netPkt) (err error) {
+ if pkt.SrcAddr.String() != p.Addr().String() {
+ return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
+ }
+
+ if len(pkt.Data) < MtHdrSize {
+ return io.ErrUnexpectedEOF
+ }
+
+ if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
+ return fmt.Errorf("unsupported protocol id: 0x%08x", id)
+ }
+
+ // src PeerID at pkt.Data[4:6]
+
+ chno := pkt.Data[6]
+ if chno >= ChannelCount {
+ return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
+ }
+
+ p.mu.RLock()
+ if p.timeout != nil {
+ p.timeout.Reset(ConnTimeout)
+ }
+ p.mu.RUnlock()
+
+ rpkt := rawPkt{
+ Data: pkt.Data[MtHdrSize:],
+ ChNo: chno,
+ Unrel: true,
+ }
+ if err := p.processRawPkt(rpkt); err != nil {
+ p.errs <- PktError{"raw", rpkt.Data, err}
+ }
+
+ return nil
+}
+
+func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
+ errWrap := func(format string, a ...interface{}) {
+ if err != nil {
+ err = fmt.Errorf(format, append(a, err)...)
+ }
+ }
+
+ c := &p.chans[pkt.ChNo]
+
+ if len(pkt.Data) < 1 {
+ return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
+ }
+ switch t := rawType(pkt.Data[0]); t {
+ case rawTypeCtl:
+ defer errWrap("ctl: %w")
+
+ if len(pkt.Data) < 1+1 {
+ return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
+ }
+ switch ct := ctlType(pkt.Data[1]); ct {
+ case ctlAck:
+ defer errWrap("ack: %w")
+
+ if len(pkt.Data) < 1+1+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
+
+ if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
+ close(ack.(chan struct{}))
+ }
+
+ if len(pkt.Data) > 1+1+2 {
+ return TrailingDataError(pkt.Data[1+1+2:])
+ }
+ case ctlSetPeerID:
+ defer errWrap("set peer id: %w")
+
+ if len(pkt.Data) < 1+1+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ // Ensure no concurrent senders while peer id changes.
+ p.mu.Lock()
+ if p.idOfPeer != PeerIDNil {
+ return errors.New("peer id already set")
+ }
+
+ p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
+ p.mu.Unlock()
+
+ if len(pkt.Data) > 1+1+2 {
+ return TrailingDataError(pkt.Data[1+1+2:])
+ }
+ case ctlPing:
+ defer errWrap("ping: %w")
+
+ if len(pkt.Data) > 1+1 {
+ return TrailingDataError(pkt.Data[1+1:])
+ }
+ case ctlDisco:
+ defer errWrap("disco: %w")
+
+ if err := p.Close(); err != nil {
+ return fmt.Errorf("can't close: %w", err)
+ }
+
+ if len(pkt.Data) > 1+1 {
+ return TrailingDataError(pkt.Data[1+1:])
+ }
+ default:
+ return fmt.Errorf("unsupported ctl type: %d", ct)
+ }
+ case rawTypeOrig:
+ p.pkts <- Pkt{
+ Data: pkt.Data[1:],
+ ChNo: pkt.ChNo,
+ Unrel: pkt.Unrel,
+ }
+ case rawTypeSplit:
+ defer errWrap("split: %w")
+
+ if len(pkt.Data) < 1+2+2+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+ count := binary.BigEndian.Uint16(pkt.Data[3:5])
+ i := binary.BigEndian.Uint16(pkt.Data[5:7])
+
+ if i >= count {
+ return nil
+ }
+
+ splitpkts := p.chans[pkt.ChNo].insplit
+
+ // Delete old incomplete split packets
+ // so new ones don't get corrupted.
+ delete(splitpkts, sn-0x8000)
+
+ if splitpkts[sn] == nil {
+ splitpkts[sn] = make([][]byte, count)
+ }
+
+ chunks := splitpkts[sn]
+
+ if int(count) != len(chunks) {
+ return fmt.Errorf("chunk count changed on seqnum: %d", sn)
+ }
+
+ chunks[i] = pkt.Data[7:]
+
+ for _, chunk := range chunks {
+ if chunk == nil {
+ return nil
+ }
+ }
+
+ var data []byte
+ for _, chunk := range chunks {
+ data = append(data, chunk...)
+ }
+
+ p.pkts <- Pkt{
+ Data: data,
+ ChNo: pkt.ChNo,
+ Unrel: pkt.Unrel,
+ }
+
+ delete(splitpkts, sn)
+ case rawTypeRel:
+ defer errWrap("rel: %w")
+
+ if len(pkt.Data) < 1+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+
+ ackdata := make([]byte, 1+1+2)
+ ackdata[0] = uint8(rawTypeCtl)
+ ackdata[1] = uint8(ctlAck)
+ binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
+ ack := rawPkt{
+ Data: ackdata,
+ ChNo: pkt.ChNo,
+ Unrel: true,
+ }
+ if _, err := p.sendRaw(ack); err != nil {
+ return fmt.Errorf("can't ack %d: %w", sn, err)
+ }
+
+ if sn-c.inrelsn >= 0x8000 {
+ return nil // Already received.
+ }
+
+ c.inrel[sn] = pkt.Data[3:]
+
+ for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
+ data := c.inrel[c.inrelsn]
+ delete(c.inrel, c.inrelsn)
+
+ rpkt := rawPkt{
+ Data: data,
+ ChNo: pkt.ChNo,
+ Unrel: false,
+ }
+ if err := p.processRawPkt(rpkt); err != nil {
+ p.errs <- PktError{"rel", rpkt.Data, err}
+ }
+ }
+ default:
+ return fmt.Errorf("unsupported pkt type: %d", t)
+ }
+
+ return nil
+}