diff options
Diffstat (limited to 'azalea-protocol/src/write.rs')
| -rw-r--r-- | azalea-protocol/src/write.rs | 84 |
1 files changed, 59 insertions, 25 deletions
diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs index bf9fd0aa..4ae9f1c1 100644 --- a/azalea-protocol/src/write.rs +++ b/azalea-protocol/src/write.rs @@ -1,31 +1,65 @@ -use tokio::{io::AsyncWriteExt, net::TcpStream}; +use std::io::Read; -use crate::{mc_buf::Writable, packets::ProtocolPacket}; +use crate::{mc_buf::Writable, packets::ProtocolPacket, read::MAXIMUM_UNCOMPRESSED_LENGTH}; +use async_compression::tokio::bufread::ZlibEncoder; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, +}; -pub async fn write_packet(packet: impl ProtocolPacket, stream: &mut TcpStream) { - // TODO: implement compression - - // packet structure: - // length (varint) + id (varint) + data - - // write the packet id - let mut id_and_data_buf = vec![]; - id_and_data_buf - .write_varint(packet.id() as i32) - .expect("Writing packet id failed"); - packet.write(&mut id_and_data_buf); +fn frame_prepender(data: &mut Vec<u8>) -> Result<Vec<u8>, String> { + let mut buf = Vec::new(); + buf.write_varint(data.len() as i32) + .map_err(|e| e.to_string())?; + buf.append(data); + Ok(buf) +} - // write the packet data +fn packet_encoder<P: ProtocolPacket + std::fmt::Debug>(packet: &P) -> Result<Vec<u8>, String> { + let mut buf = Vec::new(); + buf.write_varint(packet.id() as i32) + .map_err(|e| e.to_string())?; + packet.write(&mut buf); + if buf.len() > MAXIMUM_UNCOMPRESSED_LENGTH as usize { + return Err(format!( + "Packet too big (is {} bytes, should be less than {}): {:?}", + buf.len(), + MAXIMUM_UNCOMPRESSED_LENGTH, + packet + )); + } + Ok(buf) +} - // make a new buffer that has the length at the beginning - // and id+data at the end - let mut complete_buf: Vec<u8> = Vec::new(); - complete_buf - .write_varint(id_and_data_buf.len() as i32) - .expect("Writing packet length failed"); - complete_buf.append(&mut id_and_data_buf); +async fn compression_encoder(data: &[u8], compression_threshold: u32) -> Result<Vec<u8>, String> { + let n = data.len(); + // if it's less than the compression threshold, don't compress + if n < compression_threshold as usize { + let mut buf = Vec::new(); + buf.write_varint(0).map_err(|e| e.to_string())?; + buf.write_all(data).await.map_err(|e| e.to_string())?; + Ok(buf) + } else { + // otherwise, compress + let mut deflater = ZlibEncoder::new(data); + // write deflated data to buf + let mut buf = Vec::new(); + deflater + .read_to_end(&mut buf) + .await + .map_err(|e| e.to_string())?; + Ok(buf) + } +} - // finally, write and flush to the stream - stream.write_all(&complete_buf).await.unwrap(); - stream.flush().await.unwrap(); +pub async fn write_packet<P>(packet: P, stream: &mut TcpStream, compression_threshold: Option<u32>) +where + P: ProtocolPacket + std::fmt::Debug, +{ + let mut buf = packet_encoder(&packet).unwrap(); + if let Some(threshold) = compression_threshold { + buf = compression_encoder(&buf, threshold).await.unwrap(); + } + buf = frame_prepender(&mut buf).unwrap(); + stream.write_all(&buf).await.unwrap(); } |
