From 89b1fc1d8d4bd886d80af0fe1d492cc877bce022 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Sat, 25 Feb 2023 18:55:53 +0100 Subject: Use channels --- src/worker.rs | 288 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 src/worker.rs (limited to 'src/worker.rs') diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 0000000..72bf2b5 --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,288 @@ +use super::*; +use async_recursion::async_recursion; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use std::{ + borrow::Cow, + collections::HashMap, + io, + pin::Pin, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::{ + sync::{mpsc, watch}, + time::{interval, sleep, Interval, Sleep}, +}; + +fn to_seqnum(seqnum: u16) -> usize { + (seqnum as usize) & (REL_BUFFER - 1) +} + +type Result = std::result::Result; + +#[derive(Debug)] +struct Split { + timestamp: Option, + chunks: Vec>>, + got: usize, +} + +#[derive(Debug)] +struct RecvChan { + packets: Vec>>, // char ** 😛 + splits: HashMap, + seqnum: u16, +} + +#[derive(Debug)] +pub struct Worker { + sender: Arc>, + chans: [RecvChan; NUM_CHANS], + input: R, + close: watch::Receiver, + resend: Interval, + ping: Interval, + cleanup: Interval, + timeout: Pin>, + output: mpsc::UnboundedSender>>, +} + +impl Worker { + pub(crate) fn new( + input: R, + close: watch::Receiver, + sender: Arc>, + output: mpsc::UnboundedSender>>, + ) -> Self { + Self { + input, + sender, + close, + output, + resend: interval(Duration::from_millis(500)), + ping: interval(Duration::from_secs(PING_TIMEOUT)), + cleanup: interval(Duration::from_secs(TIMEOUT)), + timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))), + chans: std::array::from_fn(|_| RecvChan { + packets: (0..REL_BUFFER).map(|_| None).collect(), + seqnum: INIT_SEQNUM, + splits: HashMap::new(), + }), + } + } + + pub async fn run(mut self) { + use Error::*; + + loop { + tokio::select! { + _ = self.close.changed() => { + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: Cow::Borrowed(&[CtlType::Disco as u8]), + }, + ) + .await + .ok(); + + self.output.send(Err(LocalDisco)).ok(); + break; + }, + _ = &mut self.timeout => { + self.output.send(Err(RemoteDisco(true))).ok(); + break; + }, + _ = self.cleanup.tick() => { + let timeout = Duration::from_secs(TIMEOUT); + + for chan in self.chans.iter_mut() { + chan.splits = chan + .splits + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) + .collect(); + } + }, + _ = self.resend.tick() => { + for chan in self.sender.chans.iter() { + for (_, ack) in chan.lock().await.acks.iter() { + self.sender.send_udp(&ack.data).await.ok(); + } + } + }, + _ = self.ping.tick() => { + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + chan: 0, + unrel: false, + data: Cow::Borrowed(&[CtlType::Ping as u8]), + }, + ) + .await + .ok(); + } + pkt = self.input.recv() => { + if let Err(e) = self.handle_pkt(pkt).await { + self.output.send(Err(e)).ok(); + } + } + } + } + } + + async fn handle_pkt(&mut self, pkt: io::Result>) -> Result<()> { + use Error::*; + + let mut cursor = io::Cursor::new(pkt?); + + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); + + let proto_id = cursor.read_u32::()?; + if proto_id != PROTO_ID { + return Err(InvalidProtoId(proto_id)); + } + + let _peer_id = cursor.read_u16::()?; + + let chan = cursor.read_u8()?; + if chan >= NUM_CHANS as u8 { + return Err(InvalidChannel(chan)); + } + + self.process_pkt(cursor, true, chan).await + } + + #[async_recursion] + async fn process_pkt( + &mut self, + mut cursor: io::Cursor>, + unrel: bool, + chan: u8, + ) -> Result<()> { + use Error::*; + + let ch = chan as usize; + match cursor.read_u8()?.try_into()? { + PktType::Ctl => match cursor.read_u8()?.try_into()? { + CtlType::Ack => { + let seqnum = cursor.read_u16::()?; + if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) { + ack.tx.send(true).ok(); + } + } + CtlType::SetPeerID => { + let mut id = self.sender.remote_id.write().await; + + if *id != PeerID::Nil as u16 { + return Err(PeerIDAlreadySet); + } + + *id = cursor.read_u16::()?; + } + CtlType::Ping => {} + CtlType::Disco => { + return Err(RemoteDisco(false)); + } + }, + PktType::Orig => { + self.output + .send(Ok(Pkt { + chan, + unrel, + data: Cow::Owned(cursor.remaining_slice().into()), + })) + .ok(); + } + PktType::Split => { + let seqnum = cursor.read_u16::()?; + let chunk_count = cursor.read_u16::()? as usize; + let chunk_index = cursor.read_u16::()? as usize; + + let mut split = self.chans[ch] + .splits + .entry(seqnum) + .or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| None).collect(), + timestamp: None, + }); + + if split.chunks.len() != chunk_count { + return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); + } + + if split + .chunks + .get_mut(chunk_index) + .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? + .replace(cursor.remaining_slice().into()) + .is_none() + { + split.got += 1; + } + + split.timestamp = if unrel { Some(Instant::now()) } else { None }; + + if split.got == chunk_count { + let split = self.chans[ch].splits.remove(&seqnum).unwrap(); + + self.output + .send(Ok(Pkt { + chan, + unrel, + data: split + .chunks + .into_iter() + .map(|x| x.unwrap()) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() + .into(), + })) + .ok(); + } + } + PktType::Rel => { + let seqnum = cursor.read_u16::()?; + self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); + + let mut ack_data = Vec::with_capacity(3); + ack_data.write_u8(CtlType::Ack as u8)?; + ack_data.write_u16::(seqnum)?; + + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + chan, + unrel: true, + data: ack_data.into(), + }, + ) + .await?; + + let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); + while let Some(pkt) = next_pkt(&mut self.chans[ch]) { + if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await { + self.output.send(Err(e)).ok(); + } + + self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; + } + } + } + + Ok(()) + } +} -- cgit v1.2.3