#![feature(array_try_from_fn)] #![feature(associated_type_bounds)] #![feature(iterator_try_collect)] pub use flate2; pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize}; pub use paste; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use enumset::{EnumSet, EnumSetTypeWithRepr}; use paste::paste as paste_macro; use std::{ collections::{HashMap, HashSet}, convert::Infallible, fmt::Debug, io::{self, Read, Write}, num::TryFromIntError, ops::Deref, }; use thiserror::Error; #[cfg(test)] mod tests; #[derive(Error, Debug)] pub enum SerializeError { #[error("io error: {0}")] IoError(#[from] io::Error), #[error("collection too big: {0}")] TooBig(#[from] TryFromIntError), } impl From for SerializeError { fn from(_err: Infallible) -> Self { unreachable!("infallible") } } #[derive(Error, Debug)] pub enum DeserializeError { #[error("io error: {0}")] IoError(io::Error), #[error("unexpected end of file")] UnexpectedEof, #[error("collection too big: {0}")] TooBig(#[from] TryFromIntError), #[error("invalid UTF-16: {0}")] InvalidUtf16(#[from] std::char::DecodeUtf16Error), #[error("invalid {0} enum variant {1}")] InvalidEnumVariant(&'static str, u64), #[error("invalid constant - wanted: {0:?} - got: {1:?}")] InvalidConst(Box, Box), } impl From for DeserializeError { fn from(_err: Infallible) -> Self { unreachable!("infallible") } } impl From for DeserializeError { fn from(err: io::Error) -> Self { if err.kind() == io::ErrorKind::UnexpectedEof { DeserializeError::UnexpectedEof } else { DeserializeError::IoError(err) } } } pub trait OrDefault { fn or_default(self) -> Self; } pub struct WrapRead<'a, R: Read>(pub &'a mut R); impl<'a, R: Read> Read for WrapRead<'a, R> { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.0.read(buf) } fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { self.0.read_vectored(bufs) } /* fn is_read_vectored(&self) -> bool { self.0.is_read_vectored() } */ fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { self.0.read_to_end(buf) } fn read_to_string(&mut self, buf: &mut String) -> io::Result { self.0.read_to_string(buf) } fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { self.0.read_exact(buf) } /* fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> { self.0.read_buf(buf) } fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> { self.0.read_buf_exact(cursor) } */ } impl OrDefault for Result { fn or_default(self) -> Self { match self { Err(DeserializeError::UnexpectedEof) => Ok(T::default()), x => x, } } } pub trait MtLen { fn option(&self) -> Option; type Range: Iterator + 'static; fn range(&self) -> Self::Range; type Take: Read; fn take(&self, reader: R) -> Self::Take; } pub trait MtCfg { type Len: MtLen; type Inner: MtCfg; fn utf16() -> bool { false } fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>; fn read_len(reader: &mut impl Read) -> Result; } pub trait MtSerialize { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError>; } pub trait MtDeserialize: Sized { fn mt_deserialize(reader: &mut impl Read) -> Result; } impl MtLen for usize { fn option(&self) -> Option { Some(*self) } type Range = std::ops::Range; fn range(&self) -> Self::Range { 0..*self } type Take = io::Take; fn take(&self, reader: R) -> Self::Take { reader.take(*self as u64) } } trait MtCfgLen: Sized + MtSerialize + MtDeserialize + TryFrom> + TryInto> { } impl MtCfg for T { type Len = usize; type Inner = DefCfg; fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { Self::try_from(len) .map_err(Into::into)? .mt_serialize::(writer) } fn read_len(reader: &mut impl Read) -> Result { Ok(Self::mt_deserialize::(reader)? .try_into() .map_err(Into::into)?) } } impl MtCfgLen for u8 {} impl MtCfgLen for u16 {} impl MtCfgLen for u32 {} impl MtCfgLen for u64 {} pub type DefCfg = u16; impl MtCfg for () { type Len = (); type Inner = DefCfg; fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> { Ok(()) } fn read_len(_writer: &mut impl Read) -> Result { Ok(()) } } impl MtLen for () { fn option(&self) -> Option { None } type Range = std::ops::RangeFrom; fn range(&self) -> Self::Range { 0.. } type Take = R; fn take(&self, reader: R) -> Self::Take { reader } } pub struct Utf16(pub B); impl MtCfg for Utf16 { type Len = B::Len; type Inner = B::Inner; fn utf16() -> bool { true } fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { B::write_len(len, writer) } fn read_len(reader: &mut impl Read) -> Result { B::read_len(reader) } } impl MtCfg for (A, B) { type Len = A::Len; type Inner = B; fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { A::write_len(len, writer) } fn read_len(reader: &mut impl Read) -> Result { A::read_len(reader) } } impl MtSerialize for u8 { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { writer.write_u8(*self)?; Ok(()) } } impl MtDeserialize for u8 { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(reader.read_u8()?) } } impl MtSerialize for i8 { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { writer.write_i8(*self)?; Ok(()) } } impl MtDeserialize for i8 { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(reader.read_i8()?) } } macro_rules! impl_num { ($T:ty) => { impl MtSerialize for $T { fn mt_serialize( &self, writer: &mut impl Write, ) -> Result<(), SerializeError> { paste_macro! { writer.[]::(*self)?; } Ok(()) } } impl MtDeserialize for $T { fn mt_deserialize(reader: &mut impl Read) -> Result { paste_macro! { Ok(reader.[]::()?) } } } }; } impl_num!(u16); impl_num!(i16); impl_num!(u32); impl_num!(i32); impl_num!(f32); impl_num!(u64); impl_num!(i64); impl_num!(f64); impl MtSerialize for () { fn mt_serialize(&self, _writer: &mut impl Write) -> Result<(), SerializeError> { Ok(()) } } impl MtDeserialize for () { fn mt_deserialize(_reader: &mut impl Read) -> Result { Ok(()) } } impl MtSerialize for bool { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { (*self as u8).mt_serialize::(writer) } } impl MtDeserialize for bool { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(u8::mt_deserialize::(reader)? != 0) } } impl MtSerialize for &T { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { (*self).mt_serialize::(writer) } } pub fn mt_serialize_seq( writer: &mut impl Write, iter: impl ExactSizeIterator + IntoIterator, ) -> Result<(), SerializeError> { C::write_len(iter.len(), writer)?; iter.into_iter() .try_for_each(|item| item.mt_serialize::(writer)) } pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>( reader: &'a mut impl Read, ) -> Result> + 'a, DeserializeError> { let len = C::read_len(reader)?; mt_deserialize_sized_seq::(&len, reader) } pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>( len: &C::Len, reader: &'a mut impl Read, ) -> Result> + 'a, DeserializeError> { let variable = len.option().is_none(); Ok(len .range() .map_while(move |_| match T::mt_deserialize::(reader) { Err(DeserializeError::UnexpectedEof) if variable => None, x => Some(x), })) } impl MtSerialize for [T; N] { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::<(), _>(writer, self.iter()) } } impl MtDeserialize for [T; N] { fn mt_deserialize(reader: &mut impl Read) -> Result { std::array::try_from_fn(|_| T::mt_deserialize::(reader)) } } impl> MtSerialize for EnumSet { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.as_repr().mt_serialize::(writer) } } impl> MtDeserialize for EnumSet { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(Self::from_repr_truncated(T::mt_deserialize::( reader, )?)) } } impl MtSerialize for Option { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { match self { Some(item) => item.mt_serialize::(writer), None => Ok(()), } } } impl MtDeserialize for Option { fn mt_deserialize(reader: &mut impl Read) -> Result { T::mt_deserialize::(reader).map(Some).or_default() } } impl MtSerialize for Vec { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::(writer, self.iter()) } } impl MtDeserialize for Vec { fn mt_deserialize(reader: &mut impl Read) -> Result { mt_deserialize_seq::(reader)?.try_collect() } } impl MtSerialize for HashSet { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::(writer, self.iter()) } } impl MtDeserialize for HashSet { fn mt_deserialize(reader: &mut impl Read) -> Result { mt_deserialize_seq::(reader)?.try_collect() } } // TODO: support more tuples impl MtSerialize for (A, B) { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.0.mt_serialize::(writer)?; self.1.mt_serialize::(writer)?; Ok(()) } } impl MtDeserialize for (A, B) { fn mt_deserialize(reader: &mut impl Read) -> Result { let a = A::mt_deserialize::(reader)?; let b = B::mt_deserialize::(reader)?; Ok((a, b)) } } impl MtSerialize for HashMap where K: MtSerialize + std::cmp::Eq + std::hash::Hash, V: MtSerialize, { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { mt_serialize_seq::(writer, self.iter()) } } impl MtDeserialize for HashMap where K: MtDeserialize + std::cmp::Eq + std::hash::Hash, V: MtDeserialize, { fn mt_deserialize(reader: &mut impl Read) -> Result { mt_deserialize_seq::(reader)?.try_collect() } } impl MtSerialize for String { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { if C::utf16() { self.encode_utf16() .collect::>() // FIXME: is this allocation necessary? .mt_serialize::(writer) } else { mt_serialize_seq::(writer, self.as_bytes().iter()) } } } impl MtDeserialize for String { fn mt_deserialize(reader: &mut impl Read) -> Result { if C::utf16() { let mut err = None; let res = char::decode_utf16(mt_deserialize_seq::(reader)?.map_while(|x| match x { Ok(v) => Some(v), Err(e) => { err = Some(e); None } })) .try_collect(); match err { None => Ok(res?), Some(e) => Err(e), } } else { let len = C::read_len(reader)?; // use capacity if available let mut st = match len.option() { Some(x) => String::with_capacity(x), None => String::new(), }; len.take(WrapRead(reader)).read_to_string(&mut st)?; Ok(st) } } } impl MtSerialize for Box { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { self.deref().mt_serialize::(writer) } } impl MtDeserialize for Box { fn mt_deserialize(reader: &mut impl Read) -> Result { Ok(Self::new(T::mt_deserialize::(reader)?)) } }