diff options
Diffstat (limited to 'data-code-generator/packetcodegen.py')
| -rw-r--r-- | data-code-generator/packetcodegen.py | 42 |
1 files changed, 39 insertions, 3 deletions
diff --git a/data-code-generator/packetcodegen.py b/data-code-generator/packetcodegen.py index ccbb3845..4c59b72b 100644 --- a/data-code-generator/packetcodegen.py +++ b/data-code-generator/packetcodegen.py @@ -66,9 +66,7 @@ def burger_type_to_rust_type(burger_type): def write_packet_file(state, packet_name_snake_case, code): - path = os.path.join( - '..', f'azalea-protocol/src/packets/{state}/{packet_name_snake_case}.rs') - with open(path, 'w') as f: + with open(f'../azalea-protocol/src/packets/{state}/{packet_name_snake_case}.rs', 'w') as f: f.write(code) @@ -128,3 +126,41 @@ def generate(burger_packets, mappings: Mappings): write_packet_file(state, to_snake_case(class_name), '\n'.join(generated_packet_code)) print() + + mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs' + with open(mod_rs_dir, 'r') as f: + mod_rs = f.read().splitlines() + + pub_mod_line = f'pub mod {to_snake_case(class_name)};' + if pub_mod_line not in mod_rs: + mod_rs.insert(0, pub_mod_line) + packet_mod_rs_line = f' {hex(packet["id"])}: {to_snake_case(class_name)}::{to_camel_case(class_name)},' + + in_serverbound = False + in_clientbound = False + for i, line in enumerate(mod_rs): + if line.strip() == 'Serverbound => {': + in_serverbound = True + continue + elif line.strip() == 'Clientbound => {': + in_clientbound = True + continue + elif line.strip() in ('}', '},'): + if (in_serverbound and direction == 'serverbound') or (in_clientbound and direction == 'clientbound'): + mod_rs.insert(i, packet_mod_rs_line) + break + in_serverbound = in_clientbound = False + continue + + if line.strip() == '' or line.strip().startswith('//') or (not in_serverbound and direction == 'serverbound') or (not in_clientbound and direction == 'clientbound'): + continue + + line_packet_id_hex = line.strip().split(':')[0] + assert line_packet_id_hex.startswith('0x') + line_packet_id = int(line_packet_id_hex[2:], 16) + if line_packet_id > packet['id']: + mod_rs.insert(i, packet_mod_rs_line) + break + + with open(mod_rs_dir, 'w') as f: + f.write('\n'.join(mod_rs)) |
