diff options
Diffstat (limited to 'azalea-protocol/src')
| -rwxr-xr-x | azalea-protocol/src/connect.rs | 17 | ||||
| -rw-r--r-- | azalea-protocol/src/lib.rs | 12 | ||||
| -rwxr-xr-x | azalea-protocol/src/read.rs | 81 | ||||
| -rwxr-xr-x | azalea-protocol/src/write.rs | 4 |
4 files changed, 66 insertions, 48 deletions
diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs index f33ce2a5..ef202378 100755 --- a/azalea-protocol/src/connect.rs +++ b/azalea-protocol/src/connect.rs @@ -8,7 +8,6 @@ use std::net::SocketAddr; use azalea_auth::game_profile::GameProfile; use azalea_auth::sessionserver::{ClientSessionServerError, ServerSessionServerError}; use azalea_crypto::{Aes128CfbDec, Aes128CfbEnc}; -use bytes::BytesMut; use thiserror::Error; use tokio::io::{AsyncWriteExt, BufStream}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; @@ -28,7 +27,7 @@ use crate::write::{serialize_packet, write_raw_packet}; pub struct RawReadConnection { pub read_stream: OwnedReadHalf, - pub buffer: BytesMut, + pub buffer: Cursor<Vec<u8>>, pub compression_threshold: Option<u32>, pub dec_cipher: Option<Aes128CfbDec>, } @@ -135,7 +134,7 @@ pub struct Connection<R: ProtocolPacket, W: ProtocolPacket> { } impl RawReadConnection { - pub async fn read(&mut self) -> Result<Vec<u8>, Box<ReadPacketError>> { + pub async fn read(&mut self) -> Result<Box<[u8]>, Box<ReadPacketError>> { read_raw_packet::<_>( &mut self.read_stream, &mut self.buffer, @@ -145,7 +144,7 @@ impl RawReadConnection { .await } - pub fn try_read(&mut self) -> Result<Option<Vec<u8>>, Box<ReadPacketError>> { + pub fn try_read(&mut self) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> { try_read_raw_packet::<_>( &mut self.read_stream, &mut self.buffer, @@ -190,7 +189,7 @@ where /// Read a packet from the stream. pub async fn read(&mut self) -> Result<R, Box<ReadPacketError>> { let raw_packet = self.raw.read().await?; - deserialize_packet(&mut Cursor::new(raw_packet.as_slice())) + deserialize_packet(&mut Cursor::new(&raw_packet)) } /// Try to read a packet from the stream, or return Ok(None) if there's no @@ -199,9 +198,7 @@ where let Some(raw_packet) = self.raw.try_read()? else { return Ok(None); }; - Ok(Some(deserialize_packet(&mut Cursor::new( - raw_packet.as_slice(), - ))?)) + Ok(Some(deserialize_packet(&mut Cursor::new(&raw_packet))?)) } } impl<W> WriteConnection<W> @@ -304,7 +301,7 @@ impl Connection<ClientboundHandshakePacket, ServerboundHandshakePacket> { reader: ReadConnection { raw: RawReadConnection { read_stream, - buffer: BytesMut::new(), + buffer: Cursor::new(Vec::new()), compression_threshold: None, dec_cipher: None, }, @@ -562,7 +559,7 @@ where reader: ReadConnection { raw: RawReadConnection { read_stream, - buffer: BytesMut::new(), + buffer: Cursor::new(Vec::new()), compression_threshold: None, dec_cipher: None, }, diff --git a/azalea-protocol/src/lib.rs b/azalea-protocol/src/lib.rs index 5e663c8f..12243de6 100644 --- a/azalea-protocol/src/lib.rs +++ b/azalea-protocol/src/lib.rs @@ -9,7 +9,7 @@ //! //! See [`crate::connect::Connection`] for an example. -// these two are necessary for thiserror backtraces +// this is necessary for thiserror backtraces #![feature(error_generic_member_access)] use std::{fmt::Display, net::SocketAddr, str::FromStr}; @@ -111,7 +111,6 @@ impl serde::Serialize for ServerAddress { mod tests { use std::io::Cursor; - use bytes::BytesMut; use uuid::Uuid; use crate::{ @@ -135,11 +134,16 @@ mod tests { .await .unwrap(); + assert_eq!( + stream, + [22, 0, 4, 116, 101, 115, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + let mut stream = Cursor::new(stream); let _ = read_packet::<ServerboundLoginPacket, _>( &mut stream, - &mut BytesMut::new(), + &mut Cursor::new(Vec::new()), None, &mut None, ) @@ -163,7 +167,7 @@ mod tests { .unwrap(); let mut stream = Cursor::new(stream); - let mut buffer = BytesMut::new(); + let mut buffer = Cursor::new(Vec::new()); let _ = read_packet::<ServerboundLoginPacket, _>(&mut stream, &mut buffer, None, &mut None) .await diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 8569ca73..6f9b754a 100755 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -9,13 +9,12 @@ use std::{ use azalea_buf::AzaleaReadVar; use azalea_buf::BufReadError; use azalea_crypto::Aes128CfbDec; -use bytes::Buf; -use bytes::BytesMut; use flate2::read::ZlibDecoder; use futures::StreamExt; use futures_lite::future; use thiserror::Error; use tokio::io::AsyncRead; +use tokio_util::bytes::Buf; use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::trace; @@ -79,12 +78,12 @@ pub enum FrameSplitterError { ConnectionClosed, } -/// Read a length, then read that amount of bytes from `BytesMut`. If there's -/// not enough data, return None -fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { +/// Read a length, then read that amount of bytes from the `Cursor<Vec<u8>>`. If +/// there's not enough data, return None +fn parse_frame(buffer: &mut Cursor<Vec<u8>>) -> Result<Box<[u8]>, FrameSplitterError> { // copy the buffer first and read from the copy, then once we make sure // the packet is all good we read it fully - let mut buffer_copy = Cursor::new(&buffer[..]); + let mut buffer_copy = Cursor::new(&buffer.get_ref()[buffer.position() as usize..]); // Packet Length let length = match u32::azalea_read_var(&mut buffer_copy) { Ok(length) => length as usize, @@ -106,18 +105,28 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { // the length of the varint that says the length of the whole packet let varint_length = buffer.remaining() - buffer_copy.remaining(); + drop(buffer_copy); buffer.advance(varint_length); - let data = buffer.split_to(length); + let data = + buffer.get_ref()[buffer.position() as usize..buffer.position() as usize + length].to_vec(); + buffer.advance(length); + + if buffer.position() == buffer.get_ref().len() as u64 { + // reset the inner vec once we've reached the end of the buffer so we don't keep + // leaking memory + *buffer.get_mut() = Vec::new(); + buffer.set_position(0); + } - Ok(data) + Ok(data.into_boxed_slice()) } -fn frame_splitter(buffer: &mut BytesMut) -> Result<Option<Vec<u8>>, FrameSplitterError> { +fn frame_splitter(buffer: &mut Cursor<Vec<u8>>) -> Result<Option<Box<[u8]>>, FrameSplitterError> { // https://tokio.rs/tokio/tutorial/framing let read_frame = parse_frame(buffer); match read_frame { - Ok(frame) => return Ok(Some(frame.to_vec())), + Ok(frame) => return Ok(Some(frame)), Err(err) => match err { FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { // we probably just haven't read enough yet @@ -141,7 +150,7 @@ pub fn deserialize_packet<P: ProtocolPacket + Debug>( // this is always true in multiplayer, false in singleplayer static VALIDATE_DECOMPRESSED: bool = true; -pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2097152; +pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2_097_152; #[derive(Error, Debug)] pub enum DecompressionError { @@ -169,13 +178,15 @@ pub enum DecompressionError { pub fn compression_decoder( stream: &mut Cursor<&[u8]>, compression_threshold: u32, -) -> Result<Vec<u8>, DecompressionError> { +) -> Result<Box<[u8]>, DecompressionError> { // Data Length let n = u32::azalea_read_var(stream)?; if n == 0 { // no data size, no compression - let mut buf = vec![]; - std::io::Read::read_to_end(stream, &mut buf)?; + let buf = stream.get_ref()[stream.position() as usize..] + .to_vec() + .into_boxed_slice(); + stream.set_position(stream.get_ref().len() as u64); return Ok(buf); } @@ -194,11 +205,14 @@ pub fn compression_decoder( } } - let mut decoded_buf = vec![]; + // VALIDATE_DECOMPRESSED should always be true, so the max they can make us + // allocate here is 2mb + let mut decoded_buf = Vec::with_capacity(n as usize); + let mut decoder = ZlibDecoder::new(stream); decoder.read_to_end(&mut decoded_buf)?; - Ok(decoded_buf) + Ok(decoded_buf.into_boxed_slice()) } /// Read a single packet from a stream. @@ -211,7 +225,7 @@ pub fn compression_decoder( /// For the non-waiting version, see [`try_read_packet`]. pub async fn read_packet<P: ProtocolPacket + Debug, R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, ) -> Result<P, Box<ReadPacketError>> @@ -219,7 +233,7 @@ where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { let raw_packet = read_raw_packet(stream, buffer, compression_threshold, cipher).await?; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(packet) } @@ -227,7 +241,7 @@ where /// received a full packet yet. pub fn try_read_packet<P: ProtocolPacket + Debug, R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, ) -> Result<Option<P>, Box<ReadPacketError>> @@ -238,18 +252,18 @@ where else { return Ok(None); }; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(Some(packet)) } pub async fn read_raw_packet<R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, // this has to be a &mut Option<T> instead of an Option<&mut T> because // otherwise the borrow checker complains about the cipher being moved cipher: &mut Option<Aes128CfbDec>, -) -> Result<Vec<u8>, Box<ReadPacketError>> +) -> Result<Box<[u8]>, Box<ReadPacketError>> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -260,15 +274,15 @@ where }; let bytes = read_and_decrypt_frame(stream, cipher).await?; - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } pub fn try_read_raw_packet<R>( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbDec>, -) -> Result<Option<Vec<u8>>, Box<ReadPacketError>> +) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -282,14 +296,14 @@ where return Ok(None); }; // we got some data, so add it to the buffer and try again - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } async fn read_and_decrypt_frame<R>( stream: &mut R, cipher: &mut Option<Aes128CfbDec>, -) -> Result<BytesMut, Box<ReadPacketError>> +) -> Result<Box<[u8]>, Box<ReadPacketError>> where R: AsyncRead + Unpin + Send + Sync, { @@ -298,7 +312,9 @@ where let Some(message) = framed.next().await else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -310,7 +326,7 @@ where fn try_read_and_decrypt_frame<R>( stream: &mut R, cipher: &mut Option<Aes128CfbDec>, -) -> Result<Option<BytesMut>, Box<ReadPacketError>> +) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> where R: AsyncRead + Unpin + Send + Sync, { @@ -323,7 +339,8 @@ where let Some(message) = message else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -334,9 +351,9 @@ where } pub fn read_raw_packet_from_buffer<R>( - buffer: &mut BytesMut, + buffer: &mut Cursor<Vec<u8>>, compression_threshold: Option<u32>, -) -> Result<Option<Vec<u8>>, Box<ReadPacketError>> +) -> Result<Option<Box<[u8]>>, Box<ReadPacketError>> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs index 512d08ad..f1ffd82e 100755 --- a/azalea-protocol/src/write.rs +++ b/azalea-protocol/src/write.rs @@ -31,7 +31,7 @@ where pub fn serialize_packet<P: ProtocolPacket + Debug>( packet: &P, -) -> Result<Vec<u8>, PacketEncodeError> { +) -> Result<Box<[u8]>, PacketEncodeError> { let mut buf = Vec::new(); packet.id().azalea_write_var(&mut buf)?; packet.write(&mut buf)?; @@ -42,7 +42,7 @@ pub fn serialize_packet<P: ProtocolPacket + Debug>( packet_string: format!("{packet:?}"), }); } - Ok(buf) + Ok(buf.into_boxed_slice()) } pub async fn write_raw_packet<W>( |
