diff options
Diffstat (limited to 'src/recv.rs')
-rw-r--r-- | src/recv.rs | 293 |
1 files changed, 148 insertions, 145 deletions
diff --git a/src/recv.rs b/src/recv.rs index fd6f299..34e273c 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -3,14 +3,16 @@ use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use std::{ borrow::Cow, - cell::OnceCell, - collections::HashMap, + collections::{HashMap, VecDeque}, io, pin::Pin, sync::Arc, time::{Duration, Instant}, }; -use tokio::sync::{mpsc, watch, Mutex}; +use tokio::{ + sync::watch, + time::{interval, sleep, Interval, Sleep}, +}; fn to_seqnum(seqnum: u16) -> usize { (seqnum as usize) & (REL_BUFFER - 1) @@ -21,7 +23,7 @@ type Result<T> = std::result::Result<T, Error>; #[derive(Debug)] struct Split { timestamp: Option<Instant>, - chunks: Vec<OnceCell<Vec<u8>>>, + chunks: Vec<Option<Vec<u8>>>, got: usize, } @@ -29,61 +31,82 @@ struct RecvChan { packets: Vec<Option<Vec<u8>>>, // char ** 😛 splits: HashMap<u16, Split>, seqnum: u16, - num: u8, } -pub(crate) struct RecvWorker<R: UdpReceiver, S: UdpSender> { - share: Arc<RudpShare<S>>, +pub struct RudpReceiver<P: UdpPeer> { + pub(crate) share: Arc<RudpShare<P>>, + chans: [RecvChan; NUM_CHANS], + udp: P::Receiver, close: watch::Receiver<bool>, - chans: Arc<Vec<Mutex<RecvChan>>>, - pkt_tx: mpsc::UnboundedSender<InPkt>, - udp_rx: R, + closed: bool, + resend: Interval, + ping: Interval, + cleanup: Interval, + timeout: Pin<Box<Sleep>>, + queue: VecDeque<Result<Pkt<'static>>>, } -impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { - pub fn new( - udp_rx: R, - share: Arc<RudpShare<S>>, +impl<P: UdpPeer> RudpReceiver<P> { + pub(crate) fn new( + udp: P::Receiver, + share: Arc<RudpShare<P>>, close: watch::Receiver<bool>, - pkt_tx: mpsc::UnboundedSender<InPkt>, ) -> Self { Self { - udp_rx, + udp, share, close, - pkt_tx, - chans: Arc::new( - (0..NUM_CHANS as u8) - .map(|num| { - Mutex::new(RecvChan { - num, - packets: (0..REL_BUFFER).map(|_| None).collect(), - seqnum: INIT_SEQNUM, - splits: HashMap::new(), - }) - }) - .collect(), - ), + closed: false, + resend: interval(Duration::from_millis(500)), + ping: interval(Duration::from_secs(PING_TIMEOUT)), + cleanup: interval(Duration::from_secs(TIMEOUT)), + timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))), + chans: std::array::from_fn(|_| RecvChan { + packets: (0..REL_BUFFER).map(|_| None).collect(), + seqnum: INIT_SEQNUM, + splits: HashMap::new(), + }), + queue: VecDeque::new(), + } + } + + fn handle_err(&mut self, res: Result<()>) -> Result<()> { + use Error::*; + + match res { + Err(RemoteDisco(_)) | Err(LocalDisco) => { + self.closed = true; + res + } + Ok(_) => res, + Err(e) => { + self.queue.push_back(Err(e)); + Ok(()) + } } } - pub async fn run(&self) { + pub async fn recv(&mut self) -> Option<Result<Pkt<'static>>> { use Error::*; - let cleanup_chans = Arc::clone(&self.chans); - let mut cleanup_close = self.close.clone(); - self.share - .tasks - .lock() - .await - /*.build_task() - .name("cleanup_splits")*/ - .spawn(async move { - let timeout = Duration::from_secs(TIMEOUT); - - ticker!(timeout, cleanup_close, { - for chan_mtx in cleanup_chans.iter() { - let mut chan = chan_mtx.lock().await; + if self.closed { + return None; + } + + loop { + if let Some(x) = self.queue.pop_front() { + return Some(x); + } + + tokio::select! { + _ = self.close.changed() => { + self.closed = true; + return Some(Err(LocalDisco)); + }, + _ = self.cleanup.tick() => { + let timeout = Duration::from_secs(TIMEOUT); + + for chan in self.chans.iter_mut() { chan.splits = chan .splits .drain_filter( @@ -91,58 +114,48 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { ) .collect(); } - }); - }); - - let mut close = self.close.clone(); - let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT)); - tokio::pin!(timeout); - - loop { - if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) { - // TODO: figure out whether this is a good idea - if let RemoteDisco(to) = e { - self.pkt_tx.send(Err(RemoteDisco(to))).ok(); + }, + _ = self.resend.tick() => { + for chan in self.share.chans.iter() { + for (_, ack) in chan.lock().await.acks.iter() { + self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) + } + } + }, + _ = self.ping.tick() => { + self.share + .send( + PktType::Ctl, + Pkt { + chan: 0, + unrel: false, + data: Cow::Borrowed(&[CtlType::Ping as u8]), + }, + ) + .await + .ok(); + } + _ = &mut self.timeout => { + self.closed = true; + return Some(Err(RemoteDisco(true))); + }, + pkt = self.udp.recv() => { + if let Err(e) = self.handle_pkt(pkt).await { + return Some(Err(e)); + } } - - #[allow(clippy::single_match)] - match e { - // anon5's mt notifies the peer on timeout, C++ MT does not - LocalDisco /*| RemoteDisco(true)*/ => drop( - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: 0, - data: Cow::Borrowed(&[CtlType::Disco as u8]), - }, - ) - .await - .ok(), - ), - _ => {} - } - - break; } } } - async fn recv_pkt( - &self, - close: &mut watch::Receiver<bool>, - timeout: Pin<&mut tokio::time::Sleep>, - ) -> Result<()> { + async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> { use Error::*; - let mut cursor = io::Cursor::new(tokio::select! { - pkt = self.udp_rx.recv() => pkt?, - _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)), - _ = close.changed() => return Err(LocalDisco), - }); + let mut cursor = io::Cursor::new(pkt?); - timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); let proto_id = cursor.read_u32::<BigEndian>()?; if proto_id != PROTO_ID { @@ -151,38 +164,34 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { let _peer_id = cursor.read_u16::<BigEndian>()?; - let n_chan = cursor.read_u8()?; - let mut chan = self - .chans - .get(n_chan as usize) - .ok_or(InvalidChannel(n_chan))? - .lock() - .await; + let chan = cursor.read_u8()?; + if chan >= NUM_CHANS as u8 { + return Err(InvalidChannel(chan)); + } + + let res = self.process_pkt(cursor, true, chan).await; + self.handle_err(res)?; - self.process_pkt(cursor, true, &mut chan).await + Ok(()) } #[async_recursion] async fn process_pkt( - &self, + &mut self, mut cursor: io::Cursor<Vec<u8>>, unrel: bool, - chan: &mut RecvChan, + chan: u8, ) -> Result<()> { use Error::*; + let ch = chan as usize; match cursor.read_u8()?.try_into()? { PktType::Ctl => match cursor.read_u8()?.try_into()? { CtlType::Ack => { // println!("Ack"); let seqnum = cursor.read_u16::<BigEndian>()?; - if let Some(ack) = self.share.chans[chan.num as usize] - .lock() - .await - .acks - .remove(&seqnum) - { + if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) { ack.tx.send(true).ok(); } } @@ -208,11 +217,11 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { PktType::Orig => { // println!("Orig"); - self.pkt_tx.send(Ok(Pkt { - chan: chan.num, + self.queue.push_back(Ok(Pkt { + chan, unrel, data: Cow::Owned(cursor.remaining_slice().into()), - }))?; + })); } PktType::Split => { // println!("Split"); @@ -221,11 +230,14 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { let chunk_index = cursor.read_u16::<BigEndian>()? as usize; let chunk_count = cursor.read_u16::<BigEndian>()? as usize; - let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { - got: 0, - chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), - timestamp: None, - }); + let mut split = self.chans[ch] + .splits + .entry(seqnum) + .or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| None).collect(), + timestamp: None, + }); if split.chunks.len() != chunk_count { return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); @@ -233,10 +245,10 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { if split .chunks - .get(chunk_index) + .get_mut(chunk_index) .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? - .set(cursor.remaining_slice().into()) - .is_ok() + .replace(cursor.remaining_slice().into()) + .is_none() { split.got += 1; } @@ -244,25 +256,29 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { split.timestamp = if unrel { Some(Instant::now()) } else { None }; if split.got == chunk_count { - self.pkt_tx.send(Ok(Pkt { - chan: chan.num, + let split = self.chans[ch].splits.remove(&seqnum).unwrap(); + + self.queue.push_back(Ok(Pkt { + chan, unrel, data: split .chunks - .iter() - .flat_map(|chunk| chunk.get().unwrap().iter()) - .copied() - .collect(), - }))?; - - chan.splits.remove(&seqnum); + .into_iter() + .map(|x| x.unwrap()) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() + .into(), + })); } } PktType::Rel => { // println!("Rel"); let seqnum = cursor.read_u16::<BigEndian>()?; - chan.packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); + self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); let mut ack_data = Vec::with_capacity(3); ack_data.write_u8(CtlType::Ack as u8)?; @@ -272,35 +288,22 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { .send( PktType::Ctl, Pkt { + chan, unrel: true, - chan: chan.num, - data: Cow::Borrowed(&ack_data), + data: ack_data.into(), }, ) .await?; - fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> { - chan.packets[to_seqnum(chan.seqnum)].take() - } - - while let Some(pkt) = next_pkt(chan) { - self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; - chan.seqnum = chan.seqnum.overflowing_add(1).0; + let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); + while let Some(pkt) = next_pkt(&mut self.chans[ch]) { + let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await; + self.handle_err(res)?; + self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; } } } Ok(()) } - - fn handle(&self, res: Result<()>) -> Result<()> { - use Error::*; - - match res { - Ok(v) => Ok(v), - Err(RemoteDisco(to)) => Err(RemoteDisco(to)), - Err(LocalDisco) => Err(LocalDisco), - Err(e) => Ok(self.pkt_tx.send(Err(e))?), - } - } } |