From e46577a214ae30159d14c128c45488e3772c8f84 Mon Sep 17 00:00:00 2001 From: mat <27899617+mat-1@users.noreply.github.com> Date: Mon, 19 Sep 2022 21:21:46 -0500 Subject: Fix connection writer being locked (#23) * Split connection struct in az-protocol * az-client uses split conns * fix errors * add a convenience write_packet fn to az-client --- azalea-protocol/src/connect.rs | 99 +++++++++++++++++++++++++++++++----------- azalea-protocol/src/read.rs | 5 +-- 2 files changed, 75 insertions(+), 29 deletions(-) (limited to 'azalea-protocol') diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs index dbca4214..3fdcecd3 100755 --- a/azalea-protocol/src/connect.rs +++ b/azalea-protocol/src/connect.rs @@ -12,37 +12,50 @@ use azalea_crypto::{Aes128CfbDec, Aes128CfbEnc}; use std::fmt::Debug; use std::marker::PhantomData; use thiserror::Error; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; -pub struct Connection { - /// The buffered writer - pub stream: TcpStream, +pub struct ReadConnection { + pub read_stream: OwnedReadHalf, pub compression_threshold: Option, - pub enc_cipher: Option, pub dec_cipher: Option, _reading: PhantomData, +} + +pub struct WriteConnection { + pub write_stream: OwnedWriteHalf, + pub compression_threshold: Option, + pub enc_cipher: Option, _writing: PhantomData, } -impl Connection +pub struct Connection { + pub reader: ReadConnection, + pub writer: WriteConnection, +} + +impl ReadConnection where R: ProtocolPacket + Debug, - W: ProtocolPacket + Debug, { pub async fn read(&mut self) -> Result { read_packet::( - &mut self.stream, + &mut self.read_stream, self.compression_threshold, &mut self.dec_cipher, ) .await } - +} +impl WriteConnection +where + W: ProtocolPacket + Debug, +{ /// Write a packet to the server pub async fn write(&mut self, packet: W) -> std::io::Result<()> { write_packet( packet, - &mut self.stream, + &mut self.write_stream, self.compression_threshold, &mut self.enc_cipher, ) @@ -50,6 +63,26 @@ where } } +impl Connection +where + R: ProtocolPacket + Debug, + W: ProtocolPacket + Debug, +{ + pub async fn read(&mut self) -> Result { + self.reader.read().await + } + + /// Write a packet to the server + pub async fn write(&mut self, packet: W) -> std::io::Result<()> { + self.writer.write(packet).await + } + + /// Split the reader and writer into two objects. This doesn't allocate. + pub fn into_split(self) -> (ReadConnection, WriteConnection) { + (self.reader, self.writer) + } +} + #[derive(Error, Debug)] pub enum ConnectionError { #[error("{0}")] @@ -66,13 +99,21 @@ impl Connection { // enable tcp_nodelay stream.set_nodelay(true)?; + let (read_stream, write_stream) = stream.into_split(); + Ok(Connection { - stream, - compression_threshold: None, - enc_cipher: None, - dec_cipher: None, - _reading: PhantomData, - _writing: PhantomData, + reader: ReadConnection { + read_stream, + compression_threshold: None, + dec_cipher: None, + _reading: PhantomData, + }, + writer: WriteConnection { + write_stream, + compression_threshold: None, + enc_cipher: None, + _writing: PhantomData, + }, }) } @@ -89,17 +130,19 @@ impl Connection { pub fn set_compression_threshold(&mut self, threshold: i32) { // if you pass a threshold of less than 0, compression is disabled if threshold >= 0 { - self.compression_threshold = Some(threshold as u32); + self.reader.compression_threshold = Some(threshold as u32); + self.writer.compression_threshold = Some(threshold as u32); } else { - self.compression_threshold = None; + self.reader.compression_threshold = None; + self.writer.compression_threshold = None; } } pub fn set_encryption_key(&mut self, key: [u8; 16]) { // minecraft has a cipher decoder and encoder, i don't think it matters though? let (enc_cipher, dec_cipher) = azalea_crypto::create_cipher(&key); - self.enc_cipher = Some(enc_cipher); - self.dec_cipher = Some(dec_cipher); + self.writer.enc_cipher = Some(enc_cipher); + self.reader.dec_cipher = Some(dec_cipher); } pub fn game(self) -> Connection { @@ -120,12 +163,18 @@ where W2: ProtocolPacket + Debug, { Connection { - stream: connection.stream, - compression_threshold: connection.compression_threshold, - enc_cipher: connection.enc_cipher, - dec_cipher: connection.dec_cipher, - _reading: PhantomData, - _writing: PhantomData, + reader: ReadConnection { + read_stream: connection.reader.read_stream, + compression_threshold: connection.reader.compression_threshold, + dec_cipher: connection.reader.dec_cipher, + _reading: PhantomData, + }, + writer: WriteConnection { + compression_threshold: connection.writer.compression_threshold, + write_stream: connection.writer.write_stream, + enc_cipher: connection.writer.enc_cipher, + _writing: PhantomData, + }, } } } diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 313fb412..3ff24f72 100755 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -221,10 +221,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::packets::{ - game::{clientbound_player_chat_packet::ChatType, ClientboundGamePacket}, - handshake::ClientboundHandshakePacket, - }; + use crate::packets::game::{clientbound_player_chat_packet::ChatType, ClientboundGamePacket}; use std::io::Cursor; #[tokio::test] -- cgit v1.2.3