summaryrefslogtreecommitdiff
path: root/proto.go
blob: 597e146d54b4c00e9e5d646dd3f124c049c9a6a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package mt

import (
	"fmt"
	"io"
	"net"

	"github.com/dragonfireclient/mt/rudp"
)

// A Pkt is a deserialized rudp.Pkt.
type Pkt struct {
	Cmd
	rudp.PktInfo
}

// Peer wraps rudp.Conn, adding (de)serialization.
type Peer struct {
	*rudp.Conn
}

func SerializePkt(pkt Cmd, w io.WriteCloser, toSrv bool) bool {
	var cmdNo uint16
	if toSrv {
		cmdNo = pkt.(ToSrvCmd).toSrvCmdNo()
	} else {
		cmdNo = pkt.(ToCltCmd).toCltCmdNo()
	}

	if cmdNo == 0xffff {
		return false
	}

	go func() (err error) {
		// defer w.CloseWithError(err)
		defer w.Close()

		buf := make([]byte, 2)
		be.PutUint16(buf, cmdNo)
		if _, err := w.Write(buf); err != nil {
			return err
		}
		return serialize(w, pkt)
	}()

	return true
}

func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
	r, w := io.Pipe()
	if !SerializePkt(pkt.Cmd, w, p.IsSrv()) {
		return nil, p.Close()
	}

	return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
}

// SendCmd is equivalent to Send(Pkt{cmd, cmd.DefaultPktInfo()}).
func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
	return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
}

func DeserializePkt(pkt io.Reader, fromSrv bool) (*Cmd, error) {
	buf := make([]byte, 2)
	if _, err := io.ReadFull(pkt, buf); err != nil {
		return nil, err
	}
	cmdNo := be.Uint16(buf)

	var newCmd func() Cmd
	if fromSrv {
		newCmd = newToCltCmd[cmdNo]
	} else {
		newCmd = newToSrvCmd[cmdNo]
	}

	if newCmd == nil {
		return nil, fmt.Errorf("unknown cmd: %d", cmdNo)
	}
	cmd := newCmd()

	if err := deserialize(pkt, cmd); err != nil {
		return nil, fmt.Errorf("%T: %w", cmd, err)
	}

	extra, err := io.ReadAll(pkt)
	if len(extra) > 0 {
		err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra))
	}

	return &cmd, err
}

func (p Peer) Recv() (_ Pkt, rerr error) {
	pkt, err := p.Conn.Recv()
	if err != nil {
		return Pkt{}, err
	}

	cmd, err := DeserializePkt(pkt, p.IsSrv())

	if cmd == nil {
		return Pkt{}, err
	} else {
		return Pkt{*cmd, pkt.PktInfo}, err
	}
}

func Connect(conn net.Conn) Peer {
	return Peer{rudp.Connect(conn)}
}

type Listener struct {
	*rudp.Listener
}

func Listen(conn net.PacketConn) Listener {
	return Listener{rudp.Listen(conn)}
}

func (l Listener) Accept() (Peer, error) {
	rpeer, err := l.Listener.Accept()
	return Peer{rpeer}, err
}