diff options
Diffstat (limited to 'fromlua')
-rwxr-xr-x | fromlua/generate.lua | 198 | ||||
-rw-r--r-- | fromlua/generated.go | 320 | ||||
-rw-r--r-- | fromlua/static.go | 59 |
3 files changed, 577 insertions, 0 deletions
diff --git a/fromlua/generate.lua b/fromlua/generate.lua new file mode 100755 index 0000000..75899b7 --- /dev/null +++ b/fromlua/generate.lua @@ -0,0 +1,198 @@ +#!/usr/bin/env lua +dofile("../parse_spec.lua") + +local readers = { + SliceByte = true, + Byte = true, + String = true, + SliceField = true, + Field = true, + Bool = true, + PointedThing = true, +} + +local static_uses = { + "[3]int16", + "AOID" +} + +local function generate(name) + local fnname, index, child, childfn, childtype + local type = name + + local open = name:find("%[") + local clos = name:find("%]") + + if open == 1 then + index = name:sub(open + 1, clos - 1) + child = name:sub(clos + 1) + childfn, childtype = generate(child) + fnname = (index == "" and "Slice" or "Vec" .. index) .. childfn + + type = "[" .. index .. "]" .. childtype + else + fnname = camel_case(name) + + local c = name:sub(1, 1) + if c == c:upper() then + type = "mt." .. name + end + end + + if not readers[fnname] then + local fun = "func read" .. fnname .. "(l *lua.LState, val lua.LValue, ptr *" .. type .. ") {\n" + + if child then + fun = fun .. "\tif val.Type() != lua.LTTable {\n\t\tpanic(\"invalid value for " + .. name .. ": must be a table\")\n\t}\n" + + if index == "" then + fun = fun .. +[[ + tbl := val.(*lua.LTable) + n := tbl.MaxN() + *ptr = make(]] .. type .. [[, n) + for i := range *ptr { + read]] .. childfn .. [[(l, l.RawGetInt(tbl, i+1), &(*ptr)[i]) + } +]] + else + local n = tonumber(index) + for i, v in ipairs({"x", "y", "z"}) do + if i > n then + break + end + + fun = fun + .. "\tread" .. childfn + .. "(l, l.GetField(val, \"" .. v .. "\"), &(*ptr)[" .. (i - 1) .. "])\n" + end + end + else + fun = fun .. "\tif val.Type() != lua.LTNumber {\n\t\tpanic(\"invalid value for " + .. name .. ": must be a number\")\n\t}\n" + .. "\t*ptr = " .. type .. "(val.(lua.LNumber))\n" + end + + fun = fun .. "}\n\n" + + readers[fnname] = fun + end + + return fnname, type +end + +for _, use in ipairs(static_uses) do + generate(use) +end + +local function signature(name, prefix, type) + local camel = camel_case(name) + return "func read" .. camel .. "(l *lua.LState, val lua.LValue, ptr *" .. prefix .. camel .. ") {\n" +end + +for name, fields in spairs(parse_spec("server/enum")) do + local camel = camel_case(name) + local fun = signature(name, "mt.") + + local impl = "" + for _, var in ipairs(fields) do + local equals = "*ptr = mt." .. apply_prefix(fields, var) .. "\n" + + if var == "no" then + fun = fun .. "\tif val.Type() == lua.LTNil {\n\t\t" .. equals .. "\t\treturn\n\t}\n" + else + impl = impl .. "\tcase \"" .. var .. "\":\n\t\t" .. equals + end + end + + fun = fun + .. "\tif val.Type() != lua.LTString {\n\t\tpanic(\"invalid value for " + .. camel .. ": must be a string\")\n\t}\n" + .. "\tstr := string(val.(lua.LString))\n" + .. "\tswitch str {\n" .. impl + .. "\tdefault:\n\t\tpanic(\"invalid value for " .. name .. ": \" + str)\n\t}\n}\n\n" + + readers[camel] = fun +end + +for name, fields in spairs(parse_spec("server/flag")) do + local camel = camel_case(name) + local fun = signature(name, "mt.") + .. "\tif val.Type() != lua.LTTable {\n\t\tpanic(\"invalid value for " + .. camel .. ": must be a table\")\n\t}\n" + + for _, var in ipairs(fields) do + fun = fun .. "\tif l.GetField(val, \"" .. var .. "\") == lua.LTrue {\n" + .. "\t\t*ptr = *ptr | mt." .. apply_prefix(fields, var) .. "\n\t}\n" + end + + fun = fun .. "}\n\n" + readers[camel] = fun +end + +local function fields_fromlua(fields, indent) + local impl = "" + + for name, type in spairs(fields) do + impl = impl .. indent .. "read" .. generate(type) .. "(l, l.GetField(val, \"" .. name .. "\"), &ptr." + .. camel_case(name) .. ")\n" + end + + return impl +end + +for name, fields in spairs(parse_spec("server/struct", true)) do + local camel = camel_case(name) + readers[camel] = signature(name, "mt.") + .. "\tif val.Type() != lua.LTTable {\n" + .. "\t\tpanic(\"invalid value for " .. camel .. ": must be a table\")\n\t}\n" + .. fields_fromlua(fields, "\t") + .. "}\n\n" +end + +local pkt_impl = "" + +for name, fields in spairs(parse_spec("server/pkt", true)) do + pkt_impl = pkt_impl + .. "\tcase \"" .. name .. "\"" .. "" .. ":\n" + .. "\t\tptr := &mt.ToSrv" .. camel_case(name) .. "{}\n" + + if next(fields) then + pkt_impl = pkt_impl + .. "\t\tval := l.CheckTable(3)\n" + .. fields_fromlua(fields, "\t\t") + end + + pkt_impl = pkt_impl + .. "\t\treturn ptr\n" +end + +local funcs = "" +for _, fn in spairs(readers) do + if type(fn) == "string" then + funcs = funcs .. fn + end +end + +local f = io.open("generated.go", "w") +f:write([[ +// generated by generate.lua, DO NOT EDIT +package fromlua + +import ( + "github.com/anon55555/mt" + "github.com/yuin/gopher-lua" +) + +]] .. funcs .. [[ +func Cmd(l *lua.LState) mt.Cmd { + str := l.CheckString(2) + switch str { +]] .. pkt_impl .. [[ + } + + panic("invalid packet type: " + str) +} +]]) +f:close() diff --git a/fromlua/generated.go b/fromlua/generated.go new file mode 100644 index 0000000..c93c229 --- /dev/null +++ b/fromlua/generated.go @@ -0,0 +1,320 @@ +// generated by generate.lua, DO NOT EDIT +package fromlua + +import ( + "github.com/anon55555/mt" + "github.com/yuin/gopher-lua" +) + +func readAOID(l *lua.LState, val lua.LValue, ptr *mt.AOID) { + if val.Type() != lua.LTNumber { + panic("invalid value for AOID: must be a number") + } + *ptr = mt.AOID(val.(lua.LNumber)) +} + +func readCompressionModes(l *lua.LState, val lua.LValue, ptr *mt.CompressionModes) { + if val.Type() != lua.LTNumber { + panic("invalid value for CompressionModes: must be a number") + } + *ptr = mt.CompressionModes(val.(lua.LNumber)) +} + +func readInt16(l *lua.LState, val lua.LValue, ptr *int16) { + if val.Type() != lua.LTNumber { + panic("invalid value for int16: must be a number") + } + *ptr = int16(val.(lua.LNumber)) +} + +func readInt32(l *lua.LState, val lua.LValue, ptr *int32) { + if val.Type() != lua.LTNumber { + panic("invalid value for int32: must be a number") + } + *ptr = int32(val.(lua.LNumber)) +} + +func readInteraction(l *lua.LState, val lua.LValue, ptr *mt.Interaction) { + if val.Type() != lua.LTString { + panic("invalid value for Interaction: must be a string") + } + str := string(val.(lua.LString)) + switch str { + case "dig": + *ptr = mt.Dig + case "stop_digging": + *ptr = mt.StopDigging + case "dug": + *ptr = mt.Dug + case "place": + *ptr = mt.Place + case "use": + *ptr = mt.Use + case "activate": + *ptr = mt.Activate + default: + panic("invalid value for interaction: " + str) + } +} + +func readKeys(l *lua.LState, val lua.LValue, ptr *mt.Keys) { + if val.Type() != lua.LTTable { + panic("invalid value for Keys: must be a table") + } + if l.GetField(val, "forward") == lua.LTrue { + *ptr = *ptr | mt.ForwardKey + } + if l.GetField(val, "backward") == lua.LTrue { + *ptr = *ptr | mt.BackwardKey + } + if l.GetField(val, "left") == lua.LTrue { + *ptr = *ptr | mt.LeftKey + } + if l.GetField(val, "right") == lua.LTrue { + *ptr = *ptr | mt.RightKey + } + if l.GetField(val, "jump") == lua.LTrue { + *ptr = *ptr | mt.JumpKey + } + if l.GetField(val, "special") == lua.LTrue { + *ptr = *ptr | mt.SpecialKey + } + if l.GetField(val, "sneak") == lua.LTrue { + *ptr = *ptr | mt.SneakKey + } + if l.GetField(val, "dig") == lua.LTrue { + *ptr = *ptr | mt.DigKey + } + if l.GetField(val, "place") == lua.LTrue { + *ptr = *ptr | mt.PlaceKey + } + if l.GetField(val, "zoom") == lua.LTrue { + *ptr = *ptr | mt.ZoomKey + } +} + +func readPlayerPos(l *lua.LState, val lua.LValue, ptr *mt.PlayerPos) { + if val.Type() != lua.LTTable { + panic("invalid value for PlayerPos: must be a table") + } + readUint8(l, l.GetField(val, "fov80"), &ptr.FOV80) + readKeys(l, l.GetField(val, "keys"), &ptr.Keys) + readInt32(l, l.GetField(val, "pitch100"), &ptr.Pitch100) + readVec3Int32(l, l.GetField(val, "pos100"), &ptr.Pos100) + readVec3Int32(l, l.GetField(val, "vel100"), &ptr.Vel100) + readUint8(l, l.GetField(val, "wanted_range"), &ptr.WantedRange) + readInt32(l, l.GetField(val, "yaw100"), &ptr.Yaw100) +} + +func readSliceSoundID(l *lua.LState, val lua.LValue, ptr *[]mt.SoundID) { + if val.Type() != lua.LTTable { + panic("invalid value for []SoundID: must be a table") + } + tbl := val.(*lua.LTable) + n := tbl.MaxN() + *ptr = make([]mt.SoundID, n) + for i := range *ptr { + readSoundID(l, l.RawGetInt(tbl, i+1), &(*ptr)[i]) + } +} + +func readSliceString(l *lua.LState, val lua.LValue, ptr *[]string) { + if val.Type() != lua.LTTable { + panic("invalid value for []string: must be a table") + } + tbl := val.(*lua.LTable) + n := tbl.MaxN() + *ptr = make([]string, n) + for i := range *ptr { + readString(l, l.RawGetInt(tbl, i+1), &(*ptr)[i]) + } +} + +func readSliceVec3Int16(l *lua.LState, val lua.LValue, ptr *[][3]int16) { + if val.Type() != lua.LTTable { + panic("invalid value for [][3]int16: must be a table") + } + tbl := val.(*lua.LTable) + n := tbl.MaxN() + *ptr = make([][3]int16, n) + for i := range *ptr { + readVec3Int16(l, l.RawGetInt(tbl, i+1), &(*ptr)[i]) + } +} + +func readSoundID(l *lua.LState, val lua.LValue, ptr *mt.SoundID) { + if val.Type() != lua.LTNumber { + panic("invalid value for SoundID: must be a number") + } + *ptr = mt.SoundID(val.(lua.LNumber)) +} + +func readUint16(l *lua.LState, val lua.LValue, ptr *uint16) { + if val.Type() != lua.LTNumber { + panic("invalid value for uint16: must be a number") + } + *ptr = uint16(val.(lua.LNumber)) +} + +func readUint8(l *lua.LState, val lua.LValue, ptr *uint8) { + if val.Type() != lua.LTNumber { + panic("invalid value for uint8: must be a number") + } + *ptr = uint8(val.(lua.LNumber)) +} + +func readVec3Int16(l *lua.LState, val lua.LValue, ptr *[3]int16) { + if val.Type() != lua.LTTable { + panic("invalid value for [3]int16: must be a table") + } + readInt16(l, l.GetField(val, "x"), &(*ptr)[0]) + readInt16(l, l.GetField(val, "y"), &(*ptr)[1]) + readInt16(l, l.GetField(val, "z"), &(*ptr)[2]) +} + +func readVec3Int32(l *lua.LState, val lua.LValue, ptr *[3]int32) { + if val.Type() != lua.LTTable { + panic("invalid value for [3]int32: must be a table") + } + readInt32(l, l.GetField(val, "x"), &(*ptr)[0]) + readInt32(l, l.GetField(val, "y"), &(*ptr)[1]) + readInt32(l, l.GetField(val, "z"), &(*ptr)[2]) +} + +func Cmd(l *lua.LState) mt.Cmd { + str := l.CheckString(2) + switch str { + case "chat_msg": + ptr := &mt.ToSrvChatMsg{} + val := l.CheckTable(3) + readString(l, l.GetField(val, "msg"), &ptr.Msg) + return ptr + case "clt_ready": + ptr := &mt.ToSrvCltReady{} + val := l.CheckTable(3) + readUint16(l, l.GetField(val, "formspec"), &ptr.Formspec) + readUint8(l, l.GetField(val, "major"), &ptr.Major) + readUint8(l, l.GetField(val, "minor"), &ptr.Minor) + readUint8(l, l.GetField(val, "patch"), &ptr.Patch) + readString(l, l.GetField(val, "version"), &ptr.Version) + return ptr + case "deleted_blks": + ptr := &mt.ToSrvDeletedBlks{} + val := l.CheckTable(3) + readSliceVec3Int16(l, l.GetField(val, "blks"), &ptr.Blks) + return ptr + case "fall_dmg": + ptr := &mt.ToSrvFallDmg{} + val := l.CheckTable(3) + readUint16(l, l.GetField(val, "amount"), &ptr.Amount) + return ptr + case "first_srp": + ptr := &mt.ToSrvFirstSRP{} + val := l.CheckTable(3) + readBool(l, l.GetField(val, "empty_passwd"), &ptr.EmptyPasswd) + readSliceByte(l, l.GetField(val, "salt"), &ptr.Salt) + readSliceByte(l, l.GetField(val, "verifier"), &ptr.Verifier) + return ptr + case "got_blks": + ptr := &mt.ToSrvGotBlks{} + val := l.CheckTable(3) + readSliceVec3Int16(l, l.GetField(val, "blks"), &ptr.Blks) + return ptr + case "init": + ptr := &mt.ToSrvInit{} + val := l.CheckTable(3) + readUint16(l, l.GetField(val, "max_proto_ver"), &ptr.MaxProtoVer) + readUint16(l, l.GetField(val, "min_proto_ver"), &ptr.MinProtoVer) + readString(l, l.GetField(val, "player_name"), &ptr.PlayerName) + readBool(l, l.GetField(val, "send_full_item_meta"), &ptr.SendFullItemMeta) + readUint8(l, l.GetField(val, "serialize_ver"), &ptr.SerializeVer) + readCompressionModes(l, l.GetField(val, "supported_compression"), &ptr.SupportedCompression) + return ptr + case "init2": + ptr := &mt.ToSrvInit2{} + val := l.CheckTable(3) + readString(l, l.GetField(val, "lang"), &ptr.Lang) + return ptr + case "interact": + ptr := &mt.ToSrvInteract{} + val := l.CheckTable(3) + readInteraction(l, l.GetField(val, "action"), &ptr.Action) + readUint16(l, l.GetField(val, "item_slot"), &ptr.ItemSlot) + readPointedThing(l, l.GetField(val, "pointed"), &ptr.Pointed) + readPlayerPos(l, l.GetField(val, "pos"), &ptr.Pos) + return ptr + case "inv_action": + ptr := &mt.ToSrvInvAction{} + val := l.CheckTable(3) + readString(l, l.GetField(val, "action"), &ptr.Action) + return ptr + case "inv_fields": + ptr := &mt.ToSrvInvFields{} + val := l.CheckTable(3) + readSliceField(l, l.GetField(val, "fields"), &ptr.Fields) + readString(l, l.GetField(val, "formname"), &ptr.Formname) + return ptr + case "join_mod_chan": + ptr := &mt.ToSrvJoinModChan{} + val := l.CheckTable(3) + readString(l, l.GetField(val, "channel"), &ptr.Channel) + return ptr + case "leave_mod_chan": + ptr := &mt.ToSrvLeaveModChan{} + val := l.CheckTable(3) + readString(l, l.GetField(val, "channel"), &ptr.Channel) + return ptr + case "msg_mod_chan": + ptr := &mt.ToSrvMsgModChan{} + val := l.CheckTable(3) + readString(l, l.GetField(val, "channel"), &ptr.Channel) + readString(l, l.GetField(val, "msg"), &ptr.Msg) + return ptr + case "nil": + ptr := &mt.ToSrvNil{} + return ptr + case "node_meta_fields": + ptr := &mt.ToSrvNodeMetaFields{} + val := l.CheckTable(3) + readSliceField(l, l.GetField(val, "fields"), &ptr.Fields) + readString(l, l.GetField(val, "formname"), &ptr.Formname) + readVec3Int16(l, l.GetField(val, "pos"), &ptr.Pos) + return ptr + case "player_pos": + ptr := &mt.ToSrvPlayerPos{} + val := l.CheckTable(3) + readPlayerPos(l, l.GetField(val, "pos"), &ptr.Pos) + return ptr + case "removed_sounds": + ptr := &mt.ToSrvRemovedSounds{} + val := l.CheckTable(3) + readSliceSoundID(l, l.GetField(val, "ids"), &ptr.IDs) + return ptr + case "req_media": + ptr := &mt.ToSrvReqMedia{} + val := l.CheckTable(3) + readSliceString(l, l.GetField(val, "filenames"), &ptr.Filenames) + return ptr + case "respawn": + ptr := &mt.ToSrvRespawn{} + return ptr + case "select_item": + ptr := &mt.ToSrvSelectItem{} + val := l.CheckTable(3) + readUint16(l, l.GetField(val, "slot"), &ptr.Slot) + return ptr + case "srp_bytes_a": + ptr := &mt.ToSrvSRPBytesA{} + val := l.CheckTable(3) + readSliceByte(l, l.GetField(val, "a"), &ptr.A) + readBool(l, l.GetField(val, "no_sha1"), &ptr.NoSHA1) + return ptr + case "srp_bytes_m": + ptr := &mt.ToSrvSRPBytesM{} + val := l.CheckTable(3) + readSliceByte(l, l.GetField(val, "m"), &ptr.M) + return ptr + } + + panic("invalid packet type: " + str) +} diff --git a/fromlua/static.go b/fromlua/static.go new file mode 100644 index 0000000..b989db8 --- /dev/null +++ b/fromlua/static.go @@ -0,0 +1,59 @@ +package fromlua + +import ( + "github.com/anon55555/mt" + "github.com/yuin/gopher-lua" +) + +//go:generate ./generate.lua + +func readBool(l *lua.LState, val lua.LValue, ptr *bool) { + if val.Type() != lua.LTBool { + panic("invalid value for bool: must be a boolean") + } + *ptr = bool(val.(lua.LBool)) +} + +func readString(l *lua.LState, val lua.LValue, ptr *string) { + if val.Type() != lua.LTString { + panic("invalid value for string: must be a string") + } + *ptr = string(val.(lua.LString)) +} + +func readSliceByte(l *lua.LState, val lua.LValue, ptr *[]byte) { + if val.Type() != lua.LTString { + panic("invalid value for []byte: must be a string") + } + *ptr = []byte(val.(lua.LString)) +} + +func readSliceField(l *lua.LState, val lua.LValue, ptr *[]mt.Field) { + if val.Type() != lua.LTTable { + panic("invalid value for []Field: must be a table") + } + val.(*lua.LTable).ForEach(func(k, v lua.LValue) { + if k.Type() != lua.LTString || v.Type() != lua.LTString { + panic("invalid value for Field: key and value must be strings") + } + *ptr = append(*ptr, mt.Field{Name: string(k.(lua.LString)), Value: string(v.(lua.LString))}) + }) +} + +func readPointedThing(l *lua.LState, val lua.LValue, ptr *mt.PointedThing) { + if val.Type() != lua.LTTable { + panic("invalid value for PointedThing: must be a table") + } + id := l.GetField(val, "id") + + if id == lua.LNil { + pt := &mt.PointedAO{} + readAOID(l, id, &(*pt).ID) + *ptr = pt + } else { + pt := &mt.PointedNode{} + readVec3Int16(l, l.GetField(val, "under"), &(*pt).Under) + readVec3Int16(l, l.GetField(val, "above"), &(*pt).Above) + *ptr = pt + } +} |