diff options
| author | mat <github@matdoes.dev> | 2022-05-25 00:21:05 -0500 |
|---|---|---|
| committer | mat <github@matdoes.dev> | 2022-05-25 00:21:05 -0500 |
| commit | 479c05474704a5a2f68b79468d2cde05c0ceec62 (patch) | |
| tree | 59e7192b74caac1566993b202419824993167df0 | |
| parent | fb3b002d94076de463e2af776666387db9e75835 (diff) | |
| download | azalea-drasl-479c05474704a5a2f68b79468d2cde05c0ceec62.tar.xz | |
Migrate might be working
| -rw-r--r-- | codegen/lib/code/packet.py | 72 | ||||
| -rw-r--r-- | codegen/lib/code/utils.py | 2 | ||||
| -rw-r--r-- | codegen/lib/utils.py | 25 | ||||
| -rw-r--r-- | codegen/migrate.py | 51 | ||||
| -rw-r--r-- | codegen/newpacket.py | 4 |
5 files changed, 132 insertions, 22 deletions
diff --git a/codegen/lib/code/packet.py b/codegen/lib/code/packet.py index 0d3ad138..59c773c1 100644 --- a/codegen/lib/code/packet.py +++ b/codegen/lib/code/packet.py @@ -7,7 +7,7 @@ def make_packet_mod_rs_line(packet_id: int, packet_class_name: str): return f' {padded_hex(packet_id)}: {to_snake_case(packet_class_name)}::{to_camel_case(packet_class_name)},' -def generate(burger_packets, mappings: Mappings, target_packet_id, target_packet_direction, target_packet_state): +def generate_packet(burger_packets, mappings: Mappings, target_packet_id, target_packet_direction, target_packet_state): for packet in burger_packets.values(): if packet['id'] != target_packet_id: continue @@ -109,9 +109,13 @@ def generate(burger_packets, mappings: Mappings, target_packet_id, target_packet f.write('\n'.join(mod_rs)) -def set_packet_ids(packet_ids: list, packet_class_names: list, direction: str, state: str): +def set_packets(packet_ids: list, packet_class_names: list, direction: str, state: str): assert len(packet_ids) == len(packet_class_names) + # sort the packets by id + packet_ids, packet_class_names = [list(x) for x in zip( + *sorted(zip(packet_ids, packet_class_names), key=lambda pair: pair[0]))] + mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs' with open(mod_rs_dir, 'r') as f: mod_rs = f.read().splitlines() @@ -146,3 +150,67 @@ def set_packet_ids(packet_ids: list, packet_class_names: list, direction: str, s with open(mod_rs_dir, 'w') as f: f.write('\n'.join(new_mod_rs)) + + +def get_packets(direction: str, state: str): + mod_rs_dir = f'../azalea-protocol/src/packets/{state}/mod.rs' + with open(mod_rs_dir, 'r') as f: + mod_rs = f.read().splitlines() + + in_serverbound = False + in_clientbound = False + + packet_ids: list[int] = [] + packet_class_names: list[str] = [] + + for line in mod_rs: + if line.strip() == 'Serverbound => {': + in_serverbound = True + continue + elif line.strip() == 'Clientbound => {': + in_clientbound = True + continue + elif line.strip() in ('}', '},'): + if (in_serverbound and direction == 'serverbound') or (in_clientbound and direction == 'clientbound'): + break + in_serverbound = in_clientbound = False + continue + + if line.strip() == '' or line.strip().startswith('//') or (not in_serverbound and direction == 'serverbound') or (not in_clientbound and direction == 'clientbound'): + continue + + line_packet_id_hex = line.strip().split(':')[0] + assert line_packet_id_hex.startswith('0x') + line_packet_id = int(line_packet_id_hex[2:], 16) + packet_ids.append(line_packet_id) + + packet_class_name = line.strip().split(':')[1].strip() + packet_class_names.append(packet_class_name) + + return packet_ids, packet_class_names + + +def change_packet_ids(id_map: dict[int, int], direction: str, state: str): + existing_packet_ids, existing_packet_class_names = get_packets( + direction, state) + + new_packet_ids = [] + + for packet_id in existing_packet_ids: + new_packet_id = id_map.get(packet_id, packet_id) + new_packet_ids.append(new_packet_id) + + set_packets(new_packet_ids, existing_packet_class_names, direction, state) + + +def remove_packet_ids(packet_ids: list[int], direction: str, state: str): + existing_packet_ids, existing_packet_class_names = get_packets( + direction, state) + + new_packet_ids = [] + + for packet_id in existing_packet_ids: + if packet_id not in packet_ids: + new_packet_ids.append(packet_id) + + set_packets(new_packet_ids, existing_packet_class_names, direction, state) diff --git a/codegen/lib/code/utils.py b/codegen/lib/code/utils.py index 92d1a9e9..28a5ef3c 100644 --- a/codegen/lib/code/utils.py +++ b/codegen/lib/code/utils.py @@ -1,6 +1,8 @@ import os +# utilities specifically for codegen + def burger_type_to_rust_type(burger_type): is_var = False diff --git a/codegen/lib/utils.py b/codegen/lib/utils.py index 051ffe51..ff1a5d36 100644 --- a/codegen/lib/utils.py +++ b/codegen/lib/utils.py @@ -1,5 +1,7 @@ import re +# utilities that could be used for things other than codegen + def to_snake_case(name: str): s = re.sub('([A-Z])', r'_\1', name) @@ -13,3 +15,26 @@ def to_camel_case(name: str): def padded_hex(n: int): return f'0x{n:02x}' + + +class PacketIdentifier: + def __init__(self, packet_id, direction, state): + self.packet_id = packet_id + self.direction = direction + self.state = state + + def __eq__(self, other): + return self.packet_id == other.packet_id and self.direction == other.direction and self.state == other.state + + def __hash__(self): + return hash((self.packet_id, self.direction, self.state)) + + +def group_packets(packets: list[PacketIdentifier]): + packet_groups: dict[tuple[str, str], list[int]] = {} + for packet in packets: + key = (packet.direction, packet.state) + if key not in packet_groups: + packet_groups[key] = [] + packet_groups[key].append(packet.packet_id) + return packet_groups diff --git a/codegen/migrate.py b/codegen/migrate.py index c0748400..6928cea1 100644 --- a/codegen/migrate.py +++ b/codegen/migrate.py @@ -1,5 +1,7 @@ +from codegen.lib.utils import PacketIdentifier, group_packets import lib.code.utils import lib.code.version +import lib.code.packet import lib.download import sys import os @@ -14,39 +16,52 @@ new_mappings = lib.download.get_mappings_for_version(new_version_id) new_burger_data = lib.download.get_burger_data_for_version(new_version_id) new_packet_list = list(new_burger_data[0]['packets']['packet'].values()) -old_packet_ids = {} -new_packet_ids = {} + +old_packets: dict[PacketIdentifier, str] = {} +new_packets: dict[PacketIdentifier, str] = {} for packet in old_packet_list: assert packet['class'].endswith('.class') packet_name = old_mappings.get_class(packet['class'][:-6]) - old_packet_ids[packet_name] = packet['id'] + old_packets[PacketIdentifier( + packet['id'], packet['direction'], packet['state'])] = packet_name for packet in new_packet_list: assert packet['class'].endswith('.class') packet_name = new_mappings.get_class(packet['class'][:-6]) - new_packet_ids[packet_name] = packet['id'] + new_packets[PacketIdentifier( + packet['id'], packet['direction'], packet['state'])] = packet_name -# find packets that changed ids -for packet_name in old_packet_ids: - if packet_name in new_packet_ids: - if old_packet_ids[packet_name] != new_packet_ids[packet_name]: - print(packet_name, 'id changed from', - old_packet_ids[packet_name], 'to', new_packet_ids[packet_name]) + +# find removed packets +removed_packets: list[PacketIdentifier] = [] +for packet in old_packets: + if packet not in new_packets: + removed_packets.append(packet) +for (direction, state), packets in group_packets(removed_packets).items(): + lib.code.packet.remove_packet_ids(packets, direction, state) print() -# find removed packets -for packet_name in old_packet_ids: - if packet_name not in new_packet_ids: - print(packet_name, 'removed') +# find packets that changed ids +changed_packets: dict[PacketIdentifier, int] = {} +for old_packet, old_packet_name in old_packets.items(): + for new_packet, new_packet_name in new_packets.items(): + if old_packet == new_packet and old_packet.packet_id != new_packet.packet_id: + changed_packets[old_packet] = new_packet.packet_id +for (direction, state), packets in group_packets(list(changed_packets.keys())).items(): + lib.code.packet.remove_packet_ids(packets, direction, state) + print() # find added packets -for packet_name in new_packet_ids: - if packet_name not in old_packet_ids: - print(packet_name, 'added') - +added_packets: list[PacketIdentifier] = [] +for packet in new_packets: + if packet not in old_packets: + added_packets.append(packet) +for packet in added_packets: + lib.code.packet.generate_packet( + new_burger_data, new_mappings, packet.packet_id, packet.direction, packet.state) lib.code.utils.fmt() print('Done!') diff --git a/codegen/newpacket.py b/codegen/newpacket.py index b3a1c64f..2e4c77d7 100644 --- a/codegen/newpacket.py +++ b/codegen/newpacket.py @@ -9,8 +9,8 @@ burger_packets_data = burger_data[0]['packets']['packet'] packet_id, direction, state = int(sys.argv[1]), sys.argv[2], sys.argv[3] print( f'Generating code for packet id: {packet_id} with direction {direction} and state {state}') -code.packetcodegen.generate(burger_packets_data, mappings, - packet_id, direction, state) +code.packetcodegen.generate_packet(burger_packets_data, mappings, + packet_id, direction, state) code.fmt() |
