aboutsummaryrefslogtreecommitdiff
path: root/azalea-protocol/src/read.rs
diff options
context:
space:
mode:
authormat <git@matdoes.dev>2024-12-25 06:16:10 +0000
committermat <git@matdoes.dev>2024-12-25 06:16:10 +0000
commit04eaa5c3d01a8f3a599a3a1abf7205eed80df4a2 (patch)
treee37b34e8bf03045778f383f4e324414e2047ca92 /azalea-protocol/src/read.rs
parent0ee9ed50e30222784d094e20302cadc879f2b6db (diff)
downloadazalea-drasl-04eaa5c3d01a8f3a599a3a1abf7205eed80df4a2.tar.xz
remove dependency on bytes crate for azalea-protocol and fix memory leak
Diffstat (limited to 'azalea-protocol/src/read.rs')
-rwxr-xr-xazalea-protocol/src/read.rs81
1 files changed, 49 insertions, 32 deletions
diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs
index 8569ca73..6f9b754a 100755
--- a/azalea-protocol/src/read.rs
+++ b/azalea-protocol/src/read.rs
@@ -9,13 +9,12 @@ use std::{
use azalea_buf::AzaleaReadVar;
use azalea_buf::BufReadError;
use azalea_crypto::Aes128CfbDec;
-use bytes::Buf;
-use bytes::BytesMut;
use flate2::read::ZlibDecoder;
use futures::StreamExt;
use futures_lite::future;
use thiserror::Error;
use tokio::io::AsyncRead;
+use tokio_util::bytes::Buf;
use tokio_util::codec::{BytesCodec, FramedRead};
use tracing::trace;
@@ -79,12 +78,12 @@ pub enum FrameSplitterError {
ConnectionClosed,
}
-/// Read a length, then read that amount of bytes from `BytesMut`. If there's
-/// not enough data, return None
-fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> {
+/// Read a length, then read that amount of bytes from the `Cursor<Vec<u8>>`. If
+/// there's not enough data, return None
+fn parse_frame(buffer: &mut Cursor<Vec<u8>>) -> Result<Box<[u8]>, FrameSplitterError> {
// copy the buffer first and read from the copy, then once we make sure
// the packet is all good we read it fully
- let mut buffer_copy = Cursor::new(&buffer[..]);
+ let mut buffer_copy = Cursor::new(&buffer.get_ref()[buffer.position() as usize..]);
// Packet Length
let length = match u32::azalea_read_var(&mut buffer_copy) {
Ok(length) => length as usize,
@@ -106,18 +105,28 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> {
// the length of the varint that says the length of the whole packet
let varint_length = buffer.remaining() - buffer_copy.remaining();
+ drop(buffer_copy);
buffer.advance(varint_length);
- let data = buffer.split_to(length);
+ let data =
+ buffer.get_ref()[buffer.position() as usize..buffer.position() as usize + length].to_vec();
+ buffer.advance(length);
+
+ if buffer.position() == buffer.get_ref().len() as u64 {
+ // reset the inner vec once we've reached the end of the buffer so we don't keep
+ // leaking memory
+ *buffer.get_mut() = Vec::new();
+ buffer.set_position(0);
+ }
- Ok(data)
+ Ok(data.into_boxed_slice())
}
-fn frame_splitter(buffer: &mut BytesMut) -> Result<Option<Vec<u8>>, FrameSplitterError> {
+fn frame_splitter(buffer: &mut Cursor<Vec<u8>>) -> Result<Option<Box<[u8]>>, FrameSplitterError> {
// https://tokio.rs/tokio/tutorial/framing
let read_frame = parse_frame(buffer);
match read_frame {
- Ok(frame) => return Ok(Some(frame.to_vec())),
+ Ok(frame) => return Ok(Some(frame)),
Err(err) => match err {
FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => {
// we probably just haven't read enough yet
@@ -141,7 +150,7 @@ pub fn deserialize_packet<P: ProtocolPacket + Debug>(
// this is always true in multiplayer, false in singleplayer
static VALIDATE_DECOMPRESSED: bool = true;
-pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2097152;
+pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2_097_152;
#[derive(Error, Debug)]
pub enum DecompressionError {
@@ -169,13 +178,15 @@ pub enum DecompressionError {
pub fn compression_decoder(
stream: &mut Cursor<&[u8]>,
compression_threshold: u32,
-) -> Result<Vec<u8>, DecompressionError> {
+) -> Result<Box<[u8]>, DecompressionError> {
// Data Length
let n = u32::azalea_read_var(stream)?;
if n == 0 {
// no data size, no compression
- let mut buf = vec![];
- std::io::Read::read_to_end(stream, &mut buf)?;
+ let buf = stream.get_ref()[stream.position() as usize..]
+ .to_vec()
+ .into_boxed_slice();
+ stream.set_position(stream.get_ref().len() as u64);
return Ok(buf);
}
@@ -194,11 +205,14 @@ pub fn compression_decoder(
}
}
- let mut decoded_buf = vec![];
+ // VALIDATE_DECOMPRESSED should always be true, so the max they can make us
+ // allocate here is 2mb
+ let mut decoded_buf = Vec::with_capacity(n as usize);
+
let mut decoder = ZlibDecoder::new(stream);
decoder.read_to_end(&mut decoded_buf)?;
- Ok(decoded_buf)
+ Ok(decoded_buf.into_boxed_slice())
}
/// Read a single packet from a stream.
@@ -211,7 +225,7 @@ pub fn compression_decoder(
/// For the non-waiting version, see [`try_read_packet`].
pub async fn read_packet<P: ProtocolPacket + Debug, R>(
stream: &mut R,
- buffer: &mut BytesMut,
+ buffer: &mut Cursor<Vec<u8>>,
compression_threshold: Option<u32>,
cipher: &mut Option<Aes128CfbDec>,
) -> Result<P, Box<ReadPacketError>>
@@ -219,7 +233,7 @@ where
R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync,
{
let raw_packet = read_raw_packet(stream, buffer, compression_threshold, cipher).await?;
- let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?;
+ let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?;
Ok(packet)
}
@@ -227,7 +241,7 @@ where
/// received a full packet yet.
pub fn try_read_packet<P: ProtocolPacket + Debug, R>(
stream: &mut R,
- buffer: &mut BytesMut,
+ buffer: &mut Cursor<Vec<u8>>,
compression_threshold: Option<u32>,
cipher: &mut Option<Aes128CfbDec>,
) -> Result<Option<P>, Box<ReadPacketError>>
@@ -238,18 +252,18 @@ where
else {
return Ok(None);
};
- let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?;
+ let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?;
Ok(Some(packet))
}
pub async fn read_raw_packet<R>(
stream: &mut R,
- buffer: &mut BytesMut,
+ buffer: &mut Cursor<Vec<u8>>,
compression_threshold: Option<u32>,
// this has to be a &mut Option<T> instead of an Option<&mut T> because
// otherwise the borrow checker complains about the cipher being moved
cipher: &mut Option<Aes128CfbDec>,
-) -> Result<Vec<u8>, Box<ReadPacketError>>
+) -> Result<Box<[u8]>, Box<ReadPacketError>>
where
R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync,
{
@@ -260,15 +274,15 @@ where
};
let bytes = read_and_decrypt_frame(stream, cipher).await?;
- buffer.extend_from_slice(&bytes);
+ buffer.get_mut().extend_from_slice(&bytes);
}
}
pub fn try_read_raw_packet<R>(
stream: &mut R,
- buffer: &mut BytesMut,
+ buffer: &mut Cursor<Vec<u8>>,
compression_threshold: Option<u32>,
cipher: &mut Option<Aes128CfbDec>,
-) -> Result<Option<Vec<u8>>, Box<ReadPacketError>>
+) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>>
where
R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync,
{
@@ -282,14 +296,14 @@ where
return Ok(None);
};
// we got some data, so add it to the buffer and try again
- buffer.extend_from_slice(&bytes);
+ buffer.get_mut().extend_from_slice(&bytes);
}
}
async fn read_and_decrypt_frame<R>(
stream: &mut R,
cipher: &mut Option<Aes128CfbDec>,
-) -> Result<BytesMut, Box<ReadPacketError>>
+) -> Result<Box<[u8]>, Box<ReadPacketError>>
where
R: AsyncRead + Unpin + Send + Sync,
{
@@ -298,7 +312,9 @@ where
let Some(message) = framed.next().await else {
return Err(Box::new(ReadPacketError::ConnectionClosed));
};
- let mut bytes = message.map_err(ReadPacketError::from)?;
+ let bytes = message.map_err(ReadPacketError::from)?;
+
+ let mut bytes = bytes.to_vec().into_boxed_slice();
// decrypt if necessary
if let Some(cipher) = cipher {
@@ -310,7 +326,7 @@ where
fn try_read_and_decrypt_frame<R>(
stream: &mut R,
cipher: &mut Option<Aes128CfbDec>,
-) -> Result<Option<BytesMut>, Box<ReadPacketError>>
+) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>>
where
R: AsyncRead + Unpin + Send + Sync,
{
@@ -323,7 +339,8 @@ where
let Some(message) = message else {
return Err(Box::new(ReadPacketError::ConnectionClosed));
};
- let mut bytes = message.map_err(ReadPacketError::from)?;
+ let bytes = message.map_err(ReadPacketError::from)?;
+ let mut bytes = bytes.to_vec().into_boxed_slice();
// decrypt if necessary
if let Some(cipher) = cipher {
@@ -334,9 +351,9 @@ where
}
pub fn read_raw_packet_from_buffer<R>(
- buffer: &mut BytesMut,
+ buffer: &mut Cursor<Vec<u8>>,
compression_threshold: Option<u32>,
-) -> Result<Option<Vec<u8>>, Box<ReadPacketError>>
+) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>>
where
R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync,
{