summaryrefslogtreecommitdiff
path: root/internal/mkserialize/mkserialize.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/mkserialize/mkserialize.go')
-rw-r--r--internal/mkserialize/mkserialize.go611
1 files changed, 611 insertions, 0 deletions
diff --git a/internal/mkserialize/mkserialize.go b/internal/mkserialize/mkserialize.go
new file mode 100644
index 0000000..9b4b019
--- /dev/null
+++ b/internal/mkserialize/mkserialize.go
@@ -0,0 +1,611 @@
+package main
+
+import (
+ "bufio"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/printer"
+ "go/token"
+ "go/types"
+ "io"
+ "log"
+ "os"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/ast/astutil"
+ "golang.org/x/tools/go/packages"
+)
+
+var (
+ pkg *packages.Package
+
+ serializeFmt = make(map[string]string)
+ deserializeFmt = make(map[string]string)
+
+ uint8T = types.Universe.Lookup("uint8").Type()
+ byteT = types.Universe.Lookup("byte").Type()
+
+ serialize []*types.Named
+ inSerialize = make(map[string]bool)
+
+ consts = make(map[*ast.StructType][]*ast.Comment)
+)
+
+func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
+ fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
+ arg := ""
+ if len(fields) == 2 {
+ arg = fields[1]
+ }
+ switch fields[0] {
+ case "const":
+ tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
+ if err != nil {
+ error(c.Pos(), err)
+ }
+
+ if de {
+ fmt.Println("{")
+ x := newVar()
+ fmt.Println("var", x, typeStr(tv.Type))
+ genSerialize(tv.Type, x, token.NoPos, nil, de)
+ fmt.Println("if", x, "!=", "(", tv.Value, ")",
+ `{ chk(fmt.Errorf("const %v: %v",`, tv.Value, ",", x, ")) }")
+ fmt.Println("}")
+ } else {
+ v := newVar()
+ fmt.Println("{", v, ":=", arg)
+ genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
+ fmt.Println("}")
+ }
+ case "assert":
+ fmt.Printf("if !("+arg+") {", expr)
+ fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
+ fmt.Println("}")
+ case "zlib":
+ if de {
+ fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
+ *sp = append(*sp, func() {
+ fmt.Println("chk(r.Close()) }")
+ })
+ } else {
+ fmt.Println("{ w := zlib.NewWriter(w)")
+ *sp = append(*sp, func() {
+ fmt.Println("chk(w.Close()) }")
+ })
+ }
+ case "lenhdr":
+ if arg != "8" && arg != "16" && arg != "32" {
+ error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
+ }
+
+ fmt.Println("{")
+
+ if !de {
+ fmt.Println("ow := w")
+ fmt.Println("w := new(bytes.Buffer)")
+ }
+
+ var cg ast.CommentGroup
+ if de {
+ t := types.Universe.Lookup("uint" + arg).Type()
+ fmt.Println("var n", t)
+ genSerialize(t, "n", token.NoPos, nil, de)
+ if arg == "64" {
+ fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
+ }
+ fmt.Println("r := &io.LimitedReader{r, int64(n)}")
+ } else {
+ switch arg {
+ case "8", "32":
+ cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
+ case "16":
+ }
+ }
+
+ *sp = append(*sp, func() {
+ if de {
+ fmt.Println("if r.N > 0",
+ `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
+ } else {
+ fmt.Println("{")
+ fmt.Println("buf := w")
+ fmt.Println("w := ow")
+ byteSlice := types.NewSlice(types.Typ[types.Byte])
+ genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
+ fmt.Println("}")
+ }
+
+ fmt.Println("}")
+ })
+ case "end":
+ (*sp)[len(*sp)-1]()
+ *sp = (*sp)[:len(*sp)-1]
+ case "if":
+ fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
+ *sp = append(*sp, func() {
+ fmt.Println("}")
+ })
+ case "ifde":
+ if !de {
+ fmt.Println("/*")
+ }
+ }
+}
+
+func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
+ var lenhdr types.Type = types.Typ[types.Uint16]
+
+ useMethod := true
+ if doc != nil {
+ for _, c := range doc.List {
+ pragma := true
+ switch c.Text {
+ case "//mt:utf16":
+ t = types.NewSlice(types.Typ[types.Uint16])
+ if de {
+ v := newVar()
+ fmt.Println("var", v, typeStr(t))
+ defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
+ expr = v
+ } else {
+ expr = "utf16.Encode([]rune(" + expr + "))"
+ }
+ pos = token.NoPos
+ case "//mt:raw":
+ lenhdr = nil
+ case "//mt:len8":
+ lenhdr = types.Typ[types.Uint8]
+ case "//mt:len32":
+ lenhdr = types.Typ[types.Uint32]
+ case "//mt:opt":
+ fmt.Println("if err := pcall(func() {")
+ defer fmt.Println("}); err != nil && err != io.EOF",
+ "{ chk(err) }")
+ default:
+ pragma = false
+ }
+ if pragma {
+ useMethod = false
+ }
+ }
+ }
+
+ str := types.TypeString(t, types.RelativeTo(pkg.Types))
+ if de {
+ if or, ok := deserializeFmt[str]; ok {
+ fmt.Println("{")
+ fmt.Println("p := &" + expr)
+ fmt.Print(or)
+ fmt.Println("}")
+ return
+ }
+ } else {
+ if or, ok := serializeFmt[str]; ok {
+ fmt.Println("{")
+ fmt.Println("x := " + expr)
+ fmt.Print(or)
+ fmt.Println("}")
+ return
+ }
+ }
+
+ expr = "(" + expr + ")"
+
+ switch t := t.(type) {
+ case *types.Named:
+ if !useMethod {
+ t := t.Underlying()
+ genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
+ return
+ }
+
+ method := "Serialize"
+ if de {
+ method = "Deserialize"
+ }
+ for i := 0; i < t.NumMethods(); i++ {
+ m := t.Method(i)
+ if m.Name() == method {
+ rw := "w"
+ if de {
+ rw = "r"
+ }
+ fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
+ return
+ }
+ }
+
+ mkSerialize(t)
+
+ fmt.Println("if err := pcall(func() {")
+ if de {
+ fmt.Println(expr + ".deserialize(r)")
+ } else {
+ fmt.Println(expr + ".serialize(w)")
+ }
+ fmt.Println("}); err != nil",
+ `{`,
+ `if err == io.EOF { chk(io.EOF) };`,
+ `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
+ `}`)
+ case *types.Struct:
+ st := pos2node(pos)[0].(*ast.StructType)
+
+ a := consts[st]
+ b := st.Fields.List
+
+ // Merge sorted slices.
+ c := make([]ast.Node, 0, len(a)+len(b))
+ for i, j := 0, 0; i < len(a) || j < len(b); {
+ if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
+ c = append(c, a[i])
+ i++
+ } else {
+ c = append(c, b[j])
+ j++
+ }
+ }
+
+ var (
+ stk []func()
+ i int
+ )
+ for _, field := range c {
+ switch field := field.(type) {
+ case *ast.Comment:
+ structPragma(field, &stk, expr, de)
+ case *ast.Field:
+ n := len(field.Names)
+ if n == 0 {
+ n = 1
+ }
+ for ; n > 0; n-- {
+ f := t.Field(i)
+ genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
+ i++
+ }
+ }
+ }
+
+ if len(stk) > 0 {
+ error(pos, "missing //mt:end")
+ }
+ case *types.Basic:
+ switch t.Kind() {
+ case types.String:
+ byteSlice := types.NewSlice(types.Typ[types.Byte])
+ if de {
+ v := newVar()
+ fmt.Println("var", v, byteSlice)
+ genSerialize(byteSlice, v, token.NoPos, doc, de)
+ fmt.Println(expr, "=", "string(", v, ")")
+ } else {
+ genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
+ }
+ default:
+ error(pos, "can't serialize ", t)
+ }
+ case *types.Slice:
+ if de {
+ if lenhdr != nil {
+ v := newVar()
+ fmt.Println("var", v, lenhdr)
+ genSerialize(lenhdr, v, pos, nil, de)
+ fmt.Printf("%s = make(%v, %s)\n",
+ expr, typeStr(t), v)
+ genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
+ } else {
+ fmt.Println("for {")
+ v := newVar()
+ fmt.Println("var", v, typeStr(t.Elem()))
+ fmt.Println("err := pcall(func() {")
+ if pos.IsValid() {
+ pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
+ }
+ genSerialize(t.Elem(), v, pos, nil, de)
+ fmt.Println("})")
+ fmt.Println("if err == io.EOF { break }")
+ fmt.Println(expr + " = append(" + expr + ", " + v + ")")
+ fmt.Println("chk(err)")
+ fmt.Println("}")
+ }
+ } else {
+ if lenhdr != nil {
+ fmt.Println("if len("+expr+") >",
+ "math.Max"+strings.Title(lenhdr.String()),
+ "{ chk(ErrTooLong) }")
+ genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
+ }
+ genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
+ }
+ case *types.Array:
+ et := t.Elem()
+ if et == byteT || et == uint8T {
+ if de {
+ fmt.Println("{",
+ "_, err := io.ReadFull(r, "+expr+"[:]);",
+ "chk(err)",
+ "}")
+ } else {
+ fmt.Println("{",
+ "_, err := w.Write("+expr+"[:]);",
+ "chk(err)",
+ "}")
+ }
+ break
+ }
+ i := newVar()
+ fmt.Println("for", i, ":= range", expr, "{")
+ if pos.IsValid() {
+ pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
+ }
+ genSerialize(et, expr+"["+i+"]", pos, nil, de)
+ fmt.Println("}")
+ default:
+ error(pos, "can't serialize ", t)
+ }
+}
+
+func readOverrides(path string, override map[string]string) {
+ f, err := os.Open(path)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer f.Close()
+
+ b := bufio.NewReader(f)
+ line := 0
+ col1 := ""
+ for {
+ ln, err := b.ReadString('\n')
+ if err != nil {
+ if err == io.EOF {
+ if len(ln) > 0 {
+ log.Fatal("no newline at end of ", f.Name())
+ }
+ return
+ }
+ log.Fatal(err)
+ }
+ line++
+
+ if ln == "\n" {
+ continue
+ }
+
+ fields := strings.SplitN(ln, "\t", 2)
+ if len(fields) == 1 {
+ log.Fatal(f.Name(), ":", line, ": missing tab")
+ }
+ if fields[0] != "" {
+ col1 = fields[0]
+ }
+
+ if col1 == "" {
+ fmt.Print(fields[1])
+ continue
+ }
+
+ override[col1] += fields[1]
+ }
+}
+
+func mkSerialize(t *types.Named) {
+ if !inSerialize[t.String()] {
+ serialize = append(serialize, t)
+ inSerialize[t.String()] = true
+ }
+}
+
+var varNo int
+
+func newVar() string {
+ varNo++
+ return fmt.Sprint("local", varNo)
+}
+
+func pos2node(pos token.Pos) []ast.Node {
+ return interval2node(pos, pos)
+}
+
+func interval2node(start, end token.Pos) []ast.Node {
+ for _, f := range pkg.Syntax {
+ if f.Pos() <= start && end <= f.End() {
+ if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
+ return path
+ }
+ }
+ }
+ return nil
+}
+
+func error(pos token.Pos, a ...interface{}) {
+ if !pos.IsValid() {
+ log.Fatal(a...)
+ }
+ log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
+}
+
+func typeStr(t types.Type) string {
+ return types.TypeString(t, func(p *types.Package) string {
+ if p == pkg.Types {
+ return ""
+ }
+
+ return p.Name()
+ })
+}
+
+var typeNames = []string{
+ "ToSrvNil",
+ "ToSrvInit",
+ "ToSrvInit2",
+ "ToSrvModChanJoin",
+ "ToSrvModChanLeave",
+ "ToSrvModChanMsg",
+ "ToSrvPlayerPos",
+ "ToSrvGotBlks",
+ "ToSrvDeletedBlks",
+ "ToSrvInvAction",
+ "ToSrvChatMsg",
+ "ToSrvFallDmg",
+ "ToSrvSelectItem",
+ "ToSrvRespawn",
+ "ToSrvInteract",
+ "ToSrvRemovedSounds",
+ "ToSrvNodeMetaFields",
+ "ToSrvInvFields",
+ "ToSrvReqMedia",
+ "ToSrvCltReady",
+ "ToSrvFirstSRP",
+ "ToSrvSRPBytesA",
+ "ToSrvSRPBytesM",
+
+ "ToCltHello",
+ "ToCltAcceptAuth",
+ "ToCltAcceptSudoMode",
+ "ToCltDenySudoMode",
+ "ToCltDisco",
+ "ToCltBlkData",
+ "ToCltAddNode",
+ "ToCltRemoveNode",
+ "ToCltInv",
+ "ToCltTimeOfDay",
+ "ToCltCSMRestrictionFlags",
+ "ToCltAddPlayerVel",
+ "ToCltMediaPush",
+ "ToCltChatMsg",
+ "ToCltAORmAdd",
+ "ToCltAOMsgs",
+ "ToCltHP",
+ "ToCltMovePlayer",
+ "ToCltDiscoLegacy",
+ "ToCltFOV",
+ "ToCltDeathScreen",
+ "ToCltMedia",
+ "ToCltNodeDefs",
+ "ToCltAnnounceMedia",
+ "ToCltItemDefs",
+ "ToCltPlaySound",
+ "ToCltStopSound",
+ "ToCltPrivs",
+ "ToCltInvFormspec",
+ "ToCltDetachedInv",
+ "ToCltShowFormspec",
+ "ToCltMovement",
+ "ToCltSpawnParticle",
+ "ToCltAddParticleSpawner",
+ "ToCltAddHUD",
+ "ToCltRmHUD",
+ "ToCltChangeHUD",
+ "ToCltHUDFlags",
+ "ToCltSetHotbarParam",
+ "ToCltBreath",
+ "ToCltSkyParams",
+ "ToCltOverrideDayNightRatio",
+ "ToCltLocalPlayerAnim",
+ "ToCltEyeOffset",
+ "ToCltDelParticleSpawner",
+ "ToCltCloudParams",
+ "ToCltFadeSound",
+ "ToCltUpdatePlayerList",
+ "ToCltModChanMsg",
+ "ToCltModChanSig",
+ "ToCltNodeMetasChanged",
+ "ToCltSunParams",
+ "ToCltMoonParams",
+ "ToCltStarParams",
+ "ToCltSRPBytesSaltB",
+ "ToCltFormspecPrepend",
+
+ "AOCmdProps",
+ "AOCmdPos",
+ "AOCmdTextureMod",
+ "AOCmdSprite",
+ "AOCmdHP",
+ "AOCmdArmorGroups",
+ "AOCmdAnim",
+ "AOCmdBonePos",
+ "AOCmdAttach",
+ "AOCmdPhysOverride",
+ "AOCmdSpawnInfant",
+ "AOCmdAnimSpeed",
+
+ "NodeMeta",
+ "MinimapMode",
+ "NodeDef",
+ "PointedNode",
+ "PointedAO",
+}
+
+func main() {
+ log.SetFlags(0)
+ log.SetPrefix("mkserialize: ")
+
+ flag.Parse()
+
+ cfg := &packages.Config{Mode: packages.NeedSyntax |
+ packages.NeedName |
+ packages.NeedDeps |
+ packages.NeedImports |
+ packages.NeedTypes |
+ packages.NeedTypesInfo}
+ pkgs, err := packages.Load(cfg, flag.Args()...)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if packages.PrintErrors(pkgs) > 0 {
+ os.Exit(1)
+ }
+
+ if len(pkgs) != 1 {
+ log.Fatal("must be exactly 1 package")
+ }
+ pkg = pkgs[0]
+
+ fmt.Println("package", pkg.Name)
+
+ readOverrides("serialize.fmt", serializeFmt)
+ readOverrides("deserialize.fmt", deserializeFmt)
+
+ for _, f := range pkg.Syntax {
+ for _, cg := range f.Comments {
+ for _, c := range cg.List {
+ if !strings.HasPrefix(c.Text, "//mt:") {
+ continue
+ }
+ st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
+ consts[st] = append(consts[st], c)
+ }
+ }
+ }
+
+ for _, name := range typeNames {
+ obj := pkg.Types.Scope().Lookup(name)
+ if obj == nil {
+ log.Println("undeclared identifier: ", name)
+ continue
+ }
+ mkSerialize(obj.Type().(*types.Named))
+ }
+
+ for i := 0; i < len(serialize); i++ {
+ for _, de := range []bool{false, true} {
+ t := serialize[i]
+ sig := "serialize(w io.Writer)"
+ if de {
+ sig = "deserialize(r io.Reader)"
+ }
+ fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
+ pos := t.Obj().Pos()
+ tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
+ var b strings.Builder
+ printer.Fprint(&b, pkg.Fset, tExpr)
+ genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)
+ fmt.Println("}")
+ }
+ }
+}