diff options
| author | mat <27899617+mat-1@users.noreply.github.com> | 2022-10-07 20:12:36 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-07 20:12:36 -0500 |
| commit | bc3aa9467ae1e2d0ea1727093af9b0af14965e69 (patch) | |
| tree | 8db3b735daed484507129eb0683db88ddec14210 /azalea-protocol/src/read.rs | |
| parent | 695efef66fdf1e08f0cb6d8783c085875100fa2d (diff) | |
| download | azalea-drasl-bc3aa9467ae1e2d0ea1727093af9b0af14965e69.tar.xz | |
Replace impl Read with Cursor<&[u8]> (#26)
* Start getting rid of Cursor
* try to make the tests pass and fail
* make the tests pass
* remove unused uses
* fix clippy warnings
* fix potential OOM exploits
* fix OOM in az-nbt
* fix nbt benchmark
* fix a test
* start replacing it with Cursor<Vec<u8>>
* wip
* fix all the issues
* fix all tests
* fix nbt benchmark
* fix warnings
Diffstat (limited to 'azalea-protocol/src/read.rs')
| -rw-r--r--[-rwxr-xr-x] | azalea-protocol/src/read.rs | 98 |
1 files changed, 75 insertions, 23 deletions
diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 8a2aaf7d..eceede9d 100755..100644 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -1,9 +1,12 @@ use crate::packets::ProtocolPacket; +use azalea_buf::BufReadError; use azalea_buf::McBufVarReadable; -use azalea_buf::{read_varint_async, BufReadError}; use azalea_crypto::Aes128CfbDec; +use bytes::Buf; +use bytes::BytesMut; use flate2::read::ZlibDecoder; use log::{log_enabled, trace}; +use std::io::Cursor; use std::{ cell::Cell, io::Read, @@ -52,34 +55,82 @@ pub enum FrameSplitterError { source: std::io::Error, }, #[error("Packet is longer than {max} bytes (is {size})")] - BadLength { max: u32, size: u32 }, + BadLength { max: usize, size: usize }, + #[error("Connection reset by peer")] + ConnectionReset, + #[error("Connection closed")] + ConnectionClosed, } -async fn frame_splitter<R: ?Sized>(mut stream: &mut R) -> Result<Vec<u8>, FrameSplitterError> -where - R: AsyncRead + std::marker::Unpin + std::marker::Send, -{ +/// Read a length, then read that amount of bytes from BytesMut. If there's not +/// enough data, return None +fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { + // copy the buffer first and read from the copy, then once we make sure + // the packet is all good we read it fully + let mut buffer_copy = Cursor::new(&buffer[..]); // Packet Length - let length = read_varint_async(&mut stream).await? as u32; + let length = match u32::var_read_from(&mut buffer_copy) { + Ok(length) => length as usize, + Err(err) => match err { + BufReadError::Io(io_err) => return Err(FrameSplitterError::Io { source: io_err }), + _ => return Err(err.into()), + }, + }; - // TODO: read individual tcp packets so we don't need this - // https://github.com/tokio-rs/tokio/blob/master/examples/print_each_packet.rs - let max_length: u32 = 2u32.pow(20u32); // 1mb, arbitrary - if length > max_length { - // minecraft *probably* won't send packets bigger than this + if length > buffer_copy.get_ref().len() { return Err(FrameSplitterError::BadLength { - max: max_length, + max: buffer_copy.get_ref().len(), size: length, }); } - let mut buf = vec![0; length as usize]; - stream.read_exact(&mut buf).await?; + // we read from the copy and we know it's legit, so we can take those bytes + // from the real buffer now + + // the length of the varint that says the length of the whole packet + let varint_length = buffer.len() - buffer_copy.remaining(); + let _ = buffer.split_to(varint_length); + let data = buffer.split_to(length); - Ok(buf) + Ok(data) +} + +async fn frame_splitter<'a, R: ?Sized + Sized>( + stream: &mut R, + buffer: &'a mut BytesMut, +) -> Result<Vec<u8>, FrameSplitterError> +where + R: AsyncRead + std::marker::Unpin + std::marker::Send, +{ + // https://tokio.rs/tokio/tutorial/framing + loop { + let read_frame = parse_frame(buffer); + match read_frame { + Ok(frame) => return Ok(frame.to_vec()), + Err(err) => match err { + FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { + // we probably just haven't read enough yet + } + _ => return Err(err), + }, + } + + let read_buf: usize = AsyncReadExt::read_buf(stream, buffer).await?; + if 0 == read_buf { + // The remote closed the connection. For this to be + // a clean shutdown, there should be no data in the + // read buffer. If there is, this means that the + // peer closed the socket while sending a frame. + if buffer.as_ref().is_empty() { + return Err(FrameSplitterError::ConnectionClosed); + } else { + return Err(FrameSplitterError::ConnectionReset); + } + } + } } -fn packet_decoder<P: ProtocolPacket>(stream: &mut impl Read) -> Result<P, ReadPacketError> { +fn packet_decoder<P: ProtocolPacket>(stream: &mut Cursor<&[u8]>) -> Result<P, ReadPacketError> { // Packet ID let packet_id = u32::var_read_from(stream).map_err(|e| ReadPacketError::ReadPacketId { source: e })?; @@ -112,7 +163,7 @@ pub enum DecompressionError { } fn compression_decoder( - stream: &mut impl Read, + stream: &mut Cursor<&[u8]>, compression_threshold: u32, ) -> Result<Vec<u8>, DecompressionError> { // Data Length @@ -120,7 +171,7 @@ fn compression_decoder( if n == 0 { // no data size, no compression let mut buf = vec![]; - stream.read_to_end(&mut buf)?; + std::io::Read::read_to_end(stream, &mut buf)?; return Ok(buf); } @@ -183,6 +234,7 @@ where pub async fn read_packet<'a, P: ProtocolPacket, R>( stream: &'a mut R, + buffer: &mut BytesMut, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, ) -> Result<P, ReadPacketError> @@ -195,10 +247,10 @@ where stream: &mut Pin::new(stream), }; - let mut buf = frame_splitter(&mut encrypted_stream).await?; + let mut buf = frame_splitter(&mut encrypted_stream, buffer).await?; if let Some(compression_threshold) = compression_threshold { - buf = compression_decoder(&mut buf.as_slice(), compression_threshold)?; + buf = compression_decoder(&mut Cursor::new(&buf[..]), compression_threshold)?; } if log_enabled!(log::Level::Trace) { @@ -213,7 +265,7 @@ where trace!("Reading packet with bytes: {buf_string}"); } - let packet = packet_decoder(&mut buf.as_slice())?; + let packet = packet_decoder(&mut Cursor::new(&buf[..]))?; Ok(packet) } @@ -226,7 +278,7 @@ mod tests { #[tokio::test] async fn test_read_packet() { - let mut buf = Cursor::new(vec![ + let mut buf: Cursor<&[u8]> = Cursor::new(&[ 51, 0, 12, 177, 250, 155, 132, 106, 60, 218, 161, 217, 90, 157, 105, 57, 206, 20, 0, 5, 104, 101, 108, 108, 111, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 116, 123, 34, 101, 120, 116, 114, 97, 34, 58, 91, 123, 34, 99, 111, 108, 111, 114, 34, 58, |
