diff options
| author | mat <git@matdoes.dev> | 2024-12-25 06:16:10 +0000 |
|---|---|---|
| committer | mat <git@matdoes.dev> | 2024-12-25 06:16:10 +0000 |
| commit | 04eaa5c3d01a8f3a599a3a1abf7205eed80df4a2 (patch) | |
| tree | e37b34e8bf03045778f383f4e324414e2047ca92 /azalea-protocol/src/read.rs | |
| parent | 0ee9ed50e30222784d094e20302cadc879f2b6db (diff) | |
| download | azalea-drasl-04eaa5c3d01a8f3a599a3a1abf7205eed80df4a2.tar.xz | |
remove dependency on bytes crate for azalea-protocol and fix memory leak
Diffstat (limited to 'azalea-protocol/src/read.rs')
| -rwxr-xr-x | azalea-protocol/src/read.rs | 81 |
1 files changed, 49 insertions, 32 deletions
diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 8569ca73..6f9b754a 100755 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -9,13 +9,12 @@ use std::{ use azalea_buf::AzaleaReadVar; use azalea_buf::BufReadError; use azalea_crypto::Aes128CfbDec; -use bytes::Buf; -use bytes::BytesMut; use flate2::read::ZlibDecoder; use futures::StreamExt; use futures_lite::future; use thiserror::Error; use tokio::io::AsyncRead; +use tokio_util::bytes::Buf; use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::trace; @@ -79,12 +78,12 @@ pub enum FrameSplitterError { ConnectionClosed, } -/// Read a length, then read that amount of bytes from `BytesMut`. If there's -/// not enough data, return None -fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { +/// Read a length, then read that amount of bytes from the `Cursor<Vec<u8>>`. If +/// there's not enough data, return None +fn parse_frame(buffer: &mut Cursor<Vec<u8>>) -> Result<Box<[u8]>, FrameSplitterError> { // copy the buffer first and read from the copy, then once we make sure // the packet is all good we read it fully - let mut buffer_copy = Cursor::new(&buffer[..]); + let mut buffer_copy = Cursor::new(&buffer.get_ref()[buffer.position() as usize..]); // Packet Length let length = match u32::azalea_read_var(&mut buffer_copy) { Ok(length) => length as usize, @@ -106,18 +105,28 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { // the length of the varint that says the length of the whole packet let varint_length = buffer.remaining() - buffer_copy.remaining(); + drop(buffer_copy); buffer.advance(varint_length); - let data = buffer.split_to(length); + let data = + buffer.get_ref()[buffer.position() as usize..buffer.position() as usize + length].to_vec(); + buffer.advance(length); + + if buffer.position() == buffer.get_ref().len() as u64 { + // reset the inner vec once we've reached the end of the buffer so we don't keep + // leaking memory + *buffer.get_mut() = Vec::new(); + buffer.set_position(0); + } - Ok(data) + Ok(data.into_boxed_slice()) } -fn frame_splitter(buffer: &mut BytesMut) -> Result<Option<Vec<u8>>, FrameSplitterError> { +fn frame_splitter(buffer: &mut Cursor<Vec<u8>>) -> Result<Option<Box<[u8]>>, FrameSplitterError> { // https://tokio.rs/tokio/tutorial/framing let read_frame = parse_frame(buffer); match read_frame { - Ok(frame) => return Ok(Some(frame.to_vec())), + Ok(frame) => return Ok(Some(frame)), Err(err) => match err { FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { // we probably just haven't read enough yet @@ -141,7 +150,7 @@ pub fn deserialize_packet<P: ProtocolPacket + Debug>( // this is always true in multiplayer, false in singleplayer static VALIDATE_DECOMPRESSED: bool = true; -pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2097152; +pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2_097_152; #[derive(Error, Debug)] pub enum DecompressionError { @@ -169,13 +178,15 @@ pub enum DecompressionError { pub fn compression_decoder( stream: &mut Cursor<&[u8]>, compression_threshold: u32, -) -> Result<Vec<u8>, DecompressionError> { +) -> Result<Box<[u8]>, DecompressionError> { // Data Length let n = u32::azalea_read_var(stream)?; if n == 0 { // no data size, no compression - let mut buf = vec![]; - std::io::Read::read_to_end(stream, &mut buf)?; + let buf = stream.get_ref()[stream.position() as usize..] + .to_vec() + .into_boxed_slice(); + stream.set_position(stream.get_ref().len() as u64); return Ok(buf); } @@ -194,11 +205,14 @@ pub fn compression_decoder( } } - let mut decoded_buf = vec![]; + // VALIDATE_DECOMPRESSED should always be true, so the max they can make us + // allocate here is 2mb + let mut decoded_buf = Vec::with_capacity(n as usize); + let mut decoder = ZlibDecoder::new(stream); decoder.read_to_end(&mut decoded_buf)?; - Ok(decoded_buf) + Ok(decoded_buf.into_boxed_slice()) } /// Read a single packet from a stream. @@ -211,7 +225,7 @@ pub fn compression_decoder( /// For the non-waiting version, see [`try_read_packet`]. pub async fn read_packet<P: ProtocolPacket + Debug, R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, ) -> Result<P, Box<ReadPacketError>> @@ -219,7 +233,7 @@ where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { let raw_packet = read_raw_packet(stream, buffer, compression_threshold, cipher).await?; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(packet) } @@ -227,7 +241,7 @@ where /// received a full packet yet. pub fn try_read_packet<P: ProtocolPacket + Debug, R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, ) -> Result<Option<P>, Box<ReadPacketError>> @@ -238,18 +252,18 @@ where else { return Ok(None); }; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(Some(packet)) } pub async fn read_raw_packet<R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, // this has to be a &mut Option<T> instead of an Option<&mut T> because // otherwise the borrow checker complains about the cipher being moved cipher: &mut Option<Aes128CfbDec>, -) -> Result<Vec<u8>, Box<ReadPacketError>> +) -> Result<Box<[u8]>, Box<ReadPacketError>> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -260,15 +274,15 @@ where }; let bytes = read_and_decrypt_frame(stream, cipher).await?; - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } pub fn try_read_raw_packet<R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, -) -> Result<Option<Vec<u8>>, Box<ReadPacketError>> +) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -282,14 +296,14 @@ where return Ok(None); }; // we got some data, so add it to the buffer and try again - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } async fn read_and_decrypt_frame<R>( stream: &mut R, cipher: &mut Option<Aes128CfbDec>, -) -> Result<BytesMut, Box<ReadPacketError>> +) -> Result<Box<[u8]>, Box<ReadPacketError>> where R: AsyncRead + Unpin + Send + Sync, { @@ -298,7 +312,9 @@ where let Some(message) = framed.next().await else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -310,7 +326,7 @@ where fn try_read_and_decrypt_frame<R>( stream: &mut R, cipher: &mut Option<Aes128CfbDec>, -) -> Result<Option<BytesMut>, Box<ReadPacketError>> +) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> where R: AsyncRead + Unpin + Send + Sync, { @@ -323,7 +339,8 @@ where let Some(message) = message else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -334,9 +351,9 @@ where } pub fn read_raw_packet_from_buffer<R>( - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, -) -> Result<Option<Vec<u8>>, Box<ReadPacketError>> +) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { |
