use super::*; use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use std::{ borrow::Cow, collections::{HashMap, VecDeque}, io, pin::Pin, sync::Arc, time::{Duration, Instant}, }; use tokio::{ sync::watch, time::{interval, sleep, Interval, Sleep}, }; fn to_seqnum(seqnum: u16) -> usize { (seqnum as usize) & (REL_BUFFER - 1) } type Result = std::result::Result; #[derive(Debug)] struct Split { timestamp: Option, chunks: Vec>>, got: usize, } struct RecvChan { packets: Vec>>, // char ** 😛 splits: HashMap, seqnum: u16, } pub struct RudpReceiver { pub(crate) share: Arc>, chans: [RecvChan; NUM_CHANS], udp: P::Receiver, close: watch::Receiver, closed: bool, resend: Interval, ping: Interval, cleanup: Interval, timeout: Pin>, queue: VecDeque>>, } impl RudpReceiver

{ pub(crate) fn new( udp: P::Receiver, share: Arc>, close: watch::Receiver, ) -> Self { Self { udp, share, close, closed: false, resend: interval(Duration::from_millis(500)), ping: interval(Duration::from_secs(PING_TIMEOUT)), cleanup: interval(Duration::from_secs(TIMEOUT)), timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))), chans: std::array::from_fn(|_| RecvChan { packets: (0..REL_BUFFER).map(|_| None).collect(), seqnum: INIT_SEQNUM, splits: HashMap::new(), }), queue: VecDeque::new(), } } fn handle_err(&mut self, res: Result<()>) -> Result<()> { use Error::*; match res { Err(RemoteDisco(_)) | Err(LocalDisco) => { self.closed = true; res } Ok(_) => res, Err(e) => { self.queue.push_back(Err(e)); Ok(()) } } } pub async fn recv(&mut self) -> Option>> { use Error::*; if self.closed { return None; } loop { if let Some(x) = self.queue.pop_front() { return Some(x); } tokio::select! { _ = self.close.changed() => { self.closed = true; return Some(Err(LocalDisco)); }, _ = self.cleanup.tick() => { let timeout = Duration::from_secs(TIMEOUT); for chan in self.chans.iter_mut() { chan.splits = chan .splits .drain_filter( |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), ) .collect(); } }, _ = self.resend.tick() => { for chan in self.share.chans.iter() { for (_, ack) in chan.lock().await.acks.iter() { self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) } } }, _ = self.ping.tick() => { self.share .send( PktType::Ctl, Pkt { chan: 0, unrel: false, data: Cow::Borrowed(&[CtlType::Ping as u8]), }, ) .await .ok(); } _ = &mut self.timeout => { self.closed = true; return Some(Err(RemoteDisco(true))); }, pkt = self.udp.recv() => { if let Err(e) = self.handle_pkt(pkt).await { return Some(Err(e)); } } } } } async fn handle_pkt(&mut self, pkt: io::Result>) -> Result<()> { use Error::*; let mut cursor = io::Cursor::new(pkt?); self.timeout .as_mut() .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); let proto_id = cursor.read_u32::()?; if proto_id != PROTO_ID { return Err(InvalidProtoId(proto_id)); } let _peer_id = cursor.read_u16::()?; let chan = cursor.read_u8()?; if chan >= NUM_CHANS as u8 { return Err(InvalidChannel(chan)); } let res = self.process_pkt(cursor, true, chan).await; self.handle_err(res)?; Ok(()) } #[async_recursion] async fn process_pkt( &mut self, mut cursor: io::Cursor>, unrel: bool, chan: u8, ) -> Result<()> { use Error::*; let ch = chan as usize; match cursor.read_u8()?.try_into()? { PktType::Ctl => match cursor.read_u8()?.try_into()? { CtlType::Ack => { // println!("Ack"); let seqnum = cursor.read_u16::()?; if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) { ack.tx.send(true).ok(); } } CtlType::SetPeerID => { // println!("SetPeerID"); let mut id = self.share.remote_id.write().await; if *id != PeerID::Nil as u16 { return Err(PeerIDAlreadySet); } *id = cursor.read_u16::()?; } CtlType::Ping => { // println!("Ping"); } CtlType::Disco => { // println!("Disco"); return Err(RemoteDisco(false)); } }, PktType::Orig => { // println!("Orig"); self.queue.push_back(Ok(Pkt { chan, unrel, data: Cow::Owned(cursor.remaining_slice().into()), })); } PktType::Split => { // println!("Split"); let seqnum = cursor.read_u16::()?; let chunk_count = cursor.read_u16::()? as usize; let chunk_index = cursor.read_u16::()? as usize; let mut split = self.chans[ch] .splits .entry(seqnum) .or_insert_with(|| Split { got: 0, chunks: (0..chunk_count).map(|_| None).collect(), timestamp: None, }); if split.chunks.len() != chunk_count { return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); } if split .chunks .get_mut(chunk_index) .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? .replace(cursor.remaining_slice().into()) .is_none() { split.got += 1; } split.timestamp = if unrel { Some(Instant::now()) } else { None }; if split.got == chunk_count { let split = self.chans[ch].splits.remove(&seqnum).unwrap(); self.queue.push_back(Ok(Pkt { chan, unrel, data: split .chunks .into_iter() .map(|x| x.unwrap()) .reduce(|mut a, mut b| { a.append(&mut b); a }) .unwrap_or_default() .into(), })); } } PktType::Rel => { // println!("Rel"); let seqnum = cursor.read_u16::()?; self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); let mut ack_data = Vec::with_capacity(3); ack_data.write_u8(CtlType::Ack as u8)?; ack_data.write_u16::(seqnum)?; self.share .send( PktType::Ctl, Pkt { chan, unrel: true, data: ack_data.into(), }, ) .await?; let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); while let Some(pkt) = next_pkt(&mut self.chans[ch]) { let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await; self.handle_err(res)?; self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; } } } Ok(()) } }