From 228ee4a2f0f94975fc233946e0c4d258f87fbcf4 Mon Sep 17 00:00:00 2001 From: mat Date: Wed, 22 Mar 2023 19:52:19 +0000 Subject: optimize nbt lists --- azalea-nbt/src/decode.rs | 233 +++++++++++++++++++++++++++++++++++----------- azalea-nbt/src/encode.rs | 152 +++++++++++++++++++----------- azalea-nbt/src/lib.rs | 1 + azalea-nbt/src/tag.rs | 67 ++++++++++--- azalea-nbt/tests/tests.rs | 38 ++++---- 5 files changed, 353 insertions(+), 138 deletions(-) diff --git a/azalea-nbt/src/decode.rs b/azalea-nbt/src/decode.rs index 6ec4bf13..da3c933c 100755 --- a/azalea-nbt/src/decode.rs +++ b/azalea-nbt/src/decode.rs @@ -1,9 +1,14 @@ +use crate::tag::NbtByteArray; +use crate::tag::NbtCompound; +use crate::tag::NbtIntArray; +use crate::tag::NbtList; +use crate::tag::NbtLongArray; +use crate::tag::NbtString; use crate::Error; use crate::Tag; use ahash::AHashMap; use azalea_buf::{BufReadError, McBufReadable}; use byteorder::{ReadBytesExt, BE}; -use compact_str::CompactString; use flate2::read::{GzDecoder, ZlibDecoder}; use log::warn; use std::io::Cursor; @@ -21,7 +26,7 @@ fn read_bytes<'a>(buf: &'a mut Cursor<&[u8]>, length: usize) -> Result<&'a [u8], } #[inline] -fn read_string(stream: &mut Cursor<&[u8]>) -> Result { +fn read_string(stream: &mut Cursor<&[u8]>) -> Result { let length = stream.read_u16::()? as usize; let buf = read_bytes(stream, length)?; @@ -35,6 +40,175 @@ fn read_string(stream: &mut Cursor<&[u8]>) -> Result { }) } +#[inline] +fn read_byte_array(stream: &mut Cursor<&[u8]>) -> Result { + let length = stream.read_u32::()? as usize; + let bytes = read_bytes(stream, length)?.to_vec(); + Ok(bytes) +} + +// https://stackoverflow.com/a/59707887 +fn vec_u8_into_i8(v: Vec) -> Vec { + // ideally we'd use Vec::into_raw_parts, but it's unstable, + // so we have to do it manually: + + // first, make sure v's destructor doesn't free the data + // it thinks it owns when it goes out of scope + let mut v = std::mem::ManuallyDrop::new(v); + + // then, pick apart the existing Vec + let p = v.as_mut_ptr(); + let len = v.len(); + let cap = v.capacity(); + + // finally, adopt the data into a new Vec + unsafe { Vec::from_raw_parts(p as *mut i8, len, cap) } +} + +#[inline] +fn read_list(stream: &mut Cursor<&[u8]>) -> Result { + let type_id = stream.read_u8()?; + let length = stream.read_u32::()?; + let list = match type_id { + 0 => NbtList::Empty, + 1 => NbtList::Byte(vec_u8_into_i8( + read_bytes(stream, length as usize)?.to_vec(), + )), + 2 => NbtList::Short({ + if ((length * 2) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| stream.read_i16::()) + .collect::, _>>()? + }), + 3 => NbtList::Int({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| stream.read_i32::()) + .collect::, _>>()? + }), + 4 => NbtList::Long({ + if ((length * 8) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| stream.read_i64::()) + .collect::, _>>()? + }), + 5 => NbtList::Float({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| stream.read_f32::()) + .collect::, _>>()? + }), + 6 => NbtList::Double({ + if ((length * 8) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| stream.read_f64::()) + .collect::, _>>()? + }), + 7 => NbtList::ByteArray({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| read_byte_array(stream)) + .collect::, _>>()? + }), + 8 => NbtList::String({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| read_string(stream)) + .collect::, _>>()? + }), + 9 => NbtList::List({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| read_list(stream)) + .collect::, _>>()? + }), + 10 => NbtList::Compound({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| read_compound(stream)) + .collect::, _>>()? + }), + 11 => NbtList::IntArray({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| read_int_array(stream)) + .collect::, _>>()? + }), + 12 => NbtList::LongArray({ + if ((length * 4) as usize) > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + (0..length) + .map(|_| read_long_array(stream)) + .collect::, _>>()? + }), + _ => return Err(Error::InvalidTagType(type_id)), + }; + Ok(list) +} + +#[inline] +fn read_compound(stream: &mut Cursor<&[u8]>) -> Result { + // we default to capacity 4 because it'll probably not be empty + let mut map = NbtCompound::with_capacity(4); + loop { + let tag_id = stream.read_u8().unwrap_or(0); + if tag_id == 0 { + break; + } + let name = read_string(stream)?; + let tag = Tag::read_known(stream, tag_id)?; + map.insert(name, tag); + } + Ok(map) +} + +#[inline] +fn read_int_array(stream: &mut Cursor<&[u8]>) -> Result { + let length = stream.read_u32::()? as usize; + if length * 4 > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + let mut ints = NbtIntArray::with_capacity(length); + for _ in 0..length { + ints.push(stream.read_i32::()?); + } + Ok(ints) +} + +#[inline] +fn read_long_array(stream: &mut Cursor<&[u8]>) -> Result { + let length = stream.read_u32::()? as usize; + if length * 8 > (stream.get_ref().len() - stream.position() as usize) { + return Err(Error::UnexpectedEof); + } + let mut longs = NbtLongArray::with_capacity(length); + for _ in 0..length { + longs.push(stream.read_i64::()?); + } + Ok(longs) +} + impl Tag { /// Read the NBT data when you already know the ID of the tag. You usually /// want [`Tag::read`] if you're reading an NBT file. @@ -60,11 +234,7 @@ impl Tag { 6 => Tag::Double(stream.read_f64::()?), // A length-prefixed array of signed bytes. The prefix is a signed // integer (thus 4 bytes) - 7 => { - let length = stream.read_u32::()? as usize; - let bytes = read_bytes(stream, length)?.to_vec(); - Tag::ByteArray(bytes) - } + 7 => Tag::ByteArray(read_byte_array(stream)?), // A length-prefixed modified UTF-8 string. The prefix is an // unsigned short (thus 2 bytes) signifying the length of the // string in bytes @@ -77,57 +247,16 @@ impl Tag { // notchian implementation uses TAG_End in that situation, but // another reference implementation by Mojang uses 1 instead; // parsers should accept any type if the length is <= 0). - 9 => { - let type_id = stream.read_u8()?; - let length = stream.read_u32::()?; - let mut list = Vec::new(); - for _ in 0..length { - list.push(Tag::read_known(stream, type_id)?); - } - Tag::List(list) - } + 9 => Tag::List(read_list(stream)?), // Effectively a list of a named tags. Order is not guaranteed. - 10 => { - // we default to capacity 4 because it'll probably not be empty - let mut map = AHashMap::with_capacity(4); - loop { - let tag_id = stream.read_u8().unwrap_or(0); - if tag_id == 0 { - break; - } - let name = read_string(stream)?; - let tag = Tag::read_known(stream, tag_id)?; - map.insert(name, tag); - } - Tag::Compound(map) - } + 10 => Tag::Compound(read_compound(stream)?), // A length-prefixed array of signed integers. The prefix is a // signed integer (thus 4 bytes) and indicates the number of 4 byte // integers. - 11 => { - let length = stream.read_u32::()? as usize; - if length * 4 > (stream.get_ref().len() - stream.position() as usize) { - return Err(Error::UnexpectedEof); - } - let mut ints = Vec::with_capacity(length); - for _ in 0..length { - ints.push(stream.read_i32::()?); - } - Tag::IntArray(ints) - } + 11 => Tag::IntArray(read_int_array(stream)?), // A length-prefixed array of signed longs. The prefix is a signed // integer (thus 4 bytes) and indicates the number of 8 byte longs. - 12 => { - let length = stream.read_u32::()? as usize; - if length * 8 > (stream.get_ref().len() - stream.position() as usize) { - return Err(Error::UnexpectedEof); - } - let mut longs = Vec::with_capacity(length); - for _ in 0..length { - longs.push(stream.read_i64::()?); - } - Tag::LongArray(longs) - } + 12 => Tag::LongArray(read_long_array(stream)?), _ => return Err(Error::InvalidTagType(id)), }) } diff --git a/azalea-nbt/src/encode.rs b/azalea-nbt/src/encode.rs index 09cfffac..8dfd8fa4 100755 --- a/azalea-nbt/src/encode.rs +++ b/azalea-nbt/src/encode.rs @@ -1,9 +1,9 @@ +use crate::tag::NbtCompound; +use crate::tag::NbtList; use crate::Error; use crate::Tag; -use ahash::AHashMap; use azalea_buf::McBufWritable; use byteorder::{WriteBytesExt, BE}; -use compact_str::CompactString; use flate2::write::{GzEncoder, ZlibEncoder}; use std::io::Write; @@ -16,11 +16,7 @@ fn write_string(writer: &mut dyn Write, string: &str) -> Result<(), Error> { } #[inline] -fn write_compound( - writer: &mut dyn Write, - value: &AHashMap, - end_tag: bool, -) -> Result<(), Error> { +fn write_compound(writer: &mut dyn Write, value: &NbtCompound, end_tag: bool) -> Result<(), Error> { for (key, tag) in value { match tag { Tag::End => {} @@ -57,7 +53,7 @@ fn write_compound( Tag::ByteArray(value) => { writer.write_u8(7)?; write_string(writer, key)?; - write_bytearray(writer, value)?; + write_byte_array(writer, value)?; } Tag::String(value) => { writer.write_u8(8)?; @@ -77,12 +73,12 @@ fn write_compound( Tag::IntArray(value) => { writer.write_u8(11)?; write_string(writer, key)?; - write_intarray(writer, value)?; + write_int_array(writer, value)?; } Tag::LongArray(value) => { writer.write_u8(12)?; write_string(writer, key)?; - write_longarray(writer, value)?; + write_long_array(writer, value)?; } } } @@ -93,45 +89,91 @@ fn write_compound( } #[inline] -fn write_list(writer: &mut dyn Write, value: &[Tag]) -> Result<(), Error> { - // we just get the type from the first item, or default the type to END - if value.is_empty() { - writer.write_all(&[0; 5])?; - } else { - let first_tag = &value[0]; - writer.write_u8(first_tag.id())?; - writer.write_i32::(value.len() as i32)?; - match first_tag { - Tag::Int(_) => { - for tag in value { - writer.write_i32::( - *tag.as_int().expect("List of Int should only contains Int"), - )?; - } - } - Tag::String(_) => { - for tag in value { - write_string( - writer, - tag.as_string() - .expect("List of String should only contain String"), - )?; - } - } - Tag::Compound(_) => { - for tag in value { - write_compound( - writer, - tag.as_compound() - .expect("List of Compound should only contain Compound"), - true, - )?; - } - } - _ => { - for tag in value { - tag.write_without_end(writer)?; - } +fn write_list(writer: &mut dyn Write, value: &NbtList) -> Result<(), Error> { + match value { + NbtList::Empty => writer.write_all(&[0; 5])?, + NbtList::Byte(l) => { + writer.write_u8(1)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + writer.write_i8(*v)?; + } + } + NbtList::Short(l) => { + writer.write_u8(2)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + writer.write_i16::(*v)?; + } + } + NbtList::Int(l) => { + writer.write_u8(3)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + writer.write_i32::(*v)?; + } + } + NbtList::Long(l) => { + writer.write_u8(4)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + writer.write_i64::(*v)?; + } + } + NbtList::Float(l) => { + writer.write_u8(5)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + writer.write_f32::(*v)?; + } + } + NbtList::Double(l) => { + writer.write_u8(6)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + writer.write_f64::(*v)?; + } + } + NbtList::ByteArray(l) => { + writer.write_u8(7)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + write_byte_array(writer, v)?; + } + } + NbtList::String(l) => { + writer.write_u8(8)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + write_string(writer, v)?; + } + } + NbtList::List(l) => { + writer.write_u8(9)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + write_list(writer, v)?; + } + } + NbtList::Compound(l) => { + writer.write_u8(10)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + write_compound(writer, v, true)?; + } + } + NbtList::IntArray(l) => { + writer.write_u8(11)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + write_int_array(writer, v)?; + } + } + NbtList::LongArray(l) => { + writer.write_u8(12)?; + writer.write_i32::(l.len() as i32)?; + for v in l { + write_long_array(writer, v)?; } } } @@ -140,14 +182,14 @@ fn write_list(writer: &mut dyn Write, value: &[Tag]) -> Result<(), Error> { } #[inline] -fn write_bytearray(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> { +fn write_byte_array(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> { writer.write_u32::(value.len() as u32)?; writer.write_all(value)?; Ok(()) } #[inline] -fn write_intarray(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> { +fn write_int_array(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> { writer.write_u32::(value.len() as u32)?; for &int in value { writer.write_i32::(int)?; @@ -156,7 +198,7 @@ fn write_intarray(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> } #[inline] -fn write_longarray(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> { +fn write_long_array(writer: &mut dyn Write, value: &Vec) -> Result<(), Error> { writer.write_u32::(value.len() as u32)?; for &long in value { writer.write_i64::(long)?; @@ -179,12 +221,12 @@ impl Tag { Tag::Long(value) => writer.write_i64::(*value)?, Tag::Float(value) => writer.write_f32::(*value)?, Tag::Double(value) => writer.write_f64::(*value)?, - Tag::ByteArray(value) => write_bytearray(writer, value)?, + Tag::ByteArray(value) => write_byte_array(writer, value)?, Tag::String(value) => write_string(writer, value)?, Tag::List(value) => write_list(writer, value)?, Tag::Compound(value) => write_compound(writer, value, true)?, - Tag::IntArray(value) => write_intarray(writer, value)?, - Tag::LongArray(value) => write_longarray(writer, value)?, + Tag::IntArray(value) => write_int_array(writer, value)?, + Tag::LongArray(value) => write_long_array(writer, value)?, } Ok(()) diff --git a/azalea-nbt/src/lib.rs b/azalea-nbt/src/lib.rs index 0ceca39f..048e466e 100755 --- a/azalea-nbt/src/lib.rs +++ b/azalea-nbt/src/lib.rs @@ -6,6 +6,7 @@ mod error; mod tag; pub use error::Error; +pub use tag::NbtList; pub use tag::Tag; #[cfg(test)] diff --git a/azalea-nbt/src/tag.rs b/azalea-nbt/src/tag.rs index 4661479b..4a702339 100755 --- a/azalea-nbt/src/tag.rs +++ b/azalea-nbt/src/tag.rs @@ -5,6 +5,18 @@ use enum_as_inner::EnumAsInner; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +pub type NbtByte = i8; +pub type NbtShort = i16; +pub type NbtInt = i32; +pub type NbtLong = i64; +pub type NbtFloat = f32; +pub type NbtDouble = f64; +pub type NbtByteArray = Vec; +pub type NbtString = CompactString; +pub type NbtCompound = AHashMap; +pub type NbtIntArray = Vec; +pub type NbtLongArray = Vec; + /// An NBT value. #[derive(Clone, Debug, PartialEq, Default, EnumAsInner)] #[repr(u8)] @@ -12,18 +24,38 @@ use serde::{Deserialize, Serialize}; pub enum Tag { #[default] End = 0, - Byte(i8) = 1, - Short(i16) = 2, - Int(i32) = 3, - Long(i64) = 4, - Float(f32) = 5, - Double(f64) = 6, - ByteArray(Vec) = 7, - String(CompactString) = 8, - List(Vec) = 9, - Compound(AHashMap) = 10, - IntArray(Vec) = 11, - LongArray(Vec) = 12, + Byte(NbtByte) = 1, + Short(NbtShort) = 2, + Int(NbtInt) = 3, + Long(NbtLong) = 4, + Float(NbtFloat) = 5, + Double(NbtDouble) = 6, + ByteArray(NbtByteArray) = 7, + String(NbtString) = 8, + List(NbtList) = 9, + Compound(NbtCompound) = 10, + IntArray(NbtIntArray) = 11, + LongArray(NbtLongArray) = 12, +} + +/// An NBT value. +#[derive(Clone, Debug, PartialEq)] +#[repr(u8)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize), serde(untagged))] +pub enum NbtList { + Empty, + Byte(Vec) = 1, + Short(Vec) = 2, + Int(Vec) = 3, + Long(Vec) = 4, + Float(Vec) = 5, + Double(Vec) = 6, + ByteArray(Vec) = 7, + String(Vec) = 8, + List(Vec) = 9, + Compound(Vec) = 10, + IntArray(Vec) = 11, + LongArray(Vec) = 12, } impl Tag { @@ -37,3 +69,14 @@ impl Tag { unsafe { *<*const _>::from(self).cast::() } } } +impl NbtList { + /// Get the numerical ID of the tag type. + #[inline] + pub fn id(&self) -> u8 { + // SAFETY: Because `Self` is marked `repr(u8)`, its layout is a `repr(C)` + // `union` between `repr(C)` structs, each of which has the `u8` + // discriminant as its first field, so we can read the discriminant + // without offsetting the pointer. + unsafe { *<*const _>::from(self).cast::() } + } +} diff --git a/azalea-nbt/tests/tests.rs b/azalea-nbt/tests/tests.rs index c0fe520d..62852578 100755 --- a/azalea-nbt/tests/tests.rs +++ b/azalea-nbt/tests/tests.rs @@ -1,5 +1,5 @@ use ahash::AHashMap; -use azalea_nbt::Tag; +use azalea_nbt::{NbtList, Tag}; use std::io::Cursor; #[test] @@ -53,24 +53,24 @@ fn test_bigtest() { fn test_stringtest() { let correct_tag = Tag::Compound(AHashMap::from_iter(vec![( "😃".into(), - Tag::List(vec![ - Tag::String("asdfkghasfjgihsdfogjsndfg".into()), - Tag::String("jnabsfdgihsabguiqwrntgretqwejirhbiqw".into()), - Tag::String("asd".into()), - Tag::String("wqierjgt7wqy8u4rtbwreithwretiwerutbwenryq8uwervqwer9iuqwbrgyuqrbtwierotugqewrtqwropethert".into()), - Tag::String("asdf".into()), - Tag::String("alsdkjiqwoe".into()), - Tag::String("lmqi9hyqd".into()), - Tag::String("qwertyuiop".into()), - Tag::String("asdfghjkl".into()), - Tag::String("zxcvbnm".into()), - Tag::String(" ".into()), - Tag::String("words words words words words words".into()), - Tag::String("aaaaaaaaaaaaaaaaaaaa".into()), - Tag::String("♥".into()), - Tag::String("a\nb\n\n\nc\r\rd".into()), - Tag::String("😁".into()), - ]) + Tag::List(NbtList::String(vec![ + "asdfkghasfjgihsdfogjsndfg".into(), + "jnabsfdgihsabguiqwrntgretqwejirhbiqw".into(), + "asd".into(), + "wqierjgt7wqy8u4rtbwreithwretiwerutbwenryq8uwervqwer9iuqwbrgyuqrbtwierotugqewrtqwropethert".into(), + "asdf".into(), + "alsdkjiqwoe".into(), + "lmqi9hyqd".into(), + "qwertyuiop".into(), + "asdfghjkl".into(), + "zxcvbnm".into(), + " ".into(), + "words words words words words words".into(), + "aaaaaaaaaaaaaaaaaaaa".into(), + "♥".into(), + "a\nb\n\n\nc\r\rd".into(), + "😁".into(), + ])) )])); let original = include_bytes!("stringtest.nbt").to_vec(); -- cgit v1.2.3