diff options
| author | mat <github@matdoes.dev> | 2022-10-07 23:56:23 -0500 |
|---|---|---|
| committer | mat <github@matdoes.dev> | 2022-10-07 23:56:23 -0500 |
| commit | 6f6289376a0d9ffe7e58506824e37f6b380961c3 (patch) | |
| tree | 97956fc560b338fbef630f0d0617a248e0e8b336 /azalea-protocol/src | |
| parent | e9d8d0357ee63cce321e177bf19a8974699894ee (diff) | |
| download | azalea-drasl-6f6289376a0d9ffe7e58506824e37f6b380961c3.tar.xz | |
fix errors with rewritten packet reading
i forgot i never tested it before LMAO
Diffstat (limited to 'azalea-protocol/src')
| -rw-r--r-- | azalea-protocol/src/connect.rs | 2 | ||||
| -rwxr-xr-x | azalea-protocol/src/lib.rs | 39 | ||||
| -rw-r--r-- | azalea-protocol/src/read.rs | 144 | ||||
| -rwxr-xr-x | azalea-protocol/src/write.rs | 4 |
4 files changed, 99 insertions, 90 deletions
diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs index bd55e406..d7b9bd1d 100644 --- a/azalea-protocol/src/connect.rs +++ b/azalea-protocol/src/connect.rs @@ -57,7 +57,7 @@ where /// Write a packet to the server pub async fn write(&mut self, packet: W) -> std::io::Result<()> { write_packet( - packet, + &packet, &mut self.write_stream, self.compression_threshold, &mut self.enc_cipher, diff --git a/azalea-protocol/src/lib.rs b/azalea-protocol/src/lib.rs index 4da2ba90..58ffac0a 100755 --- a/azalea-protocol/src/lib.rs +++ b/azalea-protocol/src/lib.rs @@ -1,5 +1,9 @@ //! This lib is responsible for parsing Minecraft packets. +// these two are necessary for thiserror backtraces +#![feature(error_generic_member_access)] +#![feature(provide_any)] + use std::net::IpAddr; use std::str::FromStr; @@ -78,12 +82,10 @@ mod tests { } .get(); let mut stream = Vec::new(); - write_packet(packet, &mut stream, None, &mut None) + write_packet(&packet, &mut stream, None, &mut None) .await .unwrap(); - println!("stream: {stream:?}"); - let mut stream = Cursor::new(stream); let _ = read_packet::<ServerboundLoginPacket, _>( @@ -95,4 +97,35 @@ mod tests { .await .unwrap(); } + + #[tokio::test] + async fn test_double_hello_packet() { + let packet = ServerboundHelloPacket { + username: "test".to_string(), + public_key: Some(ProfilePublicKeyData { + expires_at: 0, + key: b"idontthinkthisreallymattersijustwantittobelongforthetest".to_vec(), + key_signature: b"idontthinkthisreallymattersijustwantittobelongforthetest".to_vec(), + }), + profile_id: Some(Uuid::from_u128(0)), + } + .get(); + let mut stream = Vec::new(); + write_packet(&packet, &mut stream, None, &mut None) + .await + .unwrap(); + write_packet(&packet, &mut stream, None, &mut None) + .await + .unwrap(); + let mut stream = Cursor::new(stream); + + let mut buffer = BytesMut::new(); + + let _ = read_packet::<ServerboundLoginPacket, _>(&mut stream, &mut buffer, None, &mut None) + .await + .unwrap(); + let _ = read_packet::<ServerboundLoginPacket, _>(&mut stream, &mut buffer, None, &mut None) + .await + .unwrap(); + } } diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index eceede9d..4c398e96 100644 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -5,16 +5,15 @@ use azalea_crypto::Aes128CfbDec; use bytes::Buf; use bytes::BytesMut; use flate2::read::ZlibDecoder; +use futures::StreamExt; use log::{log_enabled, trace}; -use std::io::Cursor; use std::{ - cell::Cell, - io::Read, - pin::Pin, - task::{Context, Poll}, + fmt::Debug, + io::{Cursor, Read}, }; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio_util::codec::{BytesCodec, FramedRead}; #[derive(Error, Debug)] pub enum ReadPacketError { @@ -28,18 +27,28 @@ pub enum ReadPacketError { UnknownPacketId { state_name: String, id: u32 }, #[error("Couldn't read packet id")] ReadPacketId { source: BufReadError }, - #[error("Couldn't decompress packet")] + #[error(transparent)] Decompress { #[from] + #[backtrace] source: DecompressionError, }, - #[error("Frame splitter error")] + #[error(transparent)] FrameSplitter { #[from] + #[backtrace] source: FrameSplitterError, }, #[error("Leftover data after reading packet {packet_name}: {data:?}")] LeftoverData { data: Vec<u8>, packet_name: String }, + #[error(transparent)] + IoError { + #[from] + #[backtrace] + source: std::io::Error, + }, + #[error("Connection closed")] + ConnectionClosed, } #[derive(Error, Debug)] @@ -52,6 +61,7 @@ pub enum FrameSplitterError { #[error("Io error")] Io { #[from] + #[backtrace] source: std::io::Error, }, #[error("Packet is longer than {max} bytes (is {size})")] @@ -77,9 +87,9 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { }, }; - if length > buffer_copy.get_ref().len() { + if length > buffer_copy.remaining() { return Err(FrameSplitterError::BadLength { - max: buffer_copy.get_ref().len(), + max: buffer_copy.remaining(), size: length, }); } @@ -88,49 +98,33 @@ fn parse_frame(buffer: &mut BytesMut) -> Result<BytesMut, FrameSplitterError> { // from the real buffer now // the length of the varint that says the length of the whole packet - let varint_length = buffer.len() - buffer_copy.remaining(); - let _ = buffer.split_to(varint_length); + let varint_length = buffer.remaining() - buffer_copy.remaining(); + + buffer.advance(varint_length); let data = buffer.split_to(length); Ok(data) } -async fn frame_splitter<'a, R: ?Sized + Sized>( - stream: &mut R, - buffer: &'a mut BytesMut, -) -> Result<Vec<u8>, FrameSplitterError> -where - R: AsyncRead + std::marker::Unpin + std::marker::Send, -{ +fn frame_splitter<'a>(buffer: &'a mut BytesMut) -> Result<Option<Vec<u8>>, FrameSplitterError> { // https://tokio.rs/tokio/tutorial/framing - loop { - let read_frame = parse_frame(buffer); - match read_frame { - Ok(frame) => return Ok(frame.to_vec()), - Err(err) => match err { - FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { - // we probably just haven't read enough yet - } - _ => return Err(err), - }, - } - - let read_buf: usize = AsyncReadExt::read_buf(stream, buffer).await?; - if 0 == read_buf { - // The remote closed the connection. For this to be - // a clean shutdown, there should be no data in the - // read buffer. If there is, this means that the - // peer closed the socket while sending a frame. - if buffer.as_ref().is_empty() { - return Err(FrameSplitterError::ConnectionClosed); - } else { - return Err(FrameSplitterError::ConnectionReset); + let read_frame = parse_frame(buffer); + match read_frame { + Ok(frame) => return Ok(Some(frame.to_vec())), + Err(err) => match err { + FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { + // we probably just haven't read enough yet } - } + _ => return Err(err), + }, } + + Ok(None) } -fn packet_decoder<P: ProtocolPacket>(stream: &mut Cursor<&[u8]>) -> Result<P, ReadPacketError> { +fn packet_decoder<P: ProtocolPacket + Debug>( + stream: &mut Cursor<&[u8]>, +) -> Result<P, ReadPacketError> { // Packet ID let packet_id = u32::var_read_from(stream).map_err(|e| ReadPacketError::ReadPacketId { source: e })?; @@ -152,6 +146,7 @@ pub enum DecompressionError { #[error("Io error")] Io { #[from] + #[backtrace] source: std::io::Error, }, #[error("Badly compressed packet - size of {size} is below server threshold of {threshold}")] @@ -197,42 +192,7 @@ fn compression_decoder( Ok(decoded_buf) } -struct EncryptedStream<'a, R> -where - R: AsyncRead + std::marker::Unpin + std::marker::Send, -{ - cipher: Cell<&'a mut Option<Aes128CfbDec>>, - stream: &'a mut Pin<&'a mut R>, -} - -impl<R> AsyncRead for EncryptedStream<'_, R> -where - R: AsyncRead + Unpin + Send, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll<std::io::Result<()>> { - // i hate this - let polled = self.as_mut().stream.as_mut().poll_read(cx, buf); - match polled { - Poll::Ready(r) => { - // if we don't check for the remaining then we decrypt big packets incorrectly - // (but only on linux and release mode for some reason LMAO) - if buf.remaining() == 0 { - if let Some(cipher) = self.as_mut().cipher.get_mut() { - azalea_crypto::decrypt_packet(cipher, buf.filled_mut()); - } - } - Poll::Ready(r) - } - Poll::Pending => Poll::Pending, - } - } -} - -pub async fn read_packet<'a, P: ProtocolPacket, R>( +pub async fn read_packet<'a, P: ProtocolPacket + Debug, R>( stream: &'a mut R, buffer: &mut BytesMut, compression_threshold: Option<u32>, @@ -241,13 +201,29 @@ pub async fn read_packet<'a, P: ProtocolPacket, R>( where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { - // if we were given a cipher, decrypt the packet - let mut encrypted_stream = EncryptedStream { - cipher: Cell::new(cipher), - stream: &mut Pin::new(stream), - }; + let mut framed = FramedRead::new(stream, BytesCodec::new()); + let mut buf = loop { + if let Some(buf) = frame_splitter(buffer)? { + // we got a full packet!! + break buf; + } else { + // no full packet yet :( keep reading + }; + + // if we were given a cipher, decrypt the packet + if let Some(message) = framed.next().await { + let mut bytes = message.unwrap(); + println!("bytes: {:?}", bytes.len()); - let mut buf = frame_splitter(&mut encrypted_stream, buffer).await?; + if let Some(cipher) = cipher { + azalea_crypto::decrypt_packet(cipher, &mut bytes); + } + + buffer.extend_from_slice(&bytes); + } else { + return Err(ReadPacketError::ConnectionClosed); + }; + }; if let Some(compression_threshold) = compression_threshold { buf = compression_decoder(&mut Cursor::new(&buf[..]), compression_threshold)?; diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs index b2ae2810..a04979a5 100755 --- a/azalea-protocol/src/write.rs +++ b/azalea-protocol/src/write.rs @@ -69,7 +69,7 @@ async fn compression_encoder( } pub async fn write_packet<P, W>( - packet: P, + packet: &P, stream: &mut W, compression_threshold: Option<u32>, cipher: &mut Option<Aes128CfbEnc>, @@ -78,7 +78,7 @@ where P: ProtocolPacket + Debug, W: AsyncWrite + Unpin + Send, { - let mut buf = packet_encoder(&packet).unwrap(); + let mut buf = packet_encoder(packet).unwrap(); if let Some(threshold) = compression_threshold { buf = compression_encoder(&buf, threshold).await.unwrap(); } |
