From bc3aa9467ae1e2d0ea1727093af9b0af14965e69 Mon Sep 17 00:00:00 2001 From: mat <27899617+mat-1@users.noreply.github.com> Date: Fri, 7 Oct 2022 20:12:36 -0500 Subject: Replace impl Read with Cursor<&[u8]> (#26) * Start getting rid of Cursor * try to make the tests pass and fail * make the tests pass * remove unused uses * fix clippy warnings * fix potential OOM exploits * fix OOM in az-nbt * fix nbt benchmark * fix a test * start replacing it with Cursor> * wip * fix all the issues * fix all tests * fix nbt benchmark * fix warnings --- azalea-protocol/src/read.rs | 98 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 23 deletions(-) mode change 100755 => 100644 azalea-protocol/src/read.rs (limited to 'azalea-protocol/src/read.rs') diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs old mode 100755 new mode 100644 index 8a2aaf7d..eceede9d --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -1,9 +1,12 @@ use crate::packets::ProtocolPacket; +use azalea_buf::BufReadError; use azalea_buf::McBufVarReadable; -use azalea_buf::{read_varint_async, BufReadError}; use azalea_crypto::Aes128CfbDec; +use bytes::Buf; +use bytes::BytesMut; use flate2::read::ZlibDecoder; use log::{log_enabled, trace}; +use std::io::Cursor; use std::{ cell::Cell, io::Read, @@ -52,34 +55,82 @@ pub enum FrameSplitterError { source: std::io::Error, }, #[error("Packet is longer than {max} bytes (is {size})")] - BadLength { max: u32, size: u32 }, + BadLength { max: usize, size: usize }, + #[error("Connection reset by peer")] + ConnectionReset, + #[error("Connection closed")] + ConnectionClosed, } -async fn frame_splitter(mut stream: &mut R) -> Result, FrameSplitterError> -where - R: AsyncRead + std::marker::Unpin + std::marker::Send, -{ +/// Read a length, then read that amount of bytes from BytesMut. If there's not +/// enough data, return None +fn parse_frame(buffer: &mut BytesMut) -> Result { + // copy the buffer first and read from the copy, then once we make sure + // the packet is all good we read it fully + let mut buffer_copy = Cursor::new(&buffer[..]); // Packet Length - let length = read_varint_async(&mut stream).await? as u32; + let length = match u32::var_read_from(&mut buffer_copy) { + Ok(length) => length as usize, + Err(err) => match err { + BufReadError::Io(io_err) => return Err(FrameSplitterError::Io { source: io_err }), + _ => return Err(err.into()), + }, + }; - // TODO: read individual tcp packets so we don't need this - // https://github.com/tokio-rs/tokio/blob/master/examples/print_each_packet.rs - let max_length: u32 = 2u32.pow(20u32); // 1mb, arbitrary - if length > max_length { - // minecraft *probably* won't send packets bigger than this + if length > buffer_copy.get_ref().len() { return Err(FrameSplitterError::BadLength { - max: max_length, + max: buffer_copy.get_ref().len(), size: length, }); } - let mut buf = vec![0; length as usize]; - stream.read_exact(&mut buf).await?; + // we read from the copy and we know it's legit, so we can take those bytes + // from the real buffer now + + // the length of the varint that says the length of the whole packet + let varint_length = buffer.len() - buffer_copy.remaining(); + let _ = buffer.split_to(varint_length); + let data = buffer.split_to(length); - Ok(buf) + Ok(data) +} + +async fn frame_splitter<'a, R: ?Sized + Sized>( + stream: &mut R, + buffer: &'a mut BytesMut, +) -> Result, FrameSplitterError> +where + R: AsyncRead + std::marker::Unpin + std::marker::Send, +{ + // https://tokio.rs/tokio/tutorial/framing + loop { + let read_frame = parse_frame(buffer); + match read_frame { + Ok(frame) => return Ok(frame.to_vec()), + Err(err) => match err { + FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { + // we probably just haven't read enough yet + } + _ => return Err(err), + }, + } + + let read_buf: usize = AsyncReadExt::read_buf(stream, buffer).await?; + if 0 == read_buf { + // The remote closed the connection. For this to be + // a clean shutdown, there should be no data in the + // read buffer. If there is, this means that the + // peer closed the socket while sending a frame. + if buffer.as_ref().is_empty() { + return Err(FrameSplitterError::ConnectionClosed); + } else { + return Err(FrameSplitterError::ConnectionReset); + } + } + } } -fn packet_decoder(stream: &mut impl Read) -> Result { +fn packet_decoder(stream: &mut Cursor<&[u8]>) -> Result { // Packet ID let packet_id = u32::var_read_from(stream).map_err(|e| ReadPacketError::ReadPacketId { source: e })?; @@ -112,7 +163,7 @@ pub enum DecompressionError { } fn compression_decoder( - stream: &mut impl Read, + stream: &mut Cursor<&[u8]>, compression_threshold: u32, ) -> Result, DecompressionError> { // Data Length @@ -120,7 +171,7 @@ fn compression_decoder( if n == 0 { // no data size, no compression let mut buf = vec![]; - stream.read_to_end(&mut buf)?; + std::io::Read::read_to_end(stream, &mut buf)?; return Ok(buf); } @@ -183,6 +234,7 @@ where pub async fn read_packet<'a, P: ProtocolPacket, R>( stream: &'a mut R, + buffer: &mut BytesMut, compression_threshold: Option, cipher: &mut Option, ) -> Result @@ -195,10 +247,10 @@ where stream: &mut Pin::new(stream), }; - let mut buf = frame_splitter(&mut encrypted_stream).await?; + let mut buf = frame_splitter(&mut encrypted_stream, buffer).await?; if let Some(compression_threshold) = compression_threshold { - buf = compression_decoder(&mut buf.as_slice(), compression_threshold)?; + buf = compression_decoder(&mut Cursor::new(&buf[..]), compression_threshold)?; } if log_enabled!(log::Level::Trace) { @@ -213,7 +265,7 @@ where trace!("Reading packet with bytes: {buf_string}"); } - let packet = packet_decoder(&mut buf.as_slice())?; + let packet = packet_decoder(&mut Cursor::new(&buf[..]))?; Ok(packet) } @@ -226,7 +278,7 @@ mod tests { #[tokio::test] async fn test_read_packet() { - let mut buf = Cursor::new(vec![ + let mut buf: Cursor<&[u8]> = Cursor::new(&[ 51, 0, 12, 177, 250, 155, 132, 106, 60, 218, 161, 217, 90, 157, 105, 57, 206, 20, 0, 5, 104, 101, 108, 108, 111, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 116, 123, 34, 101, 120, 116, 114, 97, 34, 58, 91, 123, 34, 99, 111, 108, 111, 114, 34, 58, -- cgit v1.2.3