aboutsummaryrefslogtreecommitdiff
path: root/azalea-protocol/packet-macros/src
diff options
context:
space:
mode:
authormat <github@matdoes.dev>2022-01-01 23:55:19 -0600
committermat <github@matdoes.dev>2022-01-01 23:55:19 -0600
commita1afbb6031527c1db5831fc8e916bc0ecce633b4 (patch)
tree56fbccf52645cc2eefd231506ffe8ed71018dc01 /azalea-protocol/packet-macros/src
parente81b85dd5bdd6d42ee84f24ed4a142f6141f170e (diff)
downloadazalea-drasl-a1afbb6031527c1db5831fc8e916bc0ecce633b4.tar.xz
start adding packet macros
Diffstat (limited to 'azalea-protocol/packet-macros/src')
-rw-r--r--azalea-protocol/packet-macros/src/lib.rs164
1 files changed, 164 insertions, 0 deletions
diff --git a/azalea-protocol/packet-macros/src/lib.rs b/azalea-protocol/packet-macros/src/lib.rs
new file mode 100644
index 00000000..470ac4c1
--- /dev/null
+++ b/azalea-protocol/packet-macros/src/lib.rs
@@ -0,0 +1,164 @@
+use quote::{quote, quote_spanned, ToTokens};
+use syn::{self, parse_macro_input, spanned::Spanned, DeriveInput, FieldsNamed};
+
+fn as_packet_derive(
+ input: proc_macro::TokenStream,
+ state: proc_macro2::TokenStream,
+) -> proc_macro::TokenStream {
+ let DeriveInput { ident, data, .. } = parse_macro_input!(input);
+
+ let fields = match data {
+ syn::Data::Struct(syn::DataStruct { fields, .. }) => fields,
+ _ => panic!("#[derive(*Packet)] can only be used on structs"),
+ };
+ let FieldsNamed { named, .. } = match fields {
+ syn::Fields::Named(f) => f,
+ _ => panic!("#[derive(*Packet)] can only be used on structs with named fields"),
+ };
+
+ let write_fields = named
+ .iter()
+ .map(|f| {
+ let field_name = &f.ident;
+ let field_type = &f.ty;
+ // do a different buf.write_* for each field depending on the type
+ // if it's a string, use buf.write_string
+ match field_type {
+ syn::Type::Path(syn::TypePath { path, .. }) => {
+ if path.is_ident("String") {
+ quote! { buf.write_utf(&self.#field_name)?; }
+ } else if path.is_ident("ResourceLocation") {
+ quote! { buf.write_resource_location(&self.#field_name)?; }
+ // i don't know how to do this in a way that isn't terrible
+ } else if path.to_token_stream().to_string() == "Vec < u8 >" {
+ quote! { buf.write_bytes(&self.#field_name)?; }
+ } else if path.is_ident("i32") {
+ // only treat it as a varint if it has the varint attribute
+ if f.attrs.iter().any(|attr| attr.path.is_ident("varint")) {
+ quote! { buf.write_varint(self.#field_name)?; }
+ } else {
+ quote! { buf.write_i32(self.#field_name)?; }
+ }
+ } else if path.is_ident("u32") {
+ if f.attrs.iter().any(|attr| attr.path.is_ident("varint")) {
+ quote! { buf.write_varint(self.#field_name as i32)?; }
+ } else {
+ quote! { buf.write_u32(self.#field_name)?; }
+ }
+ } else if path.is_ident("u16") {
+ quote! { buf.write_short(self.#field_name as i16)?; }
+ } else if path.is_ident("ConnectionProtocol") {
+ quote! { buf.write_varint(self.#field_name.clone() as i32)?; }
+ } else {
+ panic!(
+ "#[derive(*Packet)] doesn't know how to write {}",
+ path.to_token_stream()
+ );
+ }
+ }
+ _ => panic!(
+ "Error writing field {}: {}",
+ field_name.clone().unwrap(),
+ field_type.to_token_stream()
+ ),
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let read_fields = named
+ .iter()
+ .map(|f| {
+ let field_name = &f.ident;
+ let field_type = &f.ty;
+ // do a different buf.write_* for each field depending on the type
+ // if it's a string, use buf.write_string
+ match field_type {
+ syn::Type::Path(syn::TypePath { path, .. }) => {
+ if path.is_ident("String") {
+ quote! { let #field_name = buf.read_utf().await?; }
+ } else if path.is_ident("ResourceLocation") {
+ quote! { let #field_name = buf.read_resource_location().await?; }
+ // i don't know how to do this in a way that isn't terrible
+ } else if path.to_token_stream().to_string() == "Vec < u8 >" {
+ quote! { let #field_name = buf.read_bytes().await?; }
+ } else if path.is_ident("i32") {
+ // only treat it as a varint if it has the varint attribute
+ if f.attrs.iter().any(|a| a.path.is_ident("varint")) {
+ quote! { let #field_name = buf.read_varint().await?; }
+ } else {
+ quote! { let #field_name = buf.read_i32().await?; }
+ }
+ } else if path.is_ident("u32") {
+ if f.attrs.iter().any(|a| a.path.is_ident("varint")) {
+ quote! { let #field_name = buf.read_varint().await? as u32; }
+ } else {
+ quote! { let #field_name = buf.read_u32().await?; }
+ }
+ } else if path.is_ident("u16") {
+ quote! { let #field_name = buf.read_short().await? as u16; }
+ } else if path.is_ident("ConnectionProtocol") {
+ quote! {
+ let #field_name = ConnectionProtocol::from_i32(buf.read_varint().await?)
+ .ok_or_else(|| "Invalid intention".to_string())?;
+ }
+ } else {
+ panic!(
+ "#[derive(*Packet)] doesn't know how to read {}",
+ path.to_token_stream()
+ );
+ }
+ }
+ _ => panic!(
+ "Error reading field {}: {}",
+ field_name.clone().unwrap(),
+ field_type.to_token_stream()
+ ),
+ }
+ })
+ .collect::<Vec<_>>();
+ let read_field_names = named.iter().map(|f| &f.ident).collect::<Vec<_>>();
+
+ let gen = quote! {
+ impl #ident {
+ pub fn get(self) -> #state {
+ #state::#ident(self)
+ }
+
+ pub fn write(&self, buf: &mut Vec<u8>) -> Result<(), std::io::Error> {
+ #(#write_fields)*
+ Ok(())
+ }
+
+ pub async fn read<T: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send>(
+ buf: &mut T,
+ ) -> Result<#state, String> {
+ #(#read_fields)*
+ Ok(#ident {
+ #(#read_field_names: #read_field_names),*
+ }.get())
+ }
+ }
+ };
+
+ gen.into()
+}
+
+#[proc_macro_derive(GamePacket, attributes(varint))]
+pub fn derive_game_packet(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ as_packet_derive(input, quote! {crate::packets::game::GamePacket})
+}
+
+#[proc_macro_derive(HandshakePacket, attributes(varint))]
+pub fn derive_handshake_packet(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ as_packet_derive(input, quote! {crate::packets::handshake::HandshakePacket})
+}
+
+#[proc_macro_derive(LoginPacket, attributes(varint))]
+pub fn derive_login_packet(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ as_packet_derive(input, quote! {crate::packets::login::LoginPacket})
+}
+
+#[proc_macro_derive(StatusPacket, attributes(varint))]
+pub fn derive_status_packet(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ as_packet_derive(input, quote! {crate::packets::status::StatusPacket})
+}