diff options
Diffstat (limited to 'src/recv_worker.rs')
-rw-r--r-- | src/recv_worker.rs | 113 |
1 files changed, 90 insertions, 23 deletions
diff --git a/src/recv_worker.rs b/src/recv_worker.rs index 5a156eb..f83e8ef 100644 --- a/src/recv_worker.rs +++ b/src/recv_worker.rs @@ -1,7 +1,8 @@ use crate::{error::Error, *}; +use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt}; use std::{ - cell::Cell, + cell::{Cell, OnceCell}, collections::HashMap, io, sync::{Arc, Weak}, @@ -13,14 +14,16 @@ fn to_seqnum(seqnum: u16) -> usize { (seqnum as usize) & (REL_BUFFER - 1) } -type Result = std::result::Result<(), Error>; +type Result<T> = std::result::Result<T, Error>; struct Split { - timestamp: time::Instant, + timestamp: Option<time::Instant>, + chunks: Vec<OnceCell<Vec<u8>>>, + got: usize, } struct Chan { - packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char ** + packets: Vec<Cell<Option<Vec<u8>>>>, // char ** 😛 splits: HashMap<u16, Split>, seqnum: u16, num: u8, @@ -65,7 +68,9 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { let mut ch = chan.lock().await; ch.splits = ch .splits - .drain_filter(|_k, v| v.timestamp.elapsed() < timeout) + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) .collect(); } @@ -93,7 +98,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { } } - async fn recv_pkt(&self) -> Result { + async fn recv_pkt(&self) -> Result<()> { use Error::*; // todo: reset timeout @@ -101,10 +106,10 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { let proto_id = cursor.read_u32::<BigEndian>()?; if proto_id != PROTO_ID { - do yeet InvalidProtoId(proto_id); + return Err(InvalidProtoId(proto_id)); } - let peer_id = cursor.read_u16::<BigEndian>()?; + let _peer_id = cursor.read_u16::<BigEndian>()?; let n_chan = cursor.read_u8()?; let mut chan = self @@ -114,40 +119,102 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { .lock() .await; - self.process_pkt(cursor, &mut chan) + self.process_pkt(cursor, true, &mut chan).await } - fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result { - use CtlType::*; + #[async_recursion] + async fn process_pkt( + &self, + mut cursor: io::Cursor<Vec<u8>>, + unrel: bool, + chan: &mut Chan, + ) -> Result<()> { use Error::*; - use PktType::*; match cursor.read_u8()?.try_into()? { - Ctl => match cursor.read_u8()?.try_into()? { - Disco => return Err(RemoteDisco), - _ => {} + PktType::Ctl => match cursor.read_u8()?.try_into()? { + CtlType::Ack => { /* TODO */ } + CtlType::SetPeerID => { + let mut id = self.share.remote_id.write().await; + + if *id != PeerID::Nil as u16 { + return Err(PeerIDAlreadySet); + } + + *id = cursor.read_u16::<BigEndian>()?; + } + CtlType::Ping => {} + CtlType::Disco => return Err(RemoteDisco), }, - Orig => { + PktType::Orig => { println!("Orig"); self.pkt_tx.send(Ok(Pkt { chan: chan.num, - unrel: true, + unrel, data: cursor.remaining_slice().into(), }))?; } - Split => { + PktType::Split => { println!("Split"); - dbg!(cursor.remaining_slice()); + + let seqnum = cursor.read_u16::<BigEndian>()?; + let chunk_index = cursor.read_u16::<BigEndian>()? as usize; + let chunk_count = cursor.read_u16::<BigEndian>()? as usize; + + let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), + timestamp: None, + }); + + if split.chunks.len() != chunk_count { + return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); + } + + if split + .chunks + .get(chunk_index) + .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? + .set(cursor.remaining_slice().into()) + .is_ok() + { + split.got += 1; + } + + split.timestamp = if unrel { + Some(time::Instant::now()) + } else { + None + }; + + if split.got == chunk_count { + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel, + data: split + .chunks + .iter() + .flat_map(|chunk| chunk.get().unwrap().iter()) + .copied() + .collect(), + }))?; + + chan.splits.remove(&seqnum); + } } - Rel => { + PktType::Rel => { println!("Rel"); let seqnum = cursor.read_u16::<BigEndian>()?; chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); - while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() { - self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?; + fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> { + chan.packets[to_seqnum(chan.seqnum)].take() + } + + while let Some(pkt) = next_pkt(chan) { + self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; chan.seqnum = chan.seqnum.overflowing_add(1).0; } } @@ -156,7 +223,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> { Ok(()) } - fn handle(&self, res: Result) -> Result { + fn handle(&self, res: Result<()>) -> Result<()> { use Error::*; match res { |