From 04eaa5c3d01a8f3a599a3a1abf7205eed80df4a2 Mon Sep 17 00:00:00 2001 From: mat Date: Wed, 25 Dec 2024 06:16:10 +0000 Subject: remove dependency on bytes crate for azalea-protocol and fix memory leak --- azalea-protocol/src/connect.rs | 17 ++++----- azalea-protocol/src/lib.rs | 12 ++++--- azalea-protocol/src/read.rs | 81 +++++++++++++++++++++++++----------------- azalea-protocol/src/write.rs | 4 +-- 4 files changed, 66 insertions(+), 48 deletions(-) (limited to 'azalea-protocol/src') diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs index f33ce2a5..ef202378 100755 --- a/azalea-protocol/src/connect.rs +++ b/azalea-protocol/src/connect.rs @@ -8,7 +8,6 @@ use std::net::SocketAddr; use azalea_auth::game_profile::GameProfile; use azalea_auth::sessionserver::{ClientSessionServerError, ServerSessionServerError}; use azalea_crypto::{Aes128CfbDec, Aes128CfbEnc}; -use bytes::BytesMut; use thiserror::Error; use tokio::io::{AsyncWriteExt, BufStream}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; @@ -28,7 +27,7 @@ use crate::write::{serialize_packet, write_raw_packet}; pub struct RawReadConnection { pub read_stream: OwnedReadHalf, - pub buffer: BytesMut, + pub buffer: Cursor>, pub compression_threshold: Option, pub dec_cipher: Option, } @@ -135,7 +134,7 @@ pub struct Connection { } impl RawReadConnection { - pub async fn read(&mut self) -> Result, Box> { + pub async fn read(&mut self) -> Result, Box> { read_raw_packet::<_>( &mut self.read_stream, &mut self.buffer, @@ -145,7 +144,7 @@ impl RawReadConnection { .await } - pub fn try_read(&mut self) -> Result>, Box> { + pub fn try_read(&mut self) -> Result>, Box> { try_read_raw_packet::<_>( &mut self.read_stream, &mut self.buffer, @@ -190,7 +189,7 @@ where /// Read a packet from the stream. pub async fn read(&mut self) -> Result> { let raw_packet = self.raw.read().await?; - deserialize_packet(&mut Cursor::new(raw_packet.as_slice())) + deserialize_packet(&mut Cursor::new(&raw_packet)) } /// Try to read a packet from the stream, or return Ok(None) if there's no @@ -199,9 +198,7 @@ where let Some(raw_packet) = self.raw.try_read()? else { return Ok(None); }; - Ok(Some(deserialize_packet(&mut Cursor::new( - raw_packet.as_slice(), - ))?)) + Ok(Some(deserialize_packet(&mut Cursor::new(&raw_packet))?)) } } impl WriteConnection @@ -304,7 +301,7 @@ impl Connection { reader: ReadConnection { raw: RawReadConnection { read_stream, - buffer: BytesMut::new(), + buffer: Cursor::new(Vec::new()), compression_threshold: None, dec_cipher: None, }, @@ -562,7 +559,7 @@ where reader: ReadConnection { raw: RawReadConnection { read_stream, - buffer: BytesMut::new(), + buffer: Cursor::new(Vec::new()), compression_threshold: None, dec_cipher: None, }, diff --git a/azalea-protocol/src/lib.rs b/azalea-protocol/src/lib.rs index 5e663c8f..12243de6 100644 --- a/azalea-protocol/src/lib.rs +++ b/azalea-protocol/src/lib.rs @@ -9,7 +9,7 @@ //! //! See [`crate::connect::Connection`] for an example. -// these two are necessary for thiserror backtraces +// this is necessary for thiserror backtraces #![feature(error_generic_member_access)] use std::{fmt::Display, net::SocketAddr, str::FromStr}; @@ -111,7 +111,6 @@ impl serde::Serialize for ServerAddress { mod tests { use std::io::Cursor; - use bytes::BytesMut; use uuid::Uuid; use crate::{ @@ -135,11 +134,16 @@ mod tests { .await .unwrap(); + assert_eq!( + stream, + [22, 0, 4, 116, 101, 115, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + let mut stream = Cursor::new(stream); let _ = read_packet::( &mut stream, - &mut BytesMut::new(), + &mut Cursor::new(Vec::new()), None, &mut None, ) @@ -163,7 +167,7 @@ mod tests { .unwrap(); let mut stream = Cursor::new(stream); - let mut buffer = BytesMut::new(); + let mut buffer = Cursor::new(Vec::new()); let _ = read_packet::(&mut stream, &mut buffer, None, &mut None) .await diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 8569ca73..6f9b754a 100755 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -9,13 +9,12 @@ use std::{ use azalea_buf::AzaleaReadVar; use azalea_buf::BufReadError; use azalea_crypto::Aes128CfbDec; -use bytes::Buf; -use bytes::BytesMut; use flate2::read::ZlibDecoder; use futures::StreamExt; use futures_lite::future; use thiserror::Error; use tokio::io::AsyncRead; +use tokio_util::bytes::Buf; use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::trace; @@ -79,12 +78,12 @@ pub enum FrameSplitterError { ConnectionClosed, } -/// Read a length, then read that amount of bytes from `BytesMut`. If there's -/// not enough data, return None -fn parse_frame(buffer: &mut BytesMut) -> Result { +/// Read a length, then read that amount of bytes from the `Cursor>`. If +/// there's not enough data, return None +fn parse_frame(buffer: &mut Cursor>) -> Result, FrameSplitterError> { // copy the buffer first and read from the copy, then once we make sure // the packet is all good we read it fully - let mut buffer_copy = Cursor::new(&buffer[..]); + let mut buffer_copy = Cursor::new(&buffer.get_ref()[buffer.position() as usize..]); // Packet Length let length = match u32::azalea_read_var(&mut buffer_copy) { Ok(length) => length as usize, @@ -106,18 +105,28 @@ fn parse_frame(buffer: &mut BytesMut) -> Result { // the length of the varint that says the length of the whole packet let varint_length = buffer.remaining() - buffer_copy.remaining(); + drop(buffer_copy); buffer.advance(varint_length); - let data = buffer.split_to(length); + let data = + buffer.get_ref()[buffer.position() as usize..buffer.position() as usize + length].to_vec(); + buffer.advance(length); + + if buffer.position() == buffer.get_ref().len() as u64 { + // reset the inner vec once we've reached the end of the buffer so we don't keep + // leaking memory + *buffer.get_mut() = Vec::new(); + buffer.set_position(0); + } - Ok(data) + Ok(data.into_boxed_slice()) } -fn frame_splitter(buffer: &mut BytesMut) -> Result>, FrameSplitterError> { +fn frame_splitter(buffer: &mut Cursor>) -> Result>, FrameSplitterError> { // https://tokio.rs/tokio/tutorial/framing let read_frame = parse_frame(buffer); match read_frame { - Ok(frame) => return Ok(Some(frame.to_vec())), + Ok(frame) => return Ok(Some(frame)), Err(err) => match err { FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { // we probably just haven't read enough yet @@ -141,7 +150,7 @@ pub fn deserialize_packet( // this is always true in multiplayer, false in singleplayer static VALIDATE_DECOMPRESSED: bool = true; -pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2097152; +pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2_097_152; #[derive(Error, Debug)] pub enum DecompressionError { @@ -169,13 +178,15 @@ pub enum DecompressionError { pub fn compression_decoder( stream: &mut Cursor<&[u8]>, compression_threshold: u32, -) -> Result, DecompressionError> { +) -> Result, DecompressionError> { // Data Length let n = u32::azalea_read_var(stream)?; if n == 0 { // no data size, no compression - let mut buf = vec![]; - std::io::Read::read_to_end(stream, &mut buf)?; + let buf = stream.get_ref()[stream.position() as usize..] + .to_vec() + .into_boxed_slice(); + stream.set_position(stream.get_ref().len() as u64); return Ok(buf); } @@ -194,11 +205,14 @@ pub fn compression_decoder( } } - let mut decoded_buf = vec![]; + // VALIDATE_DECOMPRESSED should always be true, so the max they can make us + // allocate here is 2mb + let mut decoded_buf = Vec::with_capacity(n as usize); + let mut decoder = ZlibDecoder::new(stream); decoder.read_to_end(&mut decoded_buf)?; - Ok(decoded_buf) + Ok(decoded_buf.into_boxed_slice()) } /// Read a single packet from a stream. @@ -211,7 +225,7 @@ pub fn compression_decoder( /// For the non-waiting version, see [`try_read_packet`]. pub async fn read_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, cipher: &mut Option, ) -> Result> @@ -219,7 +233,7 @@ where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { let raw_packet = read_raw_packet(stream, buffer, compression_threshold, cipher).await?; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(packet) } @@ -227,7 +241,7 @@ where /// received a full packet yet. pub fn try_read_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, cipher: &mut Option, ) -> Result, Box> @@ -238,18 +252,18 @@ where else { return Ok(None); }; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(Some(packet)) } pub async fn read_raw_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, // this has to be a &mut Option instead of an Option<&mut T> because // otherwise the borrow checker complains about the cipher being moved cipher: &mut Option, -) -> Result, Box> +) -> Result, Box> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -260,15 +274,15 @@ where }; let bytes = read_and_decrypt_frame(stream, cipher).await?; - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } pub fn try_read_raw_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, cipher: &mut Option, -) -> Result>, Box> +) -> Result>, Box> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -282,14 +296,14 @@ where return Ok(None); }; // we got some data, so add it to the buffer and try again - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } async fn read_and_decrypt_frame( stream: &mut R, cipher: &mut Option, -) -> Result> +) -> Result, Box> where R: AsyncRead + Unpin + Send + Sync, { @@ -298,7 +312,9 @@ where let Some(message) = framed.next().await else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -310,7 +326,7 @@ where fn try_read_and_decrypt_frame( stream: &mut R, cipher: &mut Option, -) -> Result, Box> +) -> Result>, Box> where R: AsyncRead + Unpin + Send + Sync, { @@ -323,7 +339,8 @@ where let Some(message) = message else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -334,9 +351,9 @@ where } pub fn read_raw_packet_from_buffer( - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, -) -> Result>, Box> +) -> Result>, Box> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs index 512d08ad..f1ffd82e 100755 --- a/azalea-protocol/src/write.rs +++ b/azalea-protocol/src/write.rs @@ -31,7 +31,7 @@ where pub fn serialize_packet( packet: &P, -) -> Result, PacketEncodeError> { +) -> Result, PacketEncodeError> { let mut buf = Vec::new(); packet.id().azalea_write_var(&mut buf)?; packet.write(&mut buf)?; @@ -42,7 +42,7 @@ pub fn serialize_packet( packet_string: format!("{packet:?}"), }); } - Ok(buf) + Ok(buf.into_boxed_slice()) } pub async fn write_raw_packet( -- cgit v1.2.3