#![feature(array_try_from_fn)] #![feature(iterator_try_collect)] pub use flate2; pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize}; pub use paste; pub use zstd; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use cgmath::{Deg, Euler, Point1, Point2, Point3, Rad, Vector1, Vector2, Vector3, Vector4}; use collision::{Aabb2, Aabb3}; use enumset::{EnumSet, EnumSetTypeWithRepr}; use paste::paste as paste_macro; use std::{ collections::{HashMap, HashSet}, convert::Infallible, fmt::Debug, io::{self, Read, Write}, num::TryFromIntError, ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}, }; use thiserror::Error; #[cfg(test)] mod tests; use crate as mt_ser; #[derive(Error, Debug)] pub enum SerializeError { #[error("io error: {0}")] IoError(#[from] io::Error), #[error("collection too big: {0}")] TooBig(#[from] TryFromIntError), #[error("{0}")] Other(String), } impl From for SerializeError { fn from(_err: Infallible) -> Self { unreachable!("infallible") } } #[derive(Error, Debug)] pub enum DeserializeError { #[error("io error: {0}")] IoError(io::Error), #[error("unexpected end of file")] UnexpectedEof, #[error("collection too big: {0}")] TooBig(#[from] TryFromIntError), #[error("invalid UTF-16: {0}")] InvalidUtf16(#[from] std::char::DecodeUtf16Error), #[error("invalid {0} enum variant {1:?}")] InvalidEnum(&'static str, Box), #[error("invalid constant - wanted: {0:?} - got: {1:?}")] InvalidConst(Box, Box), #[error("{0}")] Other(String), } impl From for DeserializeError { fn from(_err: Infallible) -> Self { unreachable!("infallible") } } impl From for DeserializeError { fn from(err: io::Error) -> Self { if err.kind() == io::ErrorKind::UnexpectedEof { DeserializeError::UnexpectedEof } else { DeserializeError::IoError(err) } } } pub trait OrDefault { fn or_default(self) -> Self; } pub struct WrapRead<'a, R: Read>(pub &'a mut R); impl<'a, R: Read> Read for WrapRead<'a, R> { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.0.read(buf) } fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { self.0.read_vectored(bufs) } /* fn is_read_vectored(&self) -> bool { self.0.is_read_vectored() } */ fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { self.0.read_to_end(buf) } fn read_to_string(&mut self, buf: &mut String) -> io::Result { self.0.read_to_string(buf) } fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { self.0.read_exact(buf) } /* fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> { self.0.read_buf(buf) } fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> { self.0.read_buf_exact(cursor) } */ } impl OrDefault for Result { fn or_default(self) -> Self { match self { Err(DeserializeError::UnexpectedEof) => Ok(T::default()), x => x, } } } pub trait MtLen { fn option(&self) -> Option; type Range: Iterator + 'static; fn range(&self) -> Self::Range; type Take: Read; fn take(&self, reader: R) -> Self::Take; } pub trait MtCfg { type Len: MtLen; type Inner: MtCfg; fn utf16() -> bool { false } fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>; fn read_len(reader: &mut impl Read) -> Result; } pub trait MtSerialize { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError>; } pub trait MtDeserialize: Sized { fn mt_deserialize(reader: &mut impl Read) -> Result; } impl MtLen for usize { fn option(&self) -> Option { Some(*self) } type Range = std::ops::Range; fn range(&self) -> Self::Range { 0..*self } type Take = io::Take; fn take(&self, reader: R) -> Self::Take { reader.take(*self as u64) } } trait MtCfgLen: Sized + MtSerialize + MtDeserialize + TryFrom + TryInto {} impl MtCfg for T where SerializeError: From<>::Error>, DeserializeError: From<>::Error>, { type Len = usize; type Inner = DefCfg; fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { Self::try_from(len)?.mt_serialize::(writer) } fn read_len(reader: &mut impl Read) -> Result { Ok(Self::mt_deserialize::(reader)?.try_into()?) } } impl MtCfgLen for u8 {} impl MtCfgLen for u16 {} impl MtCfgLen for u32 {} impl MtCfgLen for u64 {} pub type DefCfg = u16; impl MtCfg for () { type Len = (); type Inner = DefCfg; fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> { Ok(()) } fn read_len(_writer: &mut impl Read) -> Result { Ok(()) } } impl MtLen for () { fn option(&self) -> Option { None } type Range = std::ops::RangeFrom; fn range(&self) -> Self::Range { 0.. } type Take = R; fn take(&self, reader: R) -> Self::Take { reader } } pub struct Utf16(pub B); impl MtCfg for Utf16 { type Len = B::Len; type Inner = B::Inner; fn utf16() -> bool { true } fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { B::write_len(len, writer) } fn read_len(reader: &mut impl Read) -> Result { B::read_len(reader) } } impl MtCfg for (A, B) { type Len = A::Len; type Inner = B; fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { A::write_len(len, writer) } fn read_len(reader: &mut impl Read) -> Result { A::read_len(reader) } } impl MtSerialize for u8 { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { writer.write_u8(*self)?; Ok(()) } } impl MtDeserialize for u8 { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(reader.read_u8()?) } } impl MtSerialize for i8 { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { writer.write_i8(*self)?; Ok(()) } } impl MtDeserialize for i8 { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(reader.read_i8()?) } } macro_rules! impl_num { ($T:ty) => { impl MtSerialize for $T { fn mt_serialize( &self, writer: &mut impl Write, ) -> Result<(), SerializeError> { paste_macro! { writer.[]::(*self)?; } Ok(()) } } impl MtDeserialize for $T { fn mt_deserialize(reader: &mut impl Read) -> Result { paste_macro! { Ok(reader.[]::()?) } } } }; } impl_num!(u16); impl_num!(i16); impl_num!(u32); impl_num!(i32); impl_num!(f32); impl_num!(u64); impl_num!(i64); impl_num!(f64); impl MtSerialize for () { fn mt_serialize(&self, _writer: &mut impl Write) -> Result<(), SerializeError> { Ok(()) } } impl MtDeserialize for () { fn mt_deserialize(_reader: &mut impl Read) -> Result { Ok(()) } } impl MtSerialize for bool { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { (*self as u8).mt_serialize::(writer) } } impl MtDeserialize for bool { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(u8::mt_deserialize::(reader)? != 0) } } impl MtSerialize for &T { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { (*self).mt_serialize::(writer) } } pub fn mt_serialize_seq( writer: &mut impl Write, iter: impl ExactSizeIterator + IntoIterator, ) -> Result<(), SerializeError> { C::write_len(iter.len(), writer)?; iter.into_iter() .try_for_each(|item| item.mt_serialize::(writer)) } pub fn mt_deserialize_seq( reader: &mut impl Read, ) -> Result> + '_, DeserializeError> { let len = C::read_len(reader)?; mt_deserialize_sized_seq::(&len, reader) } pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>( len: &C::Len, reader: &'a mut impl Read, ) -> Result> + 'a, DeserializeError> { let variable = len.option().is_none(); Ok(len .range() .map_while(move |_| match T::mt_deserialize::(reader) { Err(DeserializeError::UnexpectedEof) if variable => None, x => Some(x), })) } impl MtSerialize for [T; N] { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::<(), _>(writer, self.iter()) } } impl MtDeserialize for [T; N] { fn mt_deserialize(reader: &mut impl Read) -> Result { std::array::try_from_fn(|_| T::mt_deserialize::(reader)) } } impl> MtSerialize for EnumSet { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.as_repr().mt_serialize::(writer) } } impl> MtDeserialize for EnumSet { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(Self::from_repr_truncated(T::mt_deserialize::( reader, )?)) } } impl MtSerialize for Option { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { match self { Some(item) => item.mt_serialize::(writer), None => Ok(()), } } } impl MtDeserialize for Option { fn mt_deserialize(reader: &mut impl Read) -> Result { T::mt_deserialize::(reader).map(Some).or_default() } } impl MtSerialize for Vec { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::(writer, self.iter()) } } impl MtDeserialize for Vec { fn mt_deserialize(reader: &mut impl Read) -> Result { mt_deserialize_seq::(reader)?.try_collect() } } impl MtSerialize for HashSet { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::(writer, self.iter()) } } impl MtDeserialize for HashSet { fn mt_deserialize(reader: &mut impl Read) -> Result { mt_deserialize_seq::(reader)?.try_collect() } } // TODO: support more tuples impl MtSerialize for (A, B) { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.0.mt_serialize::(writer)?; self.1.mt_serialize::(writer)?; Ok(()) } } impl MtDeserialize for (A, B) { fn mt_deserialize(reader: &mut impl Read) -> Result { let a = A::mt_deserialize::(reader)?; let b = B::mt_deserialize::(reader)?; Ok((a, b)) } } impl MtSerialize for HashMap where K: MtSerialize + std::cmp::Eq + std::hash::Hash, V: MtSerialize, { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::(writer, self.iter()) } } impl MtDeserialize for HashMap where K: MtDeserialize + std::cmp::Eq + std::hash::Hash, V: MtDeserialize, { fn mt_deserialize(reader: &mut impl Read) -> Result { mt_deserialize_seq::(reader)?.try_collect() } } impl MtSerialize for &str { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { if C::utf16() { self.encode_utf16() .collect::>() // FIXME: is this allocation necessary? .mt_serialize::(writer) } else { mt_serialize_seq::(writer, self.as_bytes().iter()) } } } impl MtSerialize for String { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.as_str().mt_serialize::(writer) } } impl MtDeserialize for String { fn mt_deserialize(reader: &mut impl Read) -> Result { if C::utf16() { let mut err = None; let res = char::decode_utf16(mt_deserialize_seq::(reader)?.map_while(|x| match x { Ok(v) => Some(v), Err(e) => { err = Some(e); None } })) .try_collect(); match err { None => Ok(res?), Some(e) => Err(e), } } else { let len = C::read_len(reader)?; // use capacity if available let mut st = match len.option() { Some(x) => String::with_capacity(x), None => String::new(), }; len.take(WrapRead(reader)).read_to_string(&mut st)?; Ok(st) } } } impl MtSerialize for Box { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.deref().mt_serialize::(writer) } } impl MtDeserialize for Box { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(Self::new(T::mt_deserialize::(reader)?)) } } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Range")] #[allow(unused)] struct RemoteRange { start: T, end: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "RangeFrom")] #[allow(unused)] struct RemoteRangeFrom { start: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "RangeFull")] #[allow(unused)] struct RemoteRangeFull; // RangeInclusive fields are private impl MtSerialize for RangeInclusive { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.start().mt_serialize::(writer)?; self.end().mt_serialize::(writer)?; Ok(()) } } impl MtDeserialize for RangeInclusive { fn mt_deserialize(reader: &mut impl Read) -> Result { let start = T::mt_deserialize::(reader)?; let end = T::mt_deserialize::(reader)?; Ok(start..=end) } } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "RangeTo")] #[allow(unused)] struct RemoteRangeTo { end: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "RangeToInclusive")] #[allow(unused)] struct RemoteRangeToInclusive { end: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Vector1")] #[allow(unused)] struct RemoteVector1 { x: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Vector2")] #[allow(unused)] struct RemoteVector2 { x: T, y: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Vector3")] #[allow(unused)] struct RemoteVector3 { x: T, y: T, z: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Vector4")] #[allow(unused)] struct RemoteVector4 { x: T, y: T, z: T, w: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Point1")] #[allow(unused)] struct RemotePoint1 { x: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Point2")] #[allow(unused)] struct RemotePoint2 { x: T, y: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Point3")] #[allow(unused)] struct RemotePoint3 { x: T, y: T, z: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Deg")] #[allow(unused)] struct RemoteDeg(T); #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Rad")] #[allow(unused)] struct RemoteRad(T); #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Euler")] #[allow(unused)] struct RemoteEuler { x: T, y: T, z: T, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Aabb2")] #[allow(unused)] struct RemoteAabb2 { min: Point2, max: Point2, } #[derive(MtSerialize, MtDeserialize)] #[mt(typename = "Aabb3")] #[allow(unused)] struct RemoteAabb3 { min: Point3, max: Point3, }