aboutsummaryrefslogtreecommitdiff
path: root/azalea-protocol/src/read.rs
diff options
context:
space:
mode:
authormat <27899617+mat-1@users.noreply.github.com>2022-10-07 20:12:36 -0500
committerGitHub <noreply@github.com>2022-10-07 20:12:36 -0500
commitbc3aa9467ae1e2d0ea1727093af9b0af14965e69 (patch)
tree8db3b735daed484507129eb0683db88ddec14210 /azalea-protocol/src/read.rs
parent695efef66fdf1e08f0cb6d8783c085875100fa2d (diff)
downloadazalea-drasl-bc3aa9467ae1e2d0ea1727093af9b0af14965e69.tar.xz
Replace impl Read with Cursor<&[u8]> (#26)
* Start getting rid of Cursor * try to make the tests pass and fail * make the tests pass * remove unused uses * fix clippy warnings * fix potential OOM exploits * fix OOM in az-nbt * fix nbt benchmark * fix a test * start replacing it with Cursor<Vec<u8>> * wip * fix all the issues * fix all tests * fix nbt benchmark * fix warnings
Diffstat (limited to 'azalea-protocol/src/read.rs')
-rw-r--r--[-rwxr-xr-x]azalea-protocol/src/read.rs98
1 files changed, 75 insertions, 23 deletions
diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs
index 8a2aaf7d..eceede9d 100755..100644
--- a/azalea-protocol/src/read.rs
+++ b/azalea-protocol/src/read.rs
@@ -1,9 +1,12 @@
use crate::packets::ProtocolPacket;
+use azalea_buf::BufReadError;
use azalea_buf::McBufVarReadable;
-use azalea_buf::{read_varint_async, BufReadError};
use azalea_crypto::Aes128CfbDec;
+use bytes::Buf;
+use bytes::BytesMut;
use flate2::read::ZlibDecoder;
use log::{log_enabled, trace};
+use std::io::Cursor;
use std::{
cell::Cell,
io::Read,
@@ -52,34 +55,82 @@ pub enum FrameSplitterError {
source: std::io::Error,
},
#[error("Packet is longer than {max} bytes (is {size})")]
- BadLength { max: u32, size: u32 },
+ BadLength { max: usize, size: usize },
+ #[error("Connection reset by peer")]
+ ConnectionReset,
+ #[error("Connection closed")]
+ ConnectionClosed,
}
-async fn frame_splitter<R: ?Sized>(mut stream: &mut R) -> Result<Vec<u8>, FrameSplitterError>
-where
- R: AsyncRead + std::marker::Unpin + std::marker::Send,
-{
+/// Read a length, then read that amount of bytes from BytesMut. If there's not
+/// enough data, return None
+fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> {
+ // copy the buffer first and read from the copy, then once we make sure
+ // the packet is all good we read it fully
+ let mut buffer_copy = Cursor::new(&buffer[..]);
// Packet Length
- let length = read_varint_async(&mut stream).await? as u32;
+ let length = match u32::var_read_from(&mut buffer_copy) {
+ Ok(length) => length as usize,
+ Err(err) => match err {
+ BufReadError::Io(io_err) => return Err(FrameSplitterError::Io { source: io_err }),
+ _ => return Err(err.into()),
+ },
+ };
- // TODO: read individual tcp packets so we don't need this
- // https://github.com/tokio-rs/tokio/blob/master/examples/print_each_packet.rs
- let max_length: u32 = 2u32.pow(20u32); // 1mb, arbitrary
- if length > max_length {
- // minecraft *probably* won't send packets bigger than this
+ if length > buffer_copy.get_ref().len() {
return Err(FrameSplitterError::BadLength {
- max: max_length,
+ max: buffer_copy.get_ref().len(),
size: length,
});
}
- let mut buf = vec![0; length as usize];
- stream.read_exact(&mut buf).await?;
+ // we read from the copy and we know it's legit, so we can take those bytes
+ // from the real buffer now
+
+ // the length of the varint that says the length of the whole packet
+ let varint_length = buffer.len() - buffer_copy.remaining();
+ let _ = buffer.split_to(varint_length);
+ let data = buffer.split_to(length);
- Ok(buf)
+ Ok(data)
+}
+
+async fn frame_splitter<'a, R: ?Sized + Sized>(
+ stream: &mut R,
+ buffer: &'a mut BytesMut,
+) -> Result<Vec<u8>, FrameSplitterError>
+where
+ R: AsyncRead + std::marker::Unpin + std::marker::Send,
+{
+ // https://tokio.rs/tokio/tutorial/framing
+ loop {
+ let read_frame = parse_frame(buffer);
+ match read_frame {
+ Ok(frame) => return Ok(frame.to_vec()),
+ Err(err) => match err {
+ FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => {
+ // we probably just haven't read enough yet
+ }
+ _ => return Err(err),
+ },
+ }
+
+ let read_buf: usize = AsyncReadExt::read_buf(stream, buffer).await?;
+ if 0 == read_buf {
+ // The remote closed the connection. For this to be
+ // a clean shutdown, there should be no data in the
+ // read buffer. If there is, this means that the
+ // peer closed the socket while sending a frame.
+ if buffer.as_ref().is_empty() {
+ return Err(FrameSplitterError::ConnectionClosed);
+ } else {
+ return Err(FrameSplitterError::ConnectionReset);
+ }
+ }
+ }
}
-fn packet_decoder<P: ProtocolPacket>(stream: &mut impl Read) -> Result<P, ReadPacketError> {
+fn packet_decoder<P: ProtocolPacket>(stream: &mut Cursor<&[u8]>) -> Result<P, ReadPacketError> {
// Packet ID
let packet_id =
u32::var_read_from(stream).map_err(|e| ReadPacketError::ReadPacketId { source: e })?;
@@ -112,7 +163,7 @@ pub enum DecompressionError {
}
fn compression_decoder(
- stream: &mut impl Read,
+ stream: &mut Cursor<&[u8]>,
compression_threshold: u32,
) -> Result<Vec<u8>, DecompressionError> {
// Data Length
@@ -120,7 +171,7 @@ fn compression_decoder(
if n == 0 {
// no data size, no compression
let mut buf = vec![];
- stream.read_to_end(&mut buf)?;
+ std::io::Read::read_to_end(stream, &mut buf)?;
return Ok(buf);
}
@@ -183,6 +234,7 @@ where
pub async fn read_packet<'a, P: ProtocolPacket, R>(
stream: &'a mut R,
+ buffer: &mut BytesMut,
compression_threshold: Option<u32>,
cipher: &mut Option<Aes128CfbDec>,
) -> Result<P, ReadPacketError>
@@ -195,10 +247,10 @@ where
stream: &mut Pin::new(stream),
};
- let mut buf = frame_splitter(&mut encrypted_stream).await?;
+ let mut buf = frame_splitter(&mut encrypted_stream, buffer).await?;
if let Some(compression_threshold) = compression_threshold {
- buf = compression_decoder(&mut buf.as_slice(), compression_threshold)?;
+ buf = compression_decoder(&mut Cursor::new(&buf[..]), compression_threshold)?;
}
if log_enabled!(log::Level::Trace) {
@@ -213,7 +265,7 @@ where
trace!("Reading packet with bytes: {buf_string}");
}
- let packet = packet_decoder(&mut buf.as_slice())?;
+ let packet = packet_decoder(&mut Cursor::new(&buf[..]))?;
Ok(packet)
}
@@ -226,7 +278,7 @@ mod tests {
#[tokio::test]
async fn test_read_packet() {
- let mut buf = Cursor::new(vec![
+ let mut buf: Cursor<&[u8]> = Cursor::new(&[
51, 0, 12, 177, 250, 155, 132, 106, 60, 218, 161, 217, 90, 157, 105, 57, 206, 20, 0, 5,
104, 101, 108, 108, 111, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 116,
123, 34, 101, 120, 116, 114, 97, 34, 58, 91, 123, 34, 99, 111, 108, 111, 114, 34, 58,