diff options
| author | mat <github@matdoes.dev> | 2022-10-07 23:56:23 -0500 |
|---|---|---|
| committer | mat <github@matdoes.dev> | 2022-10-07 23:56:23 -0500 |
| commit | 6f6289376a0d9ffe7e58506824e37f6b380961c3 (patch) | |
| tree | 97956fc560b338fbef630f0d0617a248e0e8b336 /azalea-protocol/src/read.rs | |
| parent | e9d8d0357ee63cce321e177bf19a8974699894ee (diff) | |
| download | azalea-drasl-6f6289376a0d9ffe7e58506824e37f6b380961c3.tar.xz | |
fix errors with rewritten packet reading
i forgot i never tested it before LMAO
Diffstat (limited to 'azalea-protocol/src/read.rs')
| -rw-r--r-- | azalea-protocol/src/read.rs | 144 |
1 files changed, 60 insertions, 84 deletions
diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index eceede9d..4c398e96 100644 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -5,16 +5,15 @@ use azalea_crypto::Aes128CfbDec; use bytes::Buf; use bytes::BytesMut; use flate2::read::ZlibDecoder; +use futures::StreamExt; use log::{log_enabled, trace}; -use std::io::Cursor; use std::{ - cell::Cell, - io::Read, - pin::Pin, - task::{Context, Poll}, + fmt::Debug, + io::{Cursor, Read}, }; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio_util::codec::{BytesCodec, FramedRead}; #[derive(Error, Debug)] pub enum ReadPacketError { @@ -28,18 +27,28 @@ pub enum ReadPacketError { UnknownPacketId { state_name: String, id: u32 }, #[error("Couldn't read packet id")] ReadPacketId { source: BufReadError }, - #[error("Couldn't decompress packet")] + #[error(transparent)] Decompress { #[from] + #[backtrace] source: DecompressionError, }, - #[error("Frame splitter error")] + #[error(transparent)] FrameSplitter { #[from] + #[backtrace] source: FrameSplitterError, }, #[error("Leftover data after reading packet {packet_name}: {data:?}")] LeftoverData { data: Vec<u8>, packet_name: String }, + #[error(transparent)] + IoError { + #[from] + #[backtrace] + source: std::io::Error, + }, + #[error("Connection closed")] + ConnectionClosed, } #[derive(Error, Debug)] @@ -52,6 +61,7 @@ pub enum FrameSplitterError { #[error("Io error")] Io { #[from] + #[backtrace] source: std::io::Error, }, #[error("Packet is longer than {max} bytes (is {size})")] @@ -77,9 +87,9 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { }, }; - if length > buffer_copy.get_ref().len() { + if length > buffer_copy.remaining() { return Err(FrameSplitterError::BadLength { - max: buffer_copy.get_ref().len(), + max: buffer_copy.remaining(), size: length, }); } @@ -88,49 +98,33 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { // from the real buffer now // the length of the varint that says the length of the whole packet - let varint_length = buffer.len() - buffer_copy.remaining(); - let _ = buffer.split_to(varint_length); + let varint_length = buffer.remaining() - buffer_copy.remaining(); + + buffer.advance(varint_length); let data = buffer.split_to(length); Ok(data) } -async fn frame_splitter<'a, R: ?Sized + Sized>( - stream: &mut R, - buffer: &'a mut BytesMut, -) -> Result<Vec<u8>, FrameSplitterError> -where - R: AsyncRead + std::marker::Unpin + std::marker::Send, -{ +fn frame_splitter<'a>(buffer: &'a mut BytesMut) -> Result<Option<Vec<u8>>, FrameSplitterError> { // https://tokio.rs/tokio/tutorial/framing - loop { - let read_frame = parse_frame(buffer); - match read_frame { - Ok(frame) => return Ok(frame.to_vec()), - Err(err) => match err { - FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { - // we probably just haven't read enough yet - } - _ => return Err(err), - }, - } - - let read_buf: usize = AsyncReadExt::read_buf(stream, buffer).await?; - if 0 == read_buf { - // The remote closed the connection. For this to be - // a clean shutdown, there should be no data in the - // read buffer. If there is, this means that the - // peer closed the socket while sending a frame. - if buffer.as_ref().is_empty() { - return Err(FrameSplitterError::ConnectionClosed); - } else { - return Err(FrameSplitterError::ConnectionReset); + let read_frame = parse_frame(buffer); + match read_frame { + Ok(frame) => return Ok(Some(frame.to_vec())), + Err(err) => match err { + FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { + // we probably just haven't read enough yet } - } + _ => return Err(err), + }, } + + Ok(None) } -fn packet_decoder<P: ProtocolPacket>(stream: &mut Cursor<&[u8]>) -> Result<P, ReadPacketError> { +fn packet_decoder<P: ProtocolPacket + Debug>( + stream: &mut Cursor<&[u8]>, +) -> Result<P, ReadPacketError> { // Packet ID let packet_id = u32::var_read_from(stream).map_err(|e| ReadPacketError::ReadPacketId { source: e })?; @@ -152,6 +146,7 @@ pub enum DecompressionError { #[error("Io error")] Io { #[from] + #[backtrace] source: std::io::Error, }, #[error("Badly compressed packet - size of {size} is below server threshold of {threshold}")] @@ -197,42 +192,7 @@ fn compression_decoder( Ok(decoded_buf) } -struct EncryptedStream<'a, R> -where - R: AsyncRead + std::marker::Unpin + std::marker::Send, -{ - cipher: Cell<&'a mut Option<Aes128CfbDec>>, - stream: &'a mut Pin<&'a mut R>, -} - -impl<R> AsyncRead for EncryptedStream<'_, R> -where - R: AsyncRead + Unpin + Send, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll<std::io::Result<()>> { - // i hate this - let polled = self.as_mut().stream.as_mut().poll_read(cx, buf); - match polled { - Poll::Ready(r) => { - // if we don't check for the remaining then we decrypt big packets incorrectly - // (but only on linux and release mode for some reason LMAO) - if buf.remaining() == 0 { - if let Some(cipher) = self.as_mut().cipher.get_mut() { - azalea_crypto::decrypt_packet(cipher, buf.filled_mut()); - } - } - Poll::Ready(r) - } - Poll::Pending => Poll::Pending, - } - } -} - -pub async fn read_packet<'a, P: ProtocolPacket, R>( +pub async fn read_packet<'a, P: ProtocolPacket + Debug, R>( stream: &'a mut R, buffer: &mut BytesMut, compression_threshold: Option<u32>, @@ -241,13 +201,29 @@ pub async fn read_packet<'a, P: ProtocolPacket, R>( where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { - // if we were given a cipher, decrypt the packet - let mut encrypted_stream = EncryptedStream { - cipher: Cell::new(cipher), - stream: &mut Pin::new(stream), - }; + let mut framed = FramedRead::new(stream, BytesCodec::new()); + let mut buf = loop { + if let Some(buf) = frame_splitter(buffer)? { + // we got a full packet!! + break buf; + } else { + // no full packet yet :( keep reading + }; + + // if we were given a cipher, decrypt the packet + if let Some(message) = framed.next().await { + let mut bytes = message.unwrap(); + println!("bytes: {:?}", bytes.len()); - let mut buf = frame_splitter(&mut encrypted_stream, buffer).await?; + if let Some(cipher) = cipher { + azalea_crypto::decrypt_packet(cipher, &mut bytes); + } + + buffer.extend_from_slice(&bytes); + } else { + return Err(ReadPacketError::ConnectionClosed); + }; + }; if let Some(compression_threshold) = compression_threshold { buf = compression_decoder(&mut Cursor::new(&buf[..]), compression_threshold)?; |
