From 89b1fc1d8d4bd886d80af0fe1d492cc877bce022 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Sat, 25 Feb 2023 18:55:53 +0100 Subject: Use channels --- src/client.rs | 47 ++++++--- src/common.rs | 5 - src/lib.rs | 6 +- src/recv.rs | 309 ---------------------------------------------------------- src/send.rs | 77 +++++++++++---- src/share.rs | 87 ----------------- src/worker.rs | 288 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 378 insertions(+), 441 deletions(-) delete mode 100644 src/recv.rs delete mode 100644 src/share.rs create mode 100644 src/worker.rs diff --git a/src/client.rs b/src/client.rs index 56db92a..29244d0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,19 @@ use super::*; use async_trait::async_trait; use std::{io, sync::Arc}; -use tokio::net; +use tokio::{ + net, + sync::{mpsc, watch}, +}; #[derive(Debug)] -pub struct ToSrv(Arc); +pub struct UdpCltSender(Arc); #[derive(Debug)] -pub struct FromSrv(Arc); +pub struct UdpCltReceiver(Arc); #[async_trait] -impl UdpSender for ToSrv { +impl UdpSender for UdpCltSender { async fn send(&self, data: &[u8]) -> io::Result<()> { self.0.send(data).await?; Ok(()) @@ -18,7 +21,7 @@ impl UdpSender for ToSrv { } #[async_trait] -impl UdpReceiver for FromSrv { +impl UdpReceiver for UdpCltReceiver { async fn recv(&mut self) -> io::Result> { let mut buffer = Vec::new(); buffer.resize(UDP_PKT_SIZE, 0); @@ -30,21 +33,35 @@ impl UdpReceiver for FromSrv { } } -pub struct RemoteSrv; -impl UdpPeer for RemoteSrv { - type Sender = ToSrv; - type Receiver = FromSrv; +#[derive(Debug)] +pub struct CltReceiver(mpsc::UnboundedReceiver, Error>>); + +impl CltReceiver { + pub async fn recv_rudp(&mut self) -> Option, Error>> { + self.0.recv().await + } } -pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver)> { +pub type CltSender = Arc>; +pub type CltWorker = Worker; + +pub async fn connect(addr: &str) -> io::Result<(CltSender, CltReceiver, CltWorker)> { let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0").await?); sock.connect(addr).await?; - new( + let (close_tx, close_rx) = watch::channel(false); + let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); + + let sender = Sender::new( + UdpCltSender(Arc::clone(&sock)), + close_tx, PeerID::Srv as u16, PeerID::Nil as u16, - ToSrv(Arc::clone(&sock)), - FromSrv(sock), - ) - .await + ); + + Ok(( + Arc::clone(&sender), + CltReceiver(pkt_rx), + Worker::new(UdpCltReceiver(sock), close_rx, sender, pkt_tx), + )) } diff --git a/src/common.rs b/src/common.rs index 0ed08f3..bdae6d2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -20,11 +20,6 @@ pub trait UdpReceiver: Send { async fn recv(&mut self) -> io::Result>; } -pub trait UdpPeer { - type Sender: UdpSender; - type Receiver: UdpReceiver; -} - #[derive(Debug, Copy, Clone, PartialEq)] #[repr(u16)] pub enum PeerID { diff --git a/src/lib.rs b/src/lib.rs index e7a8ebe..b9a042d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,11 @@ mod client; mod common; mod error; -mod recv; mod send; -mod share; +mod worker; pub use client::*; pub use common::*; pub use error::*; -pub use recv::*; pub use send::*; -pub use share::*; +pub use worker::*; diff --git a/src/recv.rs b/src/recv.rs deleted file mode 100644 index 309bf94..0000000 --- a/src/recv.rs +++ /dev/null @@ -1,309 +0,0 @@ -use super::*; -use async_recursion::async_recursion; -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use std::{ - borrow::Cow, - collections::{HashMap, VecDeque}, - io, - pin::Pin, - sync::Arc, - time::{Duration, Instant}, -}; -use tokio::{ - sync::watch, - time::{interval, sleep, Interval, Sleep}, -}; - -fn to_seqnum(seqnum: u16) -> usize { - (seqnum as usize) & (REL_BUFFER - 1) -} - -type Result = std::result::Result; - -#[derive(Debug)] -struct Split { - timestamp: Option, - chunks: Vec>>, - got: usize, -} - -struct RecvChan { - packets: Vec>>, // char ** 😛 - splits: HashMap, - seqnum: u16, -} - -pub struct RudpReceiver { - pub(crate) share: Arc>, - chans: [RecvChan; NUM_CHANS], - udp: P::Receiver, - close: watch::Receiver, - closed: bool, - resend: Interval, - ping: Interval, - cleanup: Interval, - timeout: Pin>, - queue: VecDeque>>, -} - -impl RudpReceiver

