aboutsummaryrefslogtreecommitdiff
path: root/azalea-protocol/src
diff options
context:
space:
mode:
authormat <github@matdoes.dev>2022-10-07 23:56:23 -0500
committermat <github@matdoes.dev>2022-10-07 23:56:23 -0500
commit6f6289376a0d9ffe7e58506824e37f6b380961c3 (patch)
tree97956fc560b338fbef630f0d0617a248e0e8b336 /azalea-protocol/src
parente9d8d0357ee63cce321e177bf19a8974699894ee (diff)
downloadazalea-drasl-6f6289376a0d9ffe7e58506824e37f6b380961c3.tar.xz
fix errors with rewritten packet reading
i forgot i never tested it before LMAO
Diffstat (limited to 'azalea-protocol/src')
-rw-r--r--azalea-protocol/src/connect.rs2
-rwxr-xr-xazalea-protocol/src/lib.rs39
-rw-r--r--azalea-protocol/src/read.rs144
-rwxr-xr-xazalea-protocol/src/write.rs4
4 files changed, 99 insertions, 90 deletions
diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs
index bd55e406..d7b9bd1d 100644
--- a/azalea-protocol/src/connect.rs
+++ b/azalea-protocol/src/connect.rs
@@ -57,7 +57,7 @@ where
/// Write a packet to the server
pub async fn write(&mut self, packet: W) -> std::io::Result<()> {
write_packet(
- packet,
+ &packet,
&mut self.write_stream,
self.compression_threshold,
&mut self.enc_cipher,
diff --git a/azalea-protocol/src/lib.rs b/azalea-protocol/src/lib.rs
index 4da2ba90..58ffac0a 100755
--- a/azalea-protocol/src/lib.rs
+++ b/azalea-protocol/src/lib.rs
@@ -1,5 +1,9 @@
//! This lib is responsible for parsing Minecraft packets.
+// these two are necessary for thiserror backtraces
+#![feature(error_generic_member_access)]
+#![feature(provide_any)]
+
use std::net::IpAddr;
use std::str::FromStr;
@@ -78,12 +82,10 @@ mod tests {
}
.get();
let mut stream = Vec::new();
- write_packet(packet, &mut stream, None, &mut None)
+ write_packet(&packet, &mut stream, None, &mut None)
.await
.unwrap();
- println!("stream: {stream:?}");
-
let mut stream = Cursor::new(stream);
let _ = read_packet::<ServerboundLoginPacket, _>(
@@ -95,4 +97,35 @@ mod tests {
.await
.unwrap();
}
+
+ #[tokio::test]
+ async fn test_double_hello_packet() {
+ let packet = ServerboundHelloPacket {
+ username: "test".to_string(),
+ public_key: Some(ProfilePublicKeyData {
+ expires_at: 0,
+ key: b"idontthinkthisreallymattersijustwantittobelongforthetest".to_vec(),
+ key_signature: b"idontthinkthisreallymattersijustwantittobelongforthetest".to_vec(),
+ }),
+ profile_id: Some(Uuid::from_u128(0)),
+ }
+ .get();
+ let mut stream = Vec::new();
+ write_packet(&packet, &mut stream, None, &mut None)
+ .await
+ .unwrap();
+ write_packet(&packet, &mut stream, None, &mut None)
+ .await
+ .unwrap();
+ let mut stream = Cursor::new(stream);
+
+ let mut buffer = BytesMut::new();
+
+ let _ = read_packet::<ServerboundLoginPacket, _>(&mut stream, &mut buffer, None, &mut None)
+ .await
+ .unwrap();
+ let _ = read_packet::<ServerboundLoginPacket, _>(&mut stream, &mut buffer, None, &mut None)
+ .await
+ .unwrap();
+ }
}
diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs
index eceede9d..4c398e96 100644
--- a/azalea-protocol/src/read.rs
+++ b/azalea-protocol/src/read.rs
@@ -5,16 +5,15 @@ use azalea_crypto::Aes128CfbDec;
use bytes::Buf;
use bytes::BytesMut;
use flate2::read::ZlibDecoder;
+use futures::StreamExt;
use log::{log_enabled, trace};
-use std::io::Cursor;
use std::{
- cell::Cell,
- io::Read,
- pin::Pin,
- task::{Context, Poll},
+ fmt::Debug,
+ io::{Cursor, Read},
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
+use tokio_util::codec::{BytesCodec, FramedRead};
#[derive(Error, Debug)]
pub enum ReadPacketError {
@@ -28,18 +27,28 @@ pub enum ReadPacketError {
UnknownPacketId { state_name: String, id: u32 },
#[error("Couldn't read packet id")]
ReadPacketId { source: BufReadError },
- #[error("Couldn't decompress packet")]
+ #[error(transparent)]
Decompress {
#[from]
+ #[backtrace]
source: DecompressionError,
},
- #[error("Frame splitter error")]
+ #[error(transparent)]
FrameSplitter {
#[from]
+ #[backtrace]
source: FrameSplitterError,
},
#[error("Leftover data after reading packet {packet_name}: {data:?}")]
LeftoverData { data: Vec<u8>, packet_name: String },
+ #[error(transparent)]
+ IoError {
+ #[from]
+ #[backtrace]
+ source: std::io::Error,
+ },
+ #[error("Connection closed")]
+ ConnectionClosed,
}
#[derive(Error, Debug)]
@@ -52,6 +61,7 @@ pub enum FrameSplitterError {
#[error("Io error")]
Io {
#[from]
+ #[backtrace]
source: std::io::Error,
},
#[error("Packet is longer than {max} bytes (is {size})")]
@@ -77,9 +87,9 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> {
},
};
- if length > buffer_copy.get_ref().len() {
+ if length > buffer_copy.remaining() {
return Err(FrameSplitterError::BadLength {
- max: buffer_copy.get_ref().len(),
+ max: buffer_copy.remaining(),
size: length,
});
}
@@ -88,49 +98,33 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> {
// from the real buffer now
// the length of the varint that says the length of the whole packet
- let varint_length = buffer.len() - buffer_copy.remaining();
- let _ = buffer.split_to(varint_length);
+ let varint_length = buffer.remaining() - buffer_copy.remaining();
+
+ buffer.advance(varint_length);
let data = buffer.split_to(length);
Ok(data)
}
-async fn frame_splitter<'a, R: ?Sized + Sized>(
- stream: &mut R,
- buffer: &'a mut BytesMut,
-) -> Result<Vec<u8>, FrameSplitterError>
-where
- R: AsyncRead + std::marker::Unpin + std::marker::Send,
-{
+fn frame_splitter<'a>(buffer: &'a mut BytesMut) -> Result<Option<Vec<u8>>, FrameSplitterError> {
// https://tokio.rs/tokio/tutorial/framing
- loop {
- let read_frame = parse_frame(buffer);
- match read_frame {
- Ok(frame) => return Ok(frame.to_vec()),
- Err(err) => match err {
- FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => {
- // we probably just haven't read enough yet
- }
- _ => return Err(err),
- },
- }
-
- let read_buf: usize = AsyncReadExt::read_buf(stream, buffer).await?;
- if 0 == read_buf {
- // The remote closed the connection. For this to be
- // a clean shutdown, there should be no data in the
- // read buffer. If there is, this means that the
- // peer closed the socket while sending a frame.
- if buffer.as_ref().is_empty() {
- return Err(FrameSplitterError::ConnectionClosed);
- } else {
- return Err(FrameSplitterError::ConnectionReset);
+ let read_frame = parse_frame(buffer);
+ match read_frame {
+ Ok(frame) => return Ok(Some(frame.to_vec())),
+ Err(err) => match err {
+ FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => {
+ // we probably just haven't read enough yet
}
- }
+ _ => return Err(err),
+ },
}
+
+ Ok(None)
}
-fn packet_decoder<P: ProtocolPacket>(stream: &mut Cursor<&[u8]>) -> Result<P, ReadPacketError> {
+fn packet_decoder<P: ProtocolPacket + Debug>(
+ stream: &mut Cursor<&[u8]>,
+) -> Result<P, ReadPacketError> {
// Packet ID
let packet_id =
u32::var_read_from(stream).map_err(|e| ReadPacketError::ReadPacketId { source: e })?;
@@ -152,6 +146,7 @@ pub enum DecompressionError {
#[error("Io error")]
Io {
#[from]
+ #[backtrace]
source: std::io::Error,
},
#[error("Badly compressed packet - size of {size} is below server threshold of {threshold}")]
@@ -197,42 +192,7 @@ fn compression_decoder(
Ok(decoded_buf)
}
-struct EncryptedStream<'a, R>
-where
- R: AsyncRead + std::marker::Unpin + std::marker::Send,
-{
- cipher: Cell<&'a mut Option<Aes128CfbDec>>,
- stream: &'a mut Pin<&'a mut R>,
-}
-
-impl<R> AsyncRead for EncryptedStream<'_, R>
-where
- R: AsyncRead + Unpin + Send,
-{
- fn poll_read(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut tokio::io::ReadBuf<'_>,
- ) -> Poll<std::io::Result<()>> {
- // i hate this
- let polled = self.as_mut().stream.as_mut().poll_read(cx, buf);
- match polled {
- Poll::Ready(r) => {
- // if we don't check for the remaining then we decrypt big packets incorrectly
- // (but only on linux and release mode for some reason LMAO)
- if buf.remaining() == 0 {
- if let Some(cipher) = self.as_mut().cipher.get_mut() {
- azalea_crypto::decrypt_packet(cipher, buf.filled_mut());
- }
- }
- Poll::Ready(r)
- }
- Poll::Pending => Poll::Pending,
- }
- }
-}
-
-pub async fn read_packet<'a, P: ProtocolPacket, R>(
+pub async fn read_packet<'a, P: ProtocolPacket + Debug, R>(
stream: &'a mut R,
buffer: &mut BytesMut,
compression_threshold: Option<u32>,
@@ -241,13 +201,29 @@ pub async fn read_packet<'a, P: ProtocolPacket, R>(
where
R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync,
{
- // if we were given a cipher, decrypt the packet
- let mut encrypted_stream = EncryptedStream {
- cipher: Cell::new(cipher),
- stream: &mut Pin::new(stream),
- };
+ let mut framed = FramedRead::new(stream, BytesCodec::new());
+ let mut buf = loop {
+ if let Some(buf) = frame_splitter(buffer)? {
+ // we got a full packet!!
+ break buf;
+ } else {
+ // no full packet yet :( keep reading
+ };
+
+ // if we were given a cipher, decrypt the packet
+ if let Some(message) = framed.next().await {
+ let mut bytes = message.unwrap();
+ println!("bytes: {:?}", bytes.len());
- let mut buf = frame_splitter(&mut encrypted_stream, buffer).await?;
+ if let Some(cipher) = cipher {
+ azalea_crypto::decrypt_packet(cipher, &mut bytes);
+ }
+
+ buffer.extend_from_slice(&bytes);
+ } else {
+ return Err(ReadPacketError::ConnectionClosed);
+ };
+ };
if let Some(compression_threshold) = compression_threshold {
buf = compression_decoder(&mut Cursor::new(&buf[..]), compression_threshold)?;
diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs
index b2ae2810..a04979a5 100755
--- a/azalea-protocol/src/write.rs
+++ b/azalea-protocol/src/write.rs
@@ -69,7 +69,7 @@ async fn compression_encoder(
}
pub async fn write_packet<P, W>(
- packet: P,
+ packet: &P,
stream: &mut W,
compression_threshold: Option<u32>,
cipher: &mut Option<Aes128CfbEnc>,
@@ -78,7 +78,7 @@ where
P: ProtocolPacket + Debug,
W: AsyncWrite + Unpin + Send,
{
- let mut buf = packet_encoder(&packet).unwrap();
+ let mut buf = packet_encoder(packet).unwrap();
if let Some(threshold) = compression_threshold {
buf = compression_encoder(&buf, threshold).await.unwrap();
}