diff options
author | Lizzy Fleckenstein <eliasfleckenstein@web.de> | 2023-02-25 18:55:53 +0100 |
---|---|---|
committer | Lizzy Fleckenstein <eliasfleckenstein@web.de> | 2023-02-25 18:55:53 +0100 |
commit | 89b1fc1d8d4bd886d80af0fe1d492cc877bce022 (patch) | |
tree | dd6ff3b8987752788def5cbcc865979c040b60fb /src | |
parent | e1a5830622b3adf2a868e42e7f2259cb26a8a0f6 (diff) | |
download | mt_rudp-89b1fc1d8d4bd886d80af0fe1d492cc877bce022.tar.xz |
Use channels
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 47 | ||||
-rw-r--r-- | src/common.rs | 5 | ||||
-rw-r--r-- | src/lib.rs | 6 | ||||
-rw-r--r-- | src/send.rs | 77 | ||||
-rw-r--r-- | src/share.rs | 87 | ||||
-rw-r--r-- | src/worker.rs (renamed from src/recv.rs) | 167 |
6 files changed, 163 insertions, 226 deletions
diff --git a/src/client.rs b/src/client.rs index 56db92a..29244d0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,19 @@ use super::*; use async_trait::async_trait; use std::{io, sync::Arc}; -use tokio::net; +use tokio::{ + net, + sync::{mpsc, watch}, +}; #[derive(Debug)] -pub struct ToSrv(Arc<net::UdpSocket>); +pub struct UdpCltSender(Arc<net::UdpSocket>); #[derive(Debug)] -pub struct FromSrv(Arc<net::UdpSocket>); +pub struct UdpCltReceiver(Arc<net::UdpSocket>); #[async_trait] -impl UdpSender for ToSrv { +impl UdpSender for UdpCltSender { async fn send(&self, data: &[u8]) -> io::Result<()> { self.0.send(data).await?; Ok(()) @@ -18,7 +21,7 @@ impl UdpSender for ToSrv { } #[async_trait] -impl UdpReceiver for FromSrv { +impl UdpReceiver for UdpCltReceiver { async fn recv(&mut self) -> io::Result<Vec<u8>> { let mut buffer = Vec::new(); buffer.resize(UDP_PKT_SIZE, 0); @@ -30,21 +33,35 @@ impl UdpReceiver for FromSrv { } } -pub struct RemoteSrv; -impl UdpPeer for RemoteSrv { - type Sender = ToSrv; - type Receiver = FromSrv; +#[derive(Debug)] +pub struct CltReceiver(mpsc::UnboundedReceiver<Result<Pkt<'static>, Error>>); + +impl CltReceiver { + pub async fn recv_rudp(&mut self) -> Option<Result<Pkt<'static>, Error>> { + self.0.recv().await + } } -pub async fn connect(addr: &str) -> io::Result<(RudpSender<RemoteSrv>, RudpReceiver<RemoteSrv>)> { +pub type CltSender = Arc<Sender<UdpCltSender>>; +pub type CltWorker = Worker<UdpCltSender, UdpCltReceiver>; + +pub async fn connect(addr: &str) -> io::Result<(CltSender, CltReceiver, CltWorker)> { let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0").await?); sock.connect(addr).await?; - new( + let (close_tx, close_rx) = watch::channel(false); + let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); + + let sender = Sender::new( + UdpCltSender(Arc::clone(&sock)), + close_tx, PeerID::Srv as u16, PeerID::Nil as u16, - ToSrv(Arc::clone(&sock)), - FromSrv(sock), - ) - .await + ); + + Ok(( + Arc::clone(&sender), + CltReceiver(pkt_rx), + Worker::new(UdpCltReceiver(sock), close_rx, sender, pkt_tx), + )) } diff --git a/src/common.rs b/src/common.rs index 0ed08f3..bdae6d2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -20,11 +20,6 @@ pub trait UdpReceiver: Send { async fn recv(&mut self) -> io::Result<Vec<u8>>; } -pub trait UdpPeer { - type Sender: UdpSender; - type Receiver: UdpReceiver; -} - #[derive(Debug, Copy, Clone, PartialEq)] #[repr(u16)] pub enum PeerID { @@ -4,13 +4,11 @@ mod client; mod common; mod error; -mod recv; mod send; -mod share; +mod worker; pub use client::*; pub use common::*; pub use error::*; -pub use recv::*; pub use send::*; -pub use share::*; +pub use worker::*; diff --git a/src/send.rs b/src/send.rs index 2c449e1..90bbe2d 100644 --- a/src/send.rs +++ b/src/send.rs @@ -1,34 +1,57 @@ use super::*; use byteorder::{BigEndian, WriteBytesExt}; use std::{ + collections::HashMap, io::{self, Write}, sync::Arc, }; -use tokio::sync::watch; +use tokio::sync::{watch, Mutex, RwLock}; -pub type AckResult = io::Result<Option<watch::Receiver<bool>>>; +pub type Ack = Option<watch::Receiver<bool>>; -pub struct RudpSender<P: UdpPeer> { - pub(crate) share: Arc<RudpShare<P>>, +#[derive(Debug)] +pub(crate) struct AckWait { + pub(crate) tx: watch::Sender<bool>, + pub(crate) rx: watch::Receiver<bool>, + pub(crate) data: Vec<u8>, } -// derive(Clone) adds unwanted Clone trait bound to P parameter -impl<P: UdpPeer> Clone for RudpSender<P> { - fn clone(&self) -> Self { - Self { - share: Arc::clone(&self.share), - } - } +#[derive(Debug)] +pub(crate) struct Chan { + pub(crate) acks: HashMap<u16, AckWait>, + pub(crate) seqnum: u16, } -impl<P: UdpPeer> RudpSender<P> { - pub async fn send(&self, pkt: Pkt<'_>) -> AckResult { - self.share.send(PktType::Orig, pkt).await // TODO: splits - } +#[derive(Debug)] +pub struct Sender<S: UdpSender> { + pub(crate) id: u16, + pub(crate) remote_id: RwLock<u16>, + pub(crate) chans: [Mutex<Chan>; NUM_CHANS], + udp: S, + close: watch::Sender<bool>, } -impl<P: UdpPeer> RudpShare<P> { - pub async fn send(&self, tp: PktType, pkt: Pkt<'_>) -> AckResult { +impl<S: UdpSender> Sender<S> { + pub fn new(udp: S, close: watch::Sender<bool>, id: u16, remote_id: u16) -> Arc<Self> { + Arc::new(Self { + id, + remote_id: RwLock::new(remote_id), + udp, + close, + chans: std::array::from_fn(|_| { + Mutex::new(Chan { + acks: HashMap::new(), + seqnum: INIT_SEQNUM, + }) + }), + }) + } + + pub async fn send_rudp(&self, pkt: Pkt<'_>) -> io::Result<Ack> { + self.send_rudp_type(PktType::Orig, pkt).await // TODO: splits + } + + pub async fn send_rudp_type(&self, tp: PktType, pkt: Pkt<'_>) -> io::Result<Ack> { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len()); buf.write_u32::<BigEndian>(PROTO_ID)?; buf.write_u16::<BigEndian>(*self.remote_id.read().await)?; @@ -45,7 +68,7 @@ impl<P: UdpPeer> RudpShare<P> { buf.write_u8(tp as u8)?; buf.write_all(pkt.data.as_ref())?; - self.send_raw(&buf).await?; + self.send_udp(&buf).await?; if pkt.unrel { Ok(None) @@ -54,7 +77,7 @@ impl<P: UdpPeer> RudpShare<P> { let (tx, rx) = watch::channel(false); chan.acks.insert( seqnum, - Ack { + AckWait { tx, rx: rx.clone(), data: buf, @@ -66,11 +89,23 @@ impl<P: UdpPeer> RudpShare<P> { } } - pub async fn send_raw(&self, data: &[u8]) -> io::Result<()> { + pub async fn send_udp(&self, data: &[u8]) -> io::Result<()> { if data.len() > UDP_PKT_SIZE { panic!("splitting packets is not implemented yet"); } - self.udp_tx.send(data).await + self.udp.send(data).await + } + + pub async fn peer_id(&self) -> u16 { + self.id + } + + pub async fn is_server(&self) -> bool { + self.id == PeerID::Srv as u16 + } + + pub fn close(&self) { + self.close.send(true).ok(); } } diff --git a/src/share.rs b/src/share.rs deleted file mode 100644 index 02e37b2..0000000 --- a/src/share.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::*; -use std::{borrow::Cow, collections::HashMap, io, sync::Arc}; -use tokio::sync::{watch, Mutex, RwLock}; - -#[derive(Debug)] -pub(crate) struct Ack { - pub(crate) tx: watch::Sender<bool>, - pub(crate) rx: watch::Receiver<bool>, - pub(crate) data: Vec<u8>, -} - -#[derive(Debug)] -pub(crate) struct Chan { - pub(crate) acks: HashMap<u16, Ack>, - pub(crate) seqnum: u16, -} - -#[derive(Debug)] -pub(crate) struct RudpShare<P: UdpPeer> { - pub(crate) id: u16, - pub(crate) remote_id: RwLock<u16>, - pub(crate) chans: [Mutex<Chan>; NUM_CHANS], - pub(crate) udp_tx: P::Sender, - pub(crate) close: watch::Sender<bool>, -} - -pub async fn new<P: UdpPeer>( - id: u16, - remote_id: u16, - udp_tx: P::Sender, - udp_rx: P::Receiver, -) -> io::Result<(RudpSender<P>, RudpReceiver<P>)> { - let (close_tx, close_rx) = watch::channel(false); - - let share = Arc::new(RudpShare { - id, - remote_id: RwLock::new(remote_id), - udp_tx, - close: close_tx, - chans: std::array::from_fn(|_| { - Mutex::new(Chan { - acks: HashMap::new(), - seqnum: INIT_SEQNUM, - }) - }), - }); - - Ok(( - RudpSender { - share: Arc::clone(&share), - }, - RudpReceiver::new(udp_rx, share, close_rx), - )) -} - -macro_rules! impl_share { - ($T:ident) => { - impl<P: UdpPeer> $T<P> { - pub async fn peer_id(&self) -> u16 { - self.share.id - } - - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } - - pub async fn close(self) { - self.share.close.send(true).ok(); // FIXME: handle err? - - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: 0, - data: Cow::Borrowed(&[CtlType::Disco as u8]), - }, - ) - .await - .ok(); - } - } - }; -} - -impl_share!(RudpReceiver); -impl_share!(RudpSender); diff --git a/src/recv.rs b/src/worker.rs index 309bf94..72bf2b5 100644 --- a/src/recv.rs +++ b/src/worker.rs @@ -3,14 +3,14 @@ use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use std::{ borrow::Cow, - collections::{HashMap, VecDeque}, + collections::HashMap, io, pin::Pin, sync::Arc, time::{Duration, Instant}, }; use tokio::{ - sync::watch, + sync::{mpsc, watch}, time::{interval, sleep, Interval, Sleep}, }; @@ -27,36 +27,38 @@ struct Split { got: usize, } +#[derive(Debug)] struct RecvChan { packets: Vec<Option<Vec<u8>>>, // char ** 😛 splits: HashMap<u16, Split>, seqnum: u16, } -pub struct RudpReceiver<P: UdpPeer> { - pub(crate) share: Arc<RudpShare<P>>, +#[derive(Debug)] +pub struct Worker<S: UdpSender, R: UdpReceiver> { + sender: Arc<Sender<S>>, chans: [RecvChan; NUM_CHANS], - udp: P::Receiver, + input: R, close: watch::Receiver<bool>, - closed: bool, resend: Interval, ping: Interval, cleanup: Interval, timeout: Pin<Box<Sleep>>, - queue: VecDeque<Result<Pkt<'static>>>, + output: mpsc::UnboundedSender<Result<Pkt<'static>>>, } -impl<P: UdpPeer> RudpReceiver<P> { +impl<S: UdpSender, R: UdpReceiver> Worker<S, R> { pub(crate) fn new( - udp: P::Receiver, - share: Arc<RudpShare<P>>, + input: R, close: watch::Receiver<bool>, + sender: Arc<Sender<S>>, + output: mpsc::UnboundedSender<Result<Pkt<'static>>>, ) -> Self { Self { - udp, - share, + input, + sender, close, - closed: false, + output, resend: interval(Duration::from_millis(500)), ping: interval(Duration::from_secs(PING_TIMEOUT)), cleanup: interval(Duration::from_secs(TIMEOUT)), @@ -66,42 +68,33 @@ impl<P: UdpPeer> RudpReceiver<P> { seqnum: INIT_SEQNUM, splits: HashMap::new(), }), - queue: VecDeque::new(), } } - fn handle_err(&mut self, res: Result<()>) -> Result<()> { + pub async fn run(mut self) { use Error::*; - match res { - Err(RemoteDisco(_)) | Err(LocalDisco) => { - self.closed = true; - res - } - Ok(_) => res, - Err(e) => { - self.queue.push_back(Err(e)); - Ok(()) - } - } - } - - pub async fn recv(&mut self) -> Option<Result<Pkt<'static>>> { - use Error::*; - - if self.closed { - return None; - } - loop { - if let Some(x) = self.queue.pop_front() { - return Some(x); - } - tokio::select! { _ = self.close.changed() => { - self.closed = true; - return Some(Err(LocalDisco)); + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: Cow::Borrowed(&[CtlType::Disco as u8]), + }, + ) + .await + .ok(); + + self.output.send(Err(LocalDisco)).ok(); + break; + }, + _ = &mut self.timeout => { + self.output.send(Err(RemoteDisco(true))).ok(); + break; }, _ = self.cleanup.tick() => { let timeout = Duration::from_secs(TIMEOUT); @@ -116,15 +109,15 @@ impl<P: UdpPeer> RudpReceiver<P> { } }, _ = self.resend.tick() => { - for chan in self.share.chans.iter() { + for chan in self.sender.chans.iter() { for (_, ack) in chan.lock().await.acks.iter() { - self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) + self.sender.send_udp(&ack.data).await.ok(); } } }, _ = self.ping.tick() => { - self.share - .send( + self.sender + .send_rudp_type( PktType::Ctl, Pkt { chan: 0, @@ -135,13 +128,9 @@ impl<P: UdpPeer> RudpReceiver<P> { .await .ok(); } - _ = &mut self.timeout => { - self.closed = true; - return Some(Err(RemoteDisco(true))); - }, - pkt = self.udp.recv() => { + pkt = self.input.recv() => { if let Err(e) = self.handle_pkt(pkt).await { - return Some(Err(e)); + self.output.send(Err(e)).ok(); } } } @@ -169,10 +158,7 @@ impl<P: UdpPeer> RudpReceiver<P> { return Err(InvalidChannel(chan)); } - let res = self.process_pkt(cursor, true, chan).await; - self.handle_err(res)?; - - Ok(()) + self.process_pkt(cursor, true, chan).await } #[async_recursion] @@ -188,17 +174,13 @@ impl<P: UdpPeer> RudpReceiver<P> { match cursor.read_u8()?.try_into()? { PktType::Ctl => match cursor.read_u8()?.try_into()? { CtlType::Ack => { - // println!("Ack"); - let seqnum = cursor.read_u16::<BigEndian>()?; - if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) { + if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) { ack.tx.send(true).ok(); } } CtlType::SetPeerID => { - // println!("SetPeerID"); - - let mut id = self.share.remote_id.write().await; + let mut id = self.sender.remote_id.write().await; if *id != PeerID::Nil as u16 { return Err(PeerIDAlreadySet); @@ -206,26 +188,21 @@ impl<P: UdpPeer> RudpReceiver<P> { *id = cursor.read_u16::<BigEndian>()?; } - CtlType::Ping => { - // println!("Ping"); - } + CtlType::Ping => {} CtlType::Disco => { - // println!("Disco"); return Err(RemoteDisco(false)); } }, PktType::Orig => { - // println!("Orig"); - - self.queue.push_back(Ok(Pkt { - chan, - unrel, - data: Cow::Owned(cursor.remaining_slice().into()), - })); + self.output + .send(Ok(Pkt { + chan, + unrel, + data: Cow::Owned(cursor.remaining_slice().into()), + })) + .ok(); } PktType::Split => { - // println!("Split"); - let seqnum = cursor.read_u16::<BigEndian>()?; let chunk_count = cursor.read_u16::<BigEndian>()? as usize; let chunk_index = cursor.read_u16::<BigEndian>()? as usize; @@ -258,25 +235,25 @@ impl<P: UdpPeer> RudpReceiver<P> { if split.got == chunk_count { let split = self.chans[ch].splits.remove(&seqnum).unwrap(); - self.queue.push_back(Ok(Pkt { - chan, - unrel, - data: split - .chunks - .into_iter() - .map(|x| x.unwrap()) - .reduce(|mut a, mut b| { - a.append(&mut b); - a - }) - .unwrap_or_default() - .into(), - })); + self.output + .send(Ok(Pkt { + chan, + unrel, + data: split + .chunks + .into_iter() + .map(|x| x.unwrap()) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() + .into(), + })) + .ok(); } } PktType::Rel => { - // println!("Rel"); - let seqnum = cursor.read_u16::<BigEndian>()?; self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); @@ -284,8 +261,8 @@ impl<P: UdpPeer> RudpReceiver<P> { ack_data.write_u8(CtlType::Ack as u8)?; ack_data.write_u16::<BigEndian>(seqnum)?; - self.share - .send( + self.sender + .send_rudp_type( PktType::Ctl, Pkt { chan, @@ -297,8 +274,10 @@ impl<P: UdpPeer> RudpReceiver<P> { let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); while let Some(pkt) = next_pkt(&mut self.chans[ch]) { - let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await; - self.handle_err(res)?; + if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await { + self.output.send(Err(e)).ok(); + } + self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; } } |