diff options
author | Lizzy Fleckenstein <eliasfleckenstein@web.de> | 2023-01-06 17:45:16 +0100 |
---|---|---|
committer | Lizzy Fleckenstein <eliasfleckenstein@web.de> | 2023-01-06 17:45:16 +0100 |
commit | fd23bb3a2b57d43c115005dcd70f1e18bb005032 (patch) | |
tree | 7fa77c2db1faa55685e24a180bbd419a78d7be53 /src/recv.rs | |
parent | d3b8019227137853406891e2aa84e0c8a9e3c31c (diff) | |
download | mt_rudp-fd23bb3a2b57d43c115005dcd70f1e18bb005032.tar.xz |
clean shutdown; send reliables
Diffstat (limited to 'src/recv.rs')
-rw-r--r-- | src/recv.rs | 283 |
1 files changed, 283 insertions, 0 deletions
diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..15811f2 --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,283 @@ +use crate::{error::Error, *}; +use async_recursion::async_recursion; +use byteorder::{BigEndian, ReadBytesExt}; +use std::{ + cell::{Cell, OnceCell}, + collections::HashMap, + io, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{mpsc, Mutex}; + +fn to_seqnum(seqnum: u16) -> usize { + (seqnum as usize) & (REL_BUFFER - 1) +} + +type Result<T> = std::result::Result<T, Error>; + +struct Split { + timestamp: Option<Instant>, + chunks: Vec<OnceCell<Vec<u8>>>, + got: usize, +} + +struct RecvChan { + packets: Vec<Cell<Option<Vec<u8>>>>, // char ** 😛 + splits: HashMap<u16, Split>, + seqnum: u16, + num: u8, +} + +pub struct RecvWorker<R: UdpReceiver, S: UdpSender> { + share: Arc<RudpShare<S>>, + close: watch::Receiver<bool>, + chans: Arc<Vec<Mutex<RecvChan>>>, + pkt_tx: mpsc::UnboundedSender<InPkt>, + udp_rx: R, +} + +impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { + pub fn new( + udp_rx: R, + share: Arc<RudpShare<S>>, + close: watch::Receiver<bool>, + pkt_tx: mpsc::UnboundedSender<InPkt>, + ) -> Self { + Self { + udp_rx, + share, + close, + pkt_tx, + chans: Arc::new( + (0..NUM_CHANS as u8) + .map(|num| { + Mutex::new(RecvChan { + num, + packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(), + seqnum: INIT_SEQNUM, + splits: HashMap::new(), + }) + }) + .collect(), + ), + } + } + + pub async fn run(&self) { + let cleanup_chans = Arc::clone(&self.chans); + let mut cleanup_close = self.close.clone(); + self.share + .tasks + .lock() + .await + /*.build_task() + .name("cleanup_splits")*/ + .spawn(async move { + let timeout = Duration::from_secs(TIMEOUT); + + ticker!(timeout, cleanup_close, { + for chan_mtx in cleanup_chans.iter() { + let mut chan = chan_mtx.lock().await; + chan.splits = chan + .splits + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) + .collect(); + } + }); + }); + + let mut close = self.close.clone(); + loop { + if let Err(e) = self.handle(self.recv_pkt(&mut close).await) { + if let Error::LocalDisco = e { + self.share + .send( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: &[CtlType::Disco as u8], + }, + ) + .await + .ok(); + } + break; + } + } + } + + async fn recv_pkt(&self, close: &mut watch::Receiver<bool>) -> Result<()> { + use Error::*; + + // TODO: reset timeout + let mut cursor = io::Cursor::new(tokio::select! { + pkt = self.udp_rx.recv() => pkt?, + _ = close.changed() => return Err(LocalDisco), + }); + + println!("recv"); + + let proto_id = cursor.read_u32::<BigEndian>()?; + if proto_id != PROTO_ID { + return Err(InvalidProtoId(proto_id)); + } + + let _peer_id = cursor.read_u16::<BigEndian>()?; + + let n_chan = cursor.read_u8()?; + let mut chan = self + .chans + .get(n_chan as usize) + .ok_or(InvalidChannel(n_chan))? + .lock() + .await; + + self.process_pkt(cursor, true, &mut chan).await + } + + #[async_recursion] + async fn process_pkt( + &self, + mut cursor: io::Cursor<Vec<u8>>, + unrel: bool, + chan: &mut RecvChan, + ) -> Result<()> { + use Error::*; + + match cursor.read_u8()?.try_into()? { + PktType::Ctl => match cursor.read_u8()?.try_into()? { + CtlType::Ack => { + println!("Ack"); + + let seqnum = cursor.read_u16::<BigEndian>()?; + if let Some(ack) = self.share.chans[chan.num as usize] + .lock() + .await + .acks + .remove(&seqnum) + { + ack.tx.send(true).ok(); + } + } + CtlType::SetPeerID => { + println!("SetPeerID"); + + let mut id = self.share.remote_id.write().await; + + if *id != PeerID::Nil as u16 { + return Err(PeerIDAlreadySet); + } + + *id = cursor.read_u16::<BigEndian>()?; + } + CtlType::Ping => { + println!("Ping"); + } + CtlType::Disco => { + println!("Disco"); + return Err(RemoteDisco); + } + }, + PktType::Orig => { + println!("Orig"); + + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel, + data: cursor.remaining_slice().into(), + }))?; + } + PktType::Split => { + println!("Split"); + + let seqnum = cursor.read_u16::<BigEndian>()?; + let chunk_index = cursor.read_u16::<BigEndian>()? as usize; + let chunk_count = cursor.read_u16::<BigEndian>()? as usize; + + let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), + timestamp: None, + }); + + if split.chunks.len() != chunk_count { + return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); + } + + if split + .chunks + .get(chunk_index) + .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? + .set(cursor.remaining_slice().into()) + .is_ok() + { + split.got += 1; + } + + split.timestamp = if unrel { Some(Instant::now()) } else { None }; + + if split.got == chunk_count { + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel, + data: split + .chunks + .iter() + .flat_map(|chunk| chunk.get().unwrap().iter()) + .copied() + .collect(), + }))?; + + chan.splits.remove(&seqnum); + } + } + PktType::Rel => { + println!("Rel"); + + let seqnum = cursor.read_u16::<BigEndian>()?; + chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); + + let mut ack_data = Vec::with_capacity(3); + ack_data.write_u8(CtlType::Ack as u8)?; + ack_data.write_u16::<BigEndian>(seqnum)?; + + self.share + .send( + PktType::Ctl, + Pkt { + unrel: true, + chan: chan.num, + data: &ack_data, + }, + ) + .await?; + + fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> { + chan.packets[to_seqnum(chan.seqnum)].take() + } + + while let Some(pkt) = next_pkt(chan) { + self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; + chan.seqnum = chan.seqnum.overflowing_add(1).0; + } + } + } + + Ok(()) + } + + fn handle(&self, res: Result<()>) -> Result<()> { + use Error::*; + + match res { + Ok(v) => Ok(v), + Err(RemoteDisco) => Err(RemoteDisco), + Err(LocalDisco) => Err(LocalDisco), + Err(e) => Ok(self.pkt_tx.send(Err(e))?), + } + } +} |