summaryrefslogtreecommitdiff
path: root/proto.go
blob: c5566f038ad3546ea6fbfafde66a2c30e404b89a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package mt

import (
	"fmt"
	"io"
	"net"
	"sync"

	"github.com/dragonfireclient/mt/rudp"
)

// A Pkt is a deserialized rudp.Pkt.
type Pkt struct {
	Cmd
	rudp.PktInfo
}

// Peer wraps rudp.Conn, adding (de)serialization.
type Peer struct {
	*rudp.Conn
}

func SerializePkt(pkt Cmd, w io.WriteCloser, toSrv bool, wg *sync.WaitGroup) bool {
	var cmdNo uint16
	if toSrv {
		cmdNo = pkt.(ToSrvCmd).toSrvCmdNo()
	} else {
		cmdNo = pkt.(ToCltCmd).toCltCmdNo()
	}

	if cmdNo == 0xffff {
		return false
	}

	wg.Add(1)
	go func() (err error) {
		defer wg.Done()
		// defer w.CloseWithError(err)
		defer w.Close()

		buf := make([]byte, 2)
		be.PutUint16(buf, cmdNo)
		if _, err := w.Write(buf); err != nil {
			return err
		}
		return serialize(w, pkt)
	}()

	return true
}

func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
	r, w := io.Pipe()
	if !SerializePkt(pkt.Cmd, w, p.IsSrv(), &sync.WaitGroup{}) {
		return nil, p.Close()
	}

	return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
}

// SendCmd is equivalent to Send(Pkt{cmd, cmd.DefaultPktInfo()}).
func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
	return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
}

func DeserializePkt(pkt io.Reader, fromSrv bool) (*Cmd, error) {
	buf := make([]byte, 2)
	if _, err := io.ReadFull(pkt, buf); err != nil {
		return nil, err
	}
	cmdNo := be.Uint16(buf)

	var newCmd func() Cmd
	if fromSrv {
		newCmd = newToCltCmd[cmdNo]
	} else {
		newCmd = newToSrvCmd[cmdNo]
	}

	if newCmd == nil {
		return nil, fmt.Errorf("unknown cmd: %d", cmdNo)
	}
	cmd := newCmd()

	if err := deserialize(pkt, cmd); err != nil {
		return nil, fmt.Errorf("%T: %w", cmd, err)
	}

	extra, err := io.ReadAll(pkt)
	if len(extra) > 0 {
		err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra))
	}

	return &cmd, err
}

func (p Peer) Recv() (_ Pkt, rerr error) {
	pkt, err := p.Conn.Recv()
	if err != nil {
		return Pkt{}, err
	}

	cmd, err := DeserializePkt(pkt, p.IsSrv())

	if cmd == nil {
		return Pkt{}, err
	} else {
		return Pkt{*cmd, pkt.PktInfo}, err
	}
}

func Connect(conn net.Conn) Peer {
	return Peer{rudp.Connect(conn)}
}

type Listener struct {
	*rudp.Listener
}

func Listen(conn net.PacketConn) Listener {
	return Listener{rudp.Listen(conn)}
}

func (l Listener) Accept() (Peer, error) {
	rpeer, err := l.Listener.Accept()
	return Peer{rpeer}, err
}