diff options
Diffstat (limited to 'codegen/lib/code/packet.py')
| -rw-r--r-- | codegen/lib/code/packet.py | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/codegen/lib/code/packet.py b/codegen/lib/code/packet.py new file mode 100644 index 00000000..36e0ba0c --- /dev/null +++ b/codegen/lib/code/packet.py @@ -0,0 +1,246 @@ +from .utils import burger_type_to_rust_type, write_packet_file +from ..utils import padded_hex, to_snake_case, to_camel_case +from ..mappings import Mappings +import os + + +def make_packet_mod_rs_line(packet_id: int, packet_class_name: str): + return f' {padded_hex(packet_id)}: {to_snake_case(packet_class_name)}::{to_camel_case(packet_class_name)},' + + +def fix_state(state: str): + return {'PLAY': 'game'}.get(state, state.lower()) + + +def generate_packet(burger_packets, mappings: Mappings, target_packet_id, target_packet_direction, target_packet_state): + for packet in burger_packets.values(): + if packet['id'] != target_packet_id: + continue + + direction = packet['direction'].lower() # serverbound or clientbound + state = fix_state(packet['state']) + + if state != target_packet_state or direction != target_packet_direction: + continue + + generated_packet_code = [] + uses = set() + generated_packet_code.append( + f'#[derive(Clone, Debug, McBuf, {to_camel_case(state)}Packet)]') + uses.add(f'packet_macros::{{{to_camel_case(state)}Packet, McBuf}}') + + obfuscated_class_name = packet['class'].split('.')[0].split('$')[0] + class_name = mappings.get_class( + obfuscated_class_name).split('.')[-1].split('$')[0] + + generated_packet_code.append( + f'pub struct {to_camel_case(class_name)} {{') + + for instruction in packet.get('instructions', []): + if instruction['operation'] == 'write': + obfuscated_field_name = instruction['field'] + if '.' in obfuscated_field_name or ' ' in obfuscated_field_name or '(' in obfuscated_field_name: + generated_packet_code.append(f'// TODO: {instruction}') + continue + field_name = mappings.get_field( + obfuscated_class_name, obfuscated_field_name) + if not field_name: + generated_packet_code.append( + f'// TODO: unknown field {instruction}') + continue + + field_type = instruction['type'] + field_type_rs, is_var, instruction_uses = burger_type_to_rust_type( + field_type) + if is_var: + generated_packet_code.append('#[var]') + generated_packet_code.append( + f'pub {to_snake_case(field_name)}: {field_type_rs},') + uses.update(instruction_uses) + else: + generated_packet_code.append(f'// TODO: {instruction}') + continue + + generated_packet_code.append('}') + + if uses: + # empty line before the `use` statements + generated_packet_code.insert(0, '') + for use in uses: + generated_packet_code.insert(0, f'use {use};') + + print(generated_packet_code) + write_packet_file(state, to_snake_case(class_name), + '\n'.join(generated_packet_code)) + print() + + mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs' + with open(mod_rs_dir, 'r') as f: + mod_rs = f.read().splitlines() + + pub_mod_line = f'pub mod {to_snake_case(class_name)};' + if pub_mod_line not in mod_rs: + mod_rs.insert(0, pub_mod_line) + packet_mod_rs_line = make_packet_mod_rs_line( + packet['id'], class_name) + + in_serverbound = False + in_clientbound = False + for i, line in enumerate(mod_rs): + if line.strip() == 'Serverbound => {': + in_serverbound = True + continue + elif line.strip() == 'Clientbound => {': + in_clientbound = True + continue + elif line.strip() in ('}', '},'): + if (in_serverbound and direction == 'serverbound') or (in_clientbound and direction == 'clientbound'): + mod_rs.insert(i, packet_mod_rs_line) + break + in_serverbound = in_clientbound = False + continue + + if line.strip() == '' or line.strip().startswith('//') or (not in_serverbound and direction == 'serverbound') or (not in_clientbound and direction == 'clientbound'): + continue + + line_packet_id_hex = line.strip().split(':')[0] + assert line_packet_id_hex.startswith('0x') + line_packet_id = int(line_packet_id_hex[2:], 16) + if line_packet_id > packet['id']: + mod_rs.insert(i, packet_mod_rs_line) + break + + with open(mod_rs_dir, 'w') as f: + f.write('\n'.join(mod_rs)) + + +def set_packets(packet_ids: list[int], packet_class_names: list[str], direction: str, state: str): + assert len(packet_ids) == len(packet_class_names) + + # sort the packets by id + packet_ids, packet_class_names = [list(x) for x in zip( + *sorted(zip(packet_ids, packet_class_names), key=lambda pair: pair[0]))] # type: ignore + + mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs' + with open(mod_rs_dir, 'r') as f: + mod_rs = f.read().splitlines() + new_mod_rs = [] + + required_modules = [] + + ignore_lines = False + + for line in mod_rs: + if line.strip() == 'Serverbound => {': + new_mod_rs.append(line) + if direction == 'serverbound': + ignore_lines = True + for packet_id, packet_class_name in zip(packet_ids, packet_class_names): + new_mod_rs.append( + make_packet_mod_rs_line(packet_id, packet_class_name) + ) + required_modules.append(packet_class_name) + else: + ignore_lines = False + continue + elif line.strip() == 'Clientbound => {': + new_mod_rs.append(line) + if direction == 'clientbound': + ignore_lines = True + for packet_id, packet_class_name in zip(packet_ids, packet_class_names): + new_mod_rs.append( + make_packet_mod_rs_line(packet_id, packet_class_name) + ) + else: + ignore_lines = False + continue + elif line.strip() in ('}', '},'): + ignore_lines = False + elif line.strip().startswith('pub mod '): + continue + + if not ignore_lines: + new_mod_rs.append(line) + # 0x00: clientbound_status_response_packet::ClientboundStatusResponsePacket, + if line.strip().startswith('0x'): + required_modules.append( + line.strip().split(':')[1].split('::')[0].strip()) + + for i, required_module in enumerate(required_modules): + if required_module not in mod_rs: + new_mod_rs.insert(i, f'pub mod {required_module};') + + with open(mod_rs_dir, 'w') as f: + f.write('\n'.join(new_mod_rs)) + + +def get_packets(direction: str, state: str): + mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs' + with open(mod_rs_dir, 'r') as f: + mod_rs = f.read().splitlines() + + in_serverbound = False + in_clientbound = False + + packet_ids: list[int] = [] + packet_class_names: list[str] = [] + + for line in mod_rs: + if line.strip() == 'Serverbound => {': + in_serverbound = True + continue + elif line.strip() == 'Clientbound => {': + in_clientbound = True + continue + elif line.strip() in ('}', '},'): + if (in_serverbound and direction == 'serverbound') or (in_clientbound and direction == 'clientbound'): + break + in_serverbound = in_clientbound = False + continue + + if line.strip() == '' or line.strip().startswith('//') or (not in_serverbound and direction == 'serverbound') or (not in_clientbound and direction == 'clientbound'): + continue + + line_packet_id_hex = line.strip().split(':')[0] + assert line_packet_id_hex.startswith('0x') + line_packet_id = int(line_packet_id_hex[2:], 16) + packet_ids.append(line_packet_id) + + packet_class_name = line.strip().split(':')[1].strip() + packet_class_names.append(packet_class_name) + + return packet_ids, packet_class_names + + +def change_packet_ids(id_map: dict[int, int], direction: str, state: str): + existing_packet_ids, existing_packet_class_names = get_packets( + direction, state) + + new_packet_ids = [] + + for packet_id in existing_packet_ids: + new_packet_id = id_map.get(packet_id, packet_id) + new_packet_ids.append(new_packet_id) + + set_packets(new_packet_ids, existing_packet_class_names, direction, state) + + +def remove_packet_ids(removing_packet_ids: list[int], direction: str, state: str): + existing_packet_ids, existing_packet_class_names = get_packets( + direction, state) + + new_packet_ids = [] + new_packet_class_names = [] + + for packet_id, packet_class_name in zip(existing_packet_ids, existing_packet_class_names): + if packet_id in removing_packet_ids: + try: + os.remove( + f'../azalea-protocol/src/packets/{state}/{packet_class_name}.rs') + except: + pass + else: + new_packet_ids.append(packet_id) + new_packet_class_names.append(packet_class_name) + + set_packets(new_packet_ids, new_packet_class_names, direction, state) |
