aboutsummaryrefslogtreecommitdiff
path: root/codegen/lib
diff options
context:
space:
mode:
authormat <27899617+mat-1@users.noreply.github.com>2022-06-08 23:37:54 +0000
committerGitHub <noreply@github.com>2022-06-08 23:37:54 +0000
commit601637bd48fcba826da01725430268f706181449 (patch)
tree5b58723b931450d358d7e4387d87cc8e8b9166b2 /codegen/lib
parentea7249fb77a8e07d232600081c9c3df5f698d70f (diff)
parentfb1d419a3d4207a293a1ad6001253192f1b4d12f (diff)
downloadazalea-drasl-601637bd48fcba826da01725430268f706181449.tar.xz
Merge pull request #7 from mat-1/1.19
1.19
Diffstat (limited to 'codegen/lib')
-rw-r--r--codegen/lib/code/packet.py246
-rw-r--r--codegen/lib/code/utils.py75
-rw-r--r--codegen/lib/code/version.py59
-rw-r--r--codegen/lib/download.py90
-rw-r--r--codegen/lib/mappings.py60
-rw-r--r--codegen/lib/utils.py46
6 files changed, 576 insertions, 0 deletions
diff --git a/codegen/lib/code/packet.py b/codegen/lib/code/packet.py
new file mode 100644
index 00000000..36e0ba0c
--- /dev/null
+++ b/codegen/lib/code/packet.py
@@ -0,0 +1,246 @@
+from .utils import burger_type_to_rust_type, write_packet_file
+from ..utils import padded_hex, to_snake_case, to_camel_case
+from ..mappings import Mappings
+import os
+
+
+def make_packet_mod_rs_line(packet_id: int, packet_class_name: str):
+ return f' {padded_hex(packet_id)}: {to_snake_case(packet_class_name)}::{to_camel_case(packet_class_name)},'
+
+
+def fix_state(state: str):
+ return {'PLAY': 'game'}.get(state, state.lower())
+
+
+def generate_packet(burger_packets, mappings: Mappings, target_packet_id, target_packet_direction, target_packet_state):
+ for packet in burger_packets.values():
+ if packet['id'] != target_packet_id:
+ continue
+
+ direction = packet['direction'].lower() # serverbound or clientbound
+ state = fix_state(packet['state'])
+
+ if state != target_packet_state or direction != target_packet_direction:
+ continue
+
+ generated_packet_code = []
+ uses = set()
+ generated_packet_code.append(
+ f'#[derive(Clone, Debug, McBuf, {to_camel_case(state)}Packet)]')
+ uses.add(f'packet_macros::{{{to_camel_case(state)}Packet, McBuf}}')
+
+ obfuscated_class_name = packet['class'].split('.')[0].split('$')[0]
+ class_name = mappings.get_class(
+ obfuscated_class_name).split('.')[-1].split('$')[0]
+
+ generated_packet_code.append(
+ f'pub struct {to_camel_case(class_name)} {{')
+
+ for instruction in packet.get('instructions', []):
+ if instruction['operation'] == 'write':
+ obfuscated_field_name = instruction['field']
+ if '.' in obfuscated_field_name or ' ' in obfuscated_field_name or '(' in obfuscated_field_name:
+ generated_packet_code.append(f'// TODO: {instruction}')
+ continue
+ field_name = mappings.get_field(
+ obfuscated_class_name, obfuscated_field_name)
+ if not field_name:
+ generated_packet_code.append(
+ f'// TODO: unknown field {instruction}')
+ continue
+
+ field_type = instruction['type']
+ field_type_rs, is_var, instruction_uses = burger_type_to_rust_type(
+ field_type)
+ if is_var:
+ generated_packet_code.append('#[var]')
+ generated_packet_code.append(
+ f'pub {to_snake_case(field_name)}: {field_type_rs},')
+ uses.update(instruction_uses)
+ else:
+ generated_packet_code.append(f'// TODO: {instruction}')
+ continue
+
+ generated_packet_code.append('}')
+
+ if uses:
+ # empty line before the `use` statements
+ generated_packet_code.insert(0, '')
+ for use in uses:
+ generated_packet_code.insert(0, f'use {use};')
+
+ print(generated_packet_code)
+ write_packet_file(state, to_snake_case(class_name),
+ '\n'.join(generated_packet_code))
+ print()
+
+ mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs'
+ with open(mod_rs_dir, 'r') as f:
+ mod_rs = f.read().splitlines()
+
+ pub_mod_line = f'pub mod {to_snake_case(class_name)};'
+ if pub_mod_line not in mod_rs:
+ mod_rs.insert(0, pub_mod_line)
+ packet_mod_rs_line = make_packet_mod_rs_line(
+ packet['id'], class_name)
+
+ in_serverbound = False
+ in_clientbound = False
+ for i, line in enumerate(mod_rs):
+ if line.strip() == 'Serverbound => {':
+ in_serverbound = True
+ continue
+ elif line.strip() == 'Clientbound => {':
+ in_clientbound = True
+ continue
+ elif line.strip() in ('}', '},'):
+ if (in_serverbound and direction == 'serverbound') or (in_clientbound and direction == 'clientbound'):
+ mod_rs.insert(i, packet_mod_rs_line)
+ break
+ in_serverbound = in_clientbound = False
+ continue
+
+ if line.strip() == '' or line.strip().startswith('//') or (not in_serverbound and direction == 'serverbound') or (not in_clientbound and direction == 'clientbound'):
+ continue
+
+ line_packet_id_hex = line.strip().split(':')[0]
+ assert line_packet_id_hex.startswith('0x')
+ line_packet_id = int(line_packet_id_hex[2:], 16)
+ if line_packet_id > packet['id']:
+ mod_rs.insert(i, packet_mod_rs_line)
+ break
+
+ with open(mod_rs_dir, 'w') as f:
+ f.write('\n'.join(mod_rs))
+
+
+def set_packets(packet_ids: list[int], packet_class_names: list[str], direction: str, state: str):
+ assert len(packet_ids) == len(packet_class_names)
+
+ # sort the packets by id
+ packet_ids, packet_class_names = [list(x) for x in zip(
+ *sorted(zip(packet_ids, packet_class_names), key=lambda pair: pair[0]))] # type: ignore
+
+ mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs'
+ with open(mod_rs_dir, 'r') as f:
+ mod_rs = f.read().splitlines()
+ new_mod_rs = []
+
+ required_modules = []
+
+ ignore_lines = False
+
+ for line in mod_rs:
+ if line.strip() == 'Serverbound => {':
+ new_mod_rs.append(line)
+ if direction == 'serverbound':
+ ignore_lines = True
+ for packet_id, packet_class_name in zip(packet_ids, packet_class_names):
+ new_mod_rs.append(
+ make_packet_mod_rs_line(packet_id, packet_class_name)
+ )
+ required_modules.append(packet_class_name)
+ else:
+ ignore_lines = False
+ continue
+ elif line.strip() == 'Clientbound => {':
+ new_mod_rs.append(line)
+ if direction == 'clientbound':
+ ignore_lines = True
+ for packet_id, packet_class_name in zip(packet_ids, packet_class_names):
+ new_mod_rs.append(
+ make_packet_mod_rs_line(packet_id, packet_class_name)
+ )
+ else:
+ ignore_lines = False
+ continue
+ elif line.strip() in ('}', '},'):
+ ignore_lines = False
+ elif line.strip().startswith('pub mod '):
+ continue
+
+ if not ignore_lines:
+ new_mod_rs.append(line)
+ # 0x00: clientbound_status_response_packet::ClientboundStatusResponsePacket,
+ if line.strip().startswith('0x'):
+ required_modules.append(
+ line.strip().split(':')[1].split('::')[0].strip())
+
+ for i, required_module in enumerate(required_modules):
+ if required_module not in mod_rs:
+ new_mod_rs.insert(i, f'pub mod {required_module};')
+
+ with open(mod_rs_dir, 'w') as f:
+ f.write('\n'.join(new_mod_rs))
+
+
+def get_packets(direction: str, state: str):
+ mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs'
+ with open(mod_rs_dir, 'r') as f:
+ mod_rs = f.read().splitlines()
+
+ in_serverbound = False
+ in_clientbound = False
+
+ packet_ids: list[int] = []
+ packet_class_names: list[str] = []
+
+ for line in mod_rs:
+ if line.strip() == 'Serverbound => {':
+ in_serverbound = True
+ continue
+ elif line.strip() == 'Clientbound => {':
+ in_clientbound = True
+ continue
+ elif line.strip() in ('}', '},'):
+ if (in_serverbound and direction == 'serverbound') or (in_clientbound and direction == 'clientbound'):
+ break
+ in_serverbound = in_clientbound = False
+ continue
+
+ if line.strip() == '' or line.strip().startswith('//') or (not in_serverbound and direction == 'serverbound') or (not in_clientbound and direction == 'clientbound'):
+ continue
+
+ line_packet_id_hex = line.strip().split(':')[0]
+ assert line_packet_id_hex.startswith('0x')
+ line_packet_id = int(line_packet_id_hex[2:], 16)
+ packet_ids.append(line_packet_id)
+
+ packet_class_name = line.strip().split(':')[1].strip()
+ packet_class_names.append(packet_class_name)
+
+ return packet_ids, packet_class_names
+
+
+def change_packet_ids(id_map: dict[int, int], direction: str, state: str):
+ existing_packet_ids, existing_packet_class_names = get_packets(
+ direction, state)
+
+ new_packet_ids = []
+
+ for packet_id in existing_packet_ids:
+ new_packet_id = id_map.get(packet_id, packet_id)
+ new_packet_ids.append(new_packet_id)
+
+ set_packets(new_packet_ids, existing_packet_class_names, direction, state)
+
+
+def remove_packet_ids(removing_packet_ids: list[int], direction: str, state: str):
+ existing_packet_ids, existing_packet_class_names = get_packets(
+ direction, state)
+
+ new_packet_ids = []
+ new_packet_class_names = []
+
+ for packet_id, packet_class_name in zip(existing_packet_ids, existing_packet_class_names):
+ if packet_id in removing_packet_ids:
+ try:
+ os.remove(
+ f'../azalea-protocol/src/packets/{state}/{packet_class_name}.rs')
+ except:
+ pass
+ else:
+ new_packet_ids.append(packet_id)
+ new_packet_class_names.append(packet_class_name)
+
+ set_packets(new_packet_ids, new_packet_class_names, direction, state)
diff --git a/codegen/lib/code/utils.py b/codegen/lib/code/utils.py
new file mode 100644
index 00000000..28a5ef3c
--- /dev/null
+++ b/codegen/lib/code/utils.py
@@ -0,0 +1,75 @@
+
+import os
+
+# utilities specifically for codegen
+
+
+def burger_type_to_rust_type(burger_type):
+ is_var = False
+ uses = set()
+
+ if burger_type == 'byte':
+ field_type_rs = 'i8'
+ elif burger_type == 'short':
+ field_type_rs = 'i16'
+ elif burger_type == 'int':
+ field_type_rs = 'i32'
+ elif burger_type == 'long':
+ field_type_rs = 'i64'
+ elif burger_type == 'float':
+ field_type_rs = 'f32'
+ elif burger_type == 'double':
+ field_type_rs = 'f64'
+
+ elif burger_type == 'varint':
+ is_var = True
+ field_type_rs = 'i32'
+ elif burger_type == 'varlong':
+ is_var = True
+ field_type_rs = 'i64'
+
+ elif burger_type == 'boolean':
+ field_type_rs = 'bool'
+ elif burger_type == 'string':
+ field_type_rs = 'String'
+
+ elif burger_type == 'chatcomponent':
+ field_type_rs = 'Component'
+ uses.add('azalea_chat::component::Component')
+ elif burger_type == 'identifier':
+ field_type_rs = 'ResourceLocation'
+ uses.add('azalea_core::resource_location::ResourceLocation')
+ elif burger_type == 'uuid':
+ field_type_rs = 'Uuid'
+ uses.add('uuid::Uuid')
+ elif burger_type == 'position':
+ field_type_rs = 'BlockPos'
+ uses.add('azalea_core::BlockPos')
+ elif burger_type == 'nbtcompound':
+ field_type_rs = 'azalea_nbt::Tag'
+ elif burger_type == 'itemstack':
+ field_type_rs = 'Slot'
+ uses.add('azalea_core::Slot')
+ elif burger_type == 'metadata':
+ field_type_rs = 'EntityMetadata'
+ uses.add('crate::mc_buf::EntityMetadata')
+ elif burger_type == 'enum':
+ # enums are too complicated, leave those to the user
+ field_type_rs = 'todo!()'
+ elif burger_type.endswith('[]'):
+ field_type_rs, is_var, uses = burger_type_to_rust_type(
+ burger_type[:-2])
+ field_type_rs = f'Vec<{field_type_rs}>'
+ else:
+ print('Unknown field type:', burger_type)
+ exit()
+ return field_type_rs, is_var, uses
+
+
+def write_packet_file(state, packet_name_snake_case, code):
+ with open(f'../azalea-protocol/src/packets/{state}/{packet_name_snake_case}.rs', 'w') as f:
+ f.write(code)
+
+
+def fmt():
+ os.system('cd .. && cargo fmt')
diff --git a/codegen/lib/code/version.py b/codegen/lib/code/version.py
new file mode 100644
index 00000000..e131a598
--- /dev/null
+++ b/codegen/lib/code/version.py
@@ -0,0 +1,59 @@
+import re
+import os
+
+README_DIR = os.path.join(os.path.dirname(__file__), '../../../README.md')
+VERSION_REGEX = r'\*Currently supported Minecraft version: `(.*)`.\*'
+
+
+def get_version_id() -> str:
+ with open(README_DIR, 'r') as f:
+ readme_text = f.read()
+
+ version_line_match = re.search(VERSION_REGEX, readme_text)
+ if version_line_match:
+ version_id = version_line_match.group(1)
+ return version_id
+ else:
+ raise Exception('Could not find version id in README.md')
+
+
+def set_version_id(version_id: str) -> None:
+ with open(README_DIR, 'r') as f:
+ readme_text = f.read()
+
+ version_line_match = re.search(VERSION_REGEX, readme_text)
+ if version_line_match:
+ readme_text = readme_text.replace(
+ version_line_match.group(1), version_id)
+ else:
+ raise Exception('Could not find version id in README.md')
+
+ with open(README_DIR, 'w') as f:
+ f.write(readme_text)
+
+
+def get_protocol_version() -> str:
+ # azalea-protocol/src/packets/mod.rs
+ # pub const PROTOCOL_VERSION: u32 = 758;
+ with open('../azalea-protocol/src/packets/mod.rs', 'r') as f:
+ mod_rs = f.read().splitlines()
+ for line in mod_rs:
+ if line.strip().startswith('pub const PROTOCOL_VERSION'):
+ return line.strip().split(' ')[-1].strip(';')
+ raise Exception(
+ 'Could not find protocol version in azalea-protocol/src/packets/mod.rs')
+
+
+def set_protocol_version(protocol_version: str) -> None:
+ with open('../azalea-protocol/src/packets/mod.rs', 'r') as f:
+ mod_rs = f.read().splitlines()
+ for i, line in enumerate(mod_rs):
+ if line.strip().startswith('pub const PROTOCOL_VERSION'):
+ mod_rs[i] = f'pub const PROTOCOL_VERSION: u32 = {protocol_version};'
+ break
+ else:
+ raise Exception(
+ 'Could not find protocol version in azalea-protocol/src/packets/mod.rs')
+
+ with open('../azalea-protocol/src/packets/mod.rs', 'w') as f:
+ f.write('\n'.join(mod_rs))
diff --git a/codegen/lib/download.py b/codegen/lib/download.py
new file mode 100644
index 00000000..7d14a3a3
--- /dev/null
+++ b/codegen/lib/download.py
@@ -0,0 +1,90 @@
+from .mappings import Mappings
+import requests
+import json
+import os
+
+# make sure the downloads directory exists
+if not os.path.exists('downloads'):
+ os.mkdir('downloads')
+
+
+def get_burger():
+ if not os.path.exists('downloads/Burger'):
+ print('\033[92mDownloading Burger...\033[m')
+ os.system(
+ 'cd downloads && git clone https://github.com/pokechu22/Burger && cd Burger && git pull')
+
+ print('\033[92mInstalling dependencies...\033[m')
+ os.system('cd downloads/Burger && pip install six jawa')
+
+
+def get_version_manifest():
+ if not os.path.exists(f'downloads/version_manifest.json'):
+ print(
+ f'\033[92mDownloading version manifest...\033[m')
+ version_manifest_data = requests.get(
+ 'https://launchermeta.mojang.com/mc/game/version_manifest.json').json()
+ with open(f'downloads/version_manifest.json', 'w') as f:
+ json.dump(version_manifest_data, f)
+ else:
+ with open(f'downloads/version_manifest.json', 'r') as f:
+ version_manifest_data = json.load(f)
+ return version_manifest_data
+
+
+def get_version_data(version_id: str):
+ if not os.path.exists(f'downloads/{version_id}.json'):
+ version_manifest_data = get_version_manifest()
+
+ print(
+ f'\033[92mGetting data for \033[1m{version_id}..\033[m')
+ try:
+ package_url = next(
+ filter(lambda v: v['id'] == version_id, version_manifest_data['versions']))['url']
+ except StopIteration:
+ raise ValueError(
+ f'No version with id {version_id} found. Maybe delete downloads/version_manifest.json and try again?')
+ package_data = requests.get(package_url).json()
+ with open(f'downloads/{version_id}.json', 'w') as f:
+ json.dump(package_data, f)
+ else:
+ with open(f'downloads/{version_id}.json', 'r') as f:
+ package_data = json.load(f)
+ return package_data
+
+
+def get_client_jar(version_id: str):
+ if not os.path.exists(f'downloads/client-{version_id}.jar'):
+ package_data = get_version_data(version_id)
+ print('\033[92mDownloading client jar...\033[m')
+ client_jar_url = package_data['downloads']['client']['url']
+ with open(f'downloads/client-{version_id}.jar', 'wb') as f:
+ f.write(requests.get(client_jar_url).content)
+
+
+def get_burger_data_for_version(version_id: str):
+ if not os.path.exists(f'downloads/burger-{version_id}.json'):
+ get_burger()
+ get_client_jar(version_id)
+
+ os.system(
+ f'cd downloads/Burger && python munch.py ../client-{version_id}.jar --output ../burger-{version_id}.json'
+ )
+ with open(f'downloads/burger-{version_id}.json', 'r') as f:
+ return json.load(f)
+
+
+def get_mappings_for_version(version_id: str):
+ if not os.path.exists(f'downloads/mappings-{version_id}.txt'):
+ package_data = get_version_data(version_id)
+
+ client_mappings_url = package_data['downloads']['client_mappings']['url']
+
+ mappings_text = requests.get(client_mappings_url).text
+
+ with open(f'downloads/mappings-{version_id}.txt', 'w') as f:
+ f.write(mappings_text)
+ else:
+ with open(f'downloads/mappings-{version_id}.txt', 'r') as f:
+ mappings_text = f.read()
+ return Mappings.parse(mappings_text)
diff --git a/codegen/lib/mappings.py b/codegen/lib/mappings.py
new file mode 100644
index 00000000..fb3e8bda
--- /dev/null
+++ b/codegen/lib/mappings.py
@@ -0,0 +1,60 @@
+class Mappings:
+ __slots__ = ('classes', 'fields', 'methods')
+
+ def __init__(self, classes, fields, methods):
+ self.classes = classes
+ self.fields = fields
+ self.methods = methods
+
+ @staticmethod
+ def parse(mappings_txt):
+ classes = {}
+ fields = {}
+ methods = {}
+
+ current_obfuscated_class_name = None
+
+ for line in mappings_txt.splitlines():
+ if line.startswith('#') or line == '':
+ continue
+
+ if line.startswith(' '):
+ # if a line starts with 4 spaces, that means it's a method or a field
+ if '(' in line:
+ # if it has an opening parenthesis, it's a method
+ real_name_with_parameters_and_line, obfuscated_name = line.strip().split(' -> ')
+ real_name_with_parameters = real_name_with_parameters_and_line.split(
+ ':')[-1]
+
+ real_name = real_name_with_parameters.split('(')[0]
+ parameters = real_name_with_parameters.split('(')[1]
+
+ if current_obfuscated_class_name not in methods:
+ methods[current_obfuscated_class_name] = {}
+ methods[current_obfuscated_class_name][
+ f'{obfuscated_name}({parameters})'] = real_name
+ else:
+ # otherwise, it's a field
+ real_name_with_type, obfuscated_name = line.strip().split(' -> ')
+ real_name = real_name_with_type.split(' ')[1]
+
+ if current_obfuscated_class_name not in fields:
+ fields[current_obfuscated_class_name] = {}
+ fields[current_obfuscated_class_name][obfuscated_name] = real_name
+ else:
+ # otherwise it's a class
+ real_name, obfuscated_name = line.strip(':').split(' -> ')
+ current_obfuscated_class_name = obfuscated_name
+
+ classes[obfuscated_name] = real_name
+
+ return Mappings(classes, fields, methods)
+
+ def get_field(self, obfuscated_class_name, obfuscated_field_name):
+ return self.fields.get(obfuscated_class_name, {}).get(obfuscated_field_name)
+
+ def get_class(self, obfuscated_class_name):
+ return self.classes[obfuscated_class_name]
+
+ def get_method(self, obfuscated_class_name, obfuscated_method_name, obfuscated_signature):
+ return self.methods[obfuscated_class_name][f'{obfuscated_method_name}({obfuscated_signature})']
diff --git a/codegen/lib/utils.py b/codegen/lib/utils.py
new file mode 100644
index 00000000..c185c0e5
--- /dev/null
+++ b/codegen/lib/utils.py
@@ -0,0 +1,46 @@
+import re
+
+# utilities that could be used for things other than codegen
+
+
+def to_snake_case(name: str):
+ s = re.sub('([A-Z])', r'_\1', name)
+ return s.lower().strip('_')
+
+
+def to_camel_case(name: str):
+ s = re.sub('_([a-z])', lambda m: m.group(1).upper(), name)
+ return s[0].upper() + s[1:]
+
+
+def padded_hex(n: int):
+ return f'0x{n:02x}'
+
+
+class PacketIdentifier:
+ def __init__(self, packet_id: int, direction: str, state: str):
+ self.packet_id = packet_id
+ self.direction = direction
+ self.state = state
+
+ def __eq__(self, other):
+ return self.packet_id == other.packet_id and self.direction == other.direction and self.state == other.state
+
+ def __hash__(self):
+ return hash((self.packet_id, self.direction, self.state))
+
+ def __str__(self):
+ return f'{self.packet_id} {self.direction} {self.state}'
+
+ def __repr__(self):
+ return f'PacketIdentifier({self.packet_id}, {self.direction}, {self.state})'
+
+
+def group_packets(packets: list[PacketIdentifier]):
+ packet_groups: dict[tuple[str, str], list[int]] = {}
+ for packet in packets:
+ key = (packet.direction, packet.state)
+ if key not in packet_groups:
+ packet_groups[key] = []
+ packet_groups[key].append(packet.packet_id)
+ return packet_groups