{ - pub(crate) fn new( - udp: P::Receiver, - share: Arc>, - close: watch::Receiver, - ) -> Self { - Self { - udp, - share, - close, - closed: false, - resend: interval(Duration::from_millis(500)), - ping: interval(Duration::from_secs(PING_TIMEOUT)), - cleanup: interval(Duration::from_secs(TIMEOUT)), - timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))), - chans: std::array::from_fn(|_| RecvChan { - packets: (0..REL_BUFFER).map(|_| None).collect(), - seqnum: INIT_SEQNUM, - splits: HashMap::new(), - }), - queue: VecDeque::new(), - } - } - - fn handle_err(&mut self, res: Result<()>) -> Result<()> { - use Error::*; - - match res { - Err(RemoteDisco(_)) | Err(LocalDisco) => { - self.closed = true; - res - } - Ok(_) => res, - Err(e) => { - self.queue.push_back(Err(e)); - Ok(()) - } - } - } - - pub async fn recv(&mut self) -> Option>> { - use Error::*; - - if self.closed { - return None; - } - - loop { - if let Some(x) = self.queue.pop_front() { - return Some(x); - } - - tokio::select! { - _ = self.close.changed() => { - self.closed = true; - return Some(Err(LocalDisco)); - }, - _ = self.cleanup.tick() => { - let timeout = Duration::from_secs(TIMEOUT); - - for chan in self.chans.iter_mut() { - chan.splits = chan - .splits - .drain_filter( - |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), - ) - .collect(); - } - }, - _ = self.resend.tick() => { - for chan in self.share.chans.iter() { - for (_, ack) in chan.lock().await.acks.iter() { - self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) - } - } - }, - _ = self.ping.tick() => { - self.share - .send( - PktType::Ctl, - Pkt { - chan: 0, - unrel: false, - data: Cow::Borrowed(&[CtlType::Ping as u8]), - }, - ) - .await - .ok(); - } - _ = &mut self.timeout => { - self.closed = true; - return Some(Err(RemoteDisco(true))); - }, - pkt = self.udp.recv() => { - if let Err(e) = self.handle_pkt(pkt).await { - return Some(Err(e)); - } - } - } - } - } - - async fn handle_pkt(&mut self, pkt: io::Result>) -> Result<()> { - use Error::*; - - let mut cursor = io::Cursor::new(pkt?); - - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); - - let proto_id = cursor.read_u32::()?; - if proto_id != PROTO_ID { - return Err(InvalidProtoId(proto_id)); - } - - let _peer_id = cursor.read_u16::()?; - - let chan = cursor.read_u8()?; - if chan >= NUM_CHANS as u8 { - return Err(InvalidChannel(chan)); - } - - let res = self.process_pkt(cursor, true, chan).await; - self.handle_err(res)?; - - Ok(()) - } - - #[async_recursion] - async fn process_pkt( - &mut self, - mut cursor: io::Cursor>, - unrel: bool, - chan: u8, - ) -> Result<()> { - use Error::*; - - let ch = chan as usize; - match cursor.read_u8()?.try_into()? { - PktType::Ctl => match cursor.read_u8()?.try_into()? { - CtlType::Ack => { - // println!("Ack"); - - let seqnum = cursor.read_u16::()?; - if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) { - ack.tx.send(true).ok(); - } - } - CtlType::SetPeerID => { - // println!("SetPeerID"); - - let mut id = self.share.remote_id.write().await; - - if *id != PeerID::Nil as u16 { - return Err(PeerIDAlreadySet); - } - - *id = cursor.read_u16::()?; - } - CtlType::Ping => { - // println!("Ping"); - } - CtlType::Disco => { - // println!("Disco"); - return Err(RemoteDisco(false)); - } - }, - PktType::Orig => { - // println!("Orig"); - - self.queue.push_back(Ok(Pkt { - chan, - unrel, - data: Cow::Owned(cursor.remaining_slice().into()), - })); - } - PktType::Split => { - // println!("Split"); - - let seqnum = cursor.read_u16::()?; - let chunk_count = cursor.read_u16::()? as usize; - let chunk_index = cursor.read_u16::()? as usize; - - let mut split = self.chans[ch] - .splits - .entry(seqnum) - .or_insert_with(|| Split { - got: 0, - chunks: (0..chunk_count).map(|_| None).collect(), - timestamp: None, - }); - - if split.chunks.len() != chunk_count { - return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); - } - - if split - .chunks - .get_mut(chunk_index) - .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? - .replace(cursor.remaining_slice().into()) - .is_none() - { - split.got += 1; - } - - split.timestamp = if unrel { Some(Instant::now()) } else { None }; - - if split.got == chunk_count { - let split = self.chans[ch].splits.remove(&seqnum).unwrap(); - - self.queue.push_back(Ok(Pkt { - chan, - unrel, - data: split - .chunks - .into_iter() - .map(|x| x.unwrap()) - .reduce(|mut a, mut b| { - a.append(&mut b); - a - }) - .unwrap_or_default() - .into(), - })); - } - } - PktType::Rel => { - // println!("Rel"); - - let seqnum = cursor.read_u16::()?; - self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); - - let mut ack_data = Vec::with_capacity(3); - ack_data.write_u8(CtlType::Ack as u8)?; - ack_data.write_u16::(seqnum)?; - - self.share - .send( - PktType::Ctl, - Pkt { - chan, - unrel: true, - data: ack_data.into(), - }, - ) - .await?; - - let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); - while let Some(pkt) = next_pkt(&mut self.chans[ch]) { - let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await; - self.handle_err(res)?; - self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; - } - } - } - - Ok(()) - } -} diff --git a/src/send.rs b/src/send.rs index 2c449e1..90bbe2d 100644 --- a/src/send.rs +++ b/src/send.rs @@ -1,34 +1,57 @@ use super::*; use byteorder::{BigEndian, WriteBytesExt}; use std::{ + collections::HashMap, io::{self, Write}, sync::Arc, }; -use tokio::sync::watch; +use tokio::sync::{watch, Mutex, RwLock}; -pub type AckResult = io::Result>>; +pub type Ack = Option>; -pub struct RudpSender { - pub(crate) share: Arc>, +#[derive(Debug)] +pub(crate) struct AckWait { + pub(crate) tx: watch::Sender, + pub(crate) rx: watch::Receiver, + pub(crate) data: Vec, } -// derive(Clone) adds unwanted Clone trait bound to P parameter -impl Clone for RudpSender

{ - fn clone(&self) -> Self { - Self { - share: Arc::clone(&self.share), - } - } +#[derive(Debug)] +pub(crate) struct Chan { + pub(crate) acks: HashMap, + pub(crate) seqnum: u16, } -impl RudpSender

{ - pub async fn send(&self, pkt: Pkt<'_>) -> AckResult { - self.share.send(PktType::Orig, pkt).await // TODO: splits - } +#[derive(Debug)] +pub struct Sender { + pub(crate) id: u16, + pub(crate) remote_id: RwLock, + pub(crate) chans: [Mutex; NUM_CHANS], + udp: S, + close: watch::Sender, } -impl RudpShare

{ - pub async fn send(&self, tp: PktType, pkt: Pkt<'_>) -> AckResult { +impl Sender { + pub fn new(udp: S, close: watch::Sender, id: u16, remote_id: u16) -> Arc { + Arc::new(Self { + id, + remote_id: RwLock::new(remote_id), + udp, + close, + chans: std::array::from_fn(|_| { + Mutex::new(Chan { + acks: HashMap::new(), + seqnum: INIT_SEQNUM, + }) + }), + }) + } + + pub async fn send_rudp(&self, pkt: Pkt<'_>) -> io::Result { + self.send_rudp_type(PktType::Orig, pkt).await // TODO: splits + } + + pub async fn send_rudp_type(&self, tp: PktType, pkt: Pkt<'_>) -> io::Result { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len()); buf.write_u32::(PROTO_ID)?; buf.write_u16::(*self.remote_id.read().await)?; @@ -45,7 +68,7 @@ impl RudpShare

{ buf.write_u8(tp as u8)?; buf.write_all(pkt.data.as_ref())?; - self.send_raw(&buf).await?; + self.send_udp(&buf).await?; if pkt.unrel { Ok(None) @@ -54,7 +77,7 @@ impl RudpShare

{ let (tx, rx) = watch::channel(false); chan.acks.insert( seqnum, - Ack { + AckWait { tx, rx: rx.clone(), data: buf, @@ -66,11 +89,23 @@ impl RudpShare

{ } } - pub async fn send_raw(&self, data: &[u8]) -> io::Result<()> { + pub async fn send_udp(&self, data: &[u8]) -> io::Result<()> { if data.len() > UDP_PKT_SIZE { panic!("splitting packets is not implemented yet"); } - self.udp_tx.send(data).await + self.udp.send(data).await + } + + pub async fn peer_id(&self) -> u16 { + self.id + } + + pub async fn is_server(&self) -> bool { + self.id == PeerID::Srv as u16 + } + + pub fn close(&self) { + self.close.send(true).ok(); } } diff --git a/src/share.rs b/src/share.rs deleted file mode 100644 index 02e37b2..0000000 --- a/src/share.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::*; -use std::{borrow::Cow, collections::HashMap, io, sync::Arc}; -use tokio::sync::{watch, Mutex, RwLock}; - -#[derive(Debug)] -pub(crate) struct Ack { - pub(crate) tx: watch::Sender, - pub(crate) rx: watch::Receiver, - pub(crate) data: Vec, -} - -#[derive(Debug)] -pub(crate) struct Chan { - pub(crate) acks: HashMap, - pub(crate) seqnum: u16, -} - -#[derive(Debug)] -pub(crate) struct RudpShare { - pub(crate) id: u16, - pub(crate) remote_id: RwLock, - pub(crate) chans: [Mutex; NUM_CHANS], - pub(crate) udp_tx: P::Sender, - pub(crate) close: watch::Sender, -} - -pub async fn new( - id: u16, - remote_id: u16, - udp_tx: P::Sender, - udp_rx: P::Receiver, -) -> io::Result<(RudpSender

, RudpReceiver

)> { - let (close_tx, close_rx) = watch::channel(false); - - let share = Arc::new(RudpShare { - id, - remote_id: RwLock::new(remote_id), - udp_tx, - close: close_tx, - chans: std::array::from_fn(|_| { - Mutex::new(Chan { - acks: HashMap::new(), - seqnum: INIT_SEQNUM, - }) - }), - }); - - Ok(( - RudpSender { - share: Arc::clone(&share), - }, - RudpReceiver::new(udp_rx, share, close_rx), - )) -} - -macro_rules! impl_share { - ($T:ident) => { - impl $T

{ - pub async fn peer_id(&self) -> u16 { - self.share.id - } - - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } - - pub async fn close(self) { - self.share.close.send(true).ok(); // FIXME: handle err? - - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: 0, - data: Cow::Borrowed(&[CtlType::Disco as u8]), - }, - ) - .await - .ok(); - } - } - }; -} - -impl_share!(RudpReceiver); -impl_share!(RudpSender); diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 0000000..72bf2b5 --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,288 @@ +use super::*; +use async_recursion::async_recursion; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use std::{ + borrow::Cow, + collections::HashMap, + io, + pin::Pin, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::{ + sync::{mpsc, watch}, + time::{interval, sleep, Interval, Sleep}, +}; + +fn to_seqnum(seqnum: u16) -> usize { + (seqnum as usize) & (REL_BUFFER - 1) +} + +type Result = std::result::Result; + +#[derive(Debug)] +struct Split { + timestamp: Option, + chunks: Vec>>, + got: usize, +} + +#[derive(Debug)] +struct RecvChan { + packets: Vec>>, // char ** 😛 + splits: HashMap, + seqnum: u16, +} + +#[derive(Debug)] +pub struct Worker { + sender: Arc>, + chans: [RecvChan; NUM_CHANS], + input: R, + close: watch::Receiver, + resend: Interval, + ping: Interval, + cleanup: Interval, + timeout: Pin>, + output: mpsc::UnboundedSender>>, +} + +impl Worker { + pub(crate) fn new( + input: R, + close: watch::Receiver, + sender: Arc>, + output: mpsc::UnboundedSender>>, + ) -> Self { + Self { + input, + sender, + close, + output, + resend: interval(Duration::from_millis(500)), + ping: interval(Duration::from_secs(PING_TIMEOUT)), + cleanup: interval(Duration::from_secs(TIMEOUT)), + timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))), + chans: std::array::from_fn(|_| RecvChan { + packets: (0..REL_BUFFER).map(|_| None).collect(), + seqnum: INIT_SEQNUM, + splits: HashMap::new(), + }), + } + } + + pub async fn run(mut self) { + use Error::*; + + loop { + tokio::select! { + _ = self.close.changed() => { + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: Cow::Borrowed(&[CtlType::Disco as u8]), + }, + ) + .await + .ok(); + + self.output.send(Err(LocalDisco)).ok(); + break; + }, + _ = &mut self.timeout => { + self.output.send(Err(RemoteDisco(true))).ok(); + break; + }, + _ = self.cleanup.tick() => { + let timeout = Duration::from_secs(TIMEOUT); + + for chan in self.chans.iter_mut() { + chan.splits = chan + .splits + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) + .collect(); + } + }, + _ = self.resend.tick() => { + for chan in self.sender.chans.iter() { + for (_, ack) in chan.lock().await.acks.iter() { + self.sender.send_udp(&ack.data).await.ok(); + } + } + }, + _ = self.ping.tick() => { + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + chan: 0, + unrel: false, + data: Cow::Borrowed(&[CtlType::Ping as u8]), + }, + ) + .await + .ok(); + } + pkt = self.input.recv() => { + if let Err(e) = self.handle_pkt(pkt).await { + self.output.send(Err(e)).ok(); + } + } + } + } + } + + async fn handle_pkt(&mut self, pkt: io::Result>) -> Result<()> { + use Error::*; + + let mut cursor = io::Cursor::new(pkt?); + + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); + + let proto_id = cursor.read_u32::()?; + if proto_id != PROTO_ID { + return Err(InvalidProtoId(proto_id)); + } + + let _peer_id = cursor.read_u16::()?; + + let chan = cursor.read_u8()?; + if chan >= NUM_CHANS as u8 { + return Err(InvalidChannel(chan)); + } + + self.process_pkt(cursor, true, chan).await + } + + #[async_recursion] + async fn process_pkt( + &mut self, + mut cursor: io::Cursor>, + unrel: bool, + chan: u8, + ) -> Result<()> { + use Error::*; + + let ch = chan as usize; + match cursor.read_u8()?.try_into()? { + PktType::Ctl => match cursor.read_u8()?.try_into()? { + CtlType::Ack => { + let seqnum = cursor.read_u16::()?; + if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) { + ack.tx.send(true).ok(); + } + } + CtlType::SetPeerID => { + let mut id = self.sender.remote_id.write().await; + + if *id != PeerID::Nil as u16 { + return Err(PeerIDAlreadySet); + } + + *id = cursor.read_u16::()?; + } + CtlType::Ping => {} + CtlType::Disco => { + return Err(RemoteDisco(false)); + } + }, + PktType::Orig => { + self.output + .send(Ok(Pkt { + chan, + unrel, + data: Cow::Owned(cursor.remaining_slice().into()), + })) + .ok(); + } + PktType::Split => { + let seqnum = cursor.read_u16::()?; + let chunk_count = cursor.read_u16::()? as usize; + let chunk_index = cursor.read_u16::()? as usize; + + let mut split = self.chans[ch] + .splits + .entry(seqnum) + .or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| None).collect(), + timestamp: None, + }); + + if split.chunks.len() != chunk_count { + return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); + } + + if split + .chunks + .get_mut(chunk_index) + .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? + .replace(cursor.remaining_slice().into()) + .is_none() + { + split.got += 1; + } + + split.timestamp = if unrel { Some(Instant::now()) } else { None }; + + if split.got == chunk_count { + let split = self.chans[ch].splits.remove(&seqnum).unwrap(); + + self.output + .send(Ok(Pkt { + chan, + unrel, + data: split + .chunks + .into_iter() + .map(|x| x.unwrap()) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() + .into(), + })) + .ok(); + } + } + PktType::Rel => { + let seqnum = cursor.read_u16::()?; + self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); + + let mut ack_data = Vec::with_capacity(3); + ack_data.write_u8(CtlType::Ack as u8)?; + ack_data.write_u16::(seqnum)?; + + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + chan, + unrel: true, + data: ack_data.into(), + }, + ) + .await?; + + let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); + while let Some(pkt) = next_pkt(&mut self.chans[ch]) { + if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await { + self.output.send(Err(e)).ok(); + } + + self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; + } + } + } + + Ok(()) + } +} -- cgit v1.2.3