--- Bytecode parsing.
-- Please note that this module is experimental and subject to change.
-- @module advtrains_doc_integration.bc
local bc = {}

local band = bit.band
local brshift = bit.rshift

local enable = core.settings:get_bool("atdoc_enable_bc", false)

local function read_u8(str, pos)
	return string.byte(str, pos, pos)
end

local function read_u16le(str, pos)
	local l, u = string.byte(str, pos, pos+1)
	return u*32+l
end

local function read_u32le(str, pos)
	local b0, b1, b2, b3 = string.byte(str, pos, pos+3)
	return b3*256^3+b2*256^2+b1*256+b0
end

local function construct_double(lo, hi)
	local exp = band(brshift(hi, 20), 0x7ff)
	local mat = (band(hi, 0xfffff)+lo/2^32)/0x100000
	local sign = brshift(hi, 31)
	sign = (-1)^sign
	if exp == 0 then
		return sign*math.ldexp(mat, exp-1022)
	elseif exp == 0x7ff then
		if mat == 0 then
			return sign*math.huge
		end
		return nil
	end
	return sign*math.ldexp(1+mat, exp-1023)
end

local function readflags(val, spec)
	local t = {}
	for k, v in pairs(spec) do
		local b = band(val, v)
		if b ~= 0 then
			t[k] = b
		end
	end
	return t
end

local function read_lj_uleb128(str, pos)
	local b = read_u8(str, pos)
	local v = 0
	local count = 0
	while b >= 128 do
		v = v + (b%128)*128^count
		count = count + 1
		b = read_u8(str, pos+count)
	end
	return v+b*128^count, pos+count+1
end

local function read_lj_uleb128_33(str, pos)
	local b = brshift(read_u8(str, pos), 1)
	if b >= 64 then
		local up, p2 = read_lj_uleb128(str, pos+1)
		return b%64+64*up, p2
	else
		return b, pos+1
	end
end

local function read_lj_double(str, pos)
	local lo, p2 = read_lj_uleb128(str, pos)
	local hi, p3 = read_lj_uleb128(str, p2)
	local d = construct_double(lo, hi)
	if not d then
		return nil, ("Bad double: 0x%08x%08x"):format(hi, lo)
	end
	return d, p3
end

local function read_lj_int_or_double(str, pos)
	if read_u8(str, pos) % 2 == 0 then
		return read_lj_uleb128_33(str, pos)
	else
		local lo, p2 = read_lj_uleb128_33(str, pos)
		local hi, p3 = read_lj_uleb128(str, p2)
		local d = construct_double(lo, hi)
		if not d then
			return nil, ("Bad double: 0x%08x%08x"):format(hi, lo)
		end
		return d, p3
	end
end

local lj_bcdef = {
	[0] = {"ISLT", "var", nil, "var"},
	{"ISGE", "var", nil, "var"},
	{"ISLE", "var", nil, "var"},
	{"ISGT", "var", nil, "var"},
	--
	{"ISEQV", "var", nil, "var"},
	{"ISNEV", "var", nil, "var"},
	{"ISEQS", "var", nil, "str"},
	{"ISNES", "var", nil, "str"},
	{"ISEQN", "var", nil, "num"},
	{"ISNEN", "var", nil, "num"},
	{"ISEQP", "var", nil, "pri"},
	{"ISNEP", "var", nil, "pri"},
	--
	{"ISTC", "dst", nil, "var"},
	{"ISFC", "dst", nil, "var"},
	{"IST", nil, nil, "var"},
	{"ISF", nil, nil, "var"},
	{"ISTYPE", "var", nil, "lit"},
	{"ISNUM", "var", nil, "lit"},
	--
	{"MOV", "dst", nil, "var"},
	{"NOT", "dst", nil, "var"},
	{"UNM", "dst", nil, "var"},
	{"LEN", "dst", nil, "var"},
}
for _, var in ipairs {
	{"VN", "dst", "var", "num"},
	{"NV", "dst", "var", "num"},
	{"VV", "dst", "num", "var"},
} do
	for _, ins in ipairs {"ADD", "SUB", "MUL", "DIV", "MOD"} do
		table.insert(lj_bcdef, {ins..var[1], unpack(var, 2, 4)})
	end
end
for _, ent in ipairs {
	{"POW", "dst", "var", "var"},
	{"CAT", "dst", "rbase", "rbase"},
	--
	{"KSTR", "dst", nil, "str"},
	{"KCDATA", "dst", nil, "cdata"},
	{"KSHORT", "dst", nil, "lits"},
	{"KNUM", "dst", nil, "num"},
	{"KPRI", "dst", nil, "pri"},
	{"KNIL", "dst", nil, "base"},
	{"UGET", "dst", nil, "uv"},
	{"USETV", "uv", nil, "var"},
	{"USETS", "uv", nil, "str"},
	{"USETN", "uv", nil, "num"},
	{"USETP", "uv", nil, "pri"},
	{"UCLO", "rbase", nil, "jump"},
	{"FNEW", "dst", nil, "func"},
	--
	{"TNEW", "dst", nil, "lit"},
	{"TDUP", "dst", nil, "tab"},
	{"GGET", "dst", nil, "str"},
	{"GSET", "var", nil, "str"},
	{"TGETV", "dst", "var", "var"},
	{"TGETS", "dst", "var", "str"},
	{"TGETB", "dst", "var", "lit"},
	{"TGETR", "dst", "var", "var"},
	{"TSETV", "var", "var", "var"},
	{"TSETS", "var", "var", "str"},
	{"TSETB", "var", "var", "lit"},
	{"TSETM", "base", nil, "num"},
	{"TSETR", "var", "var", "var"},
	--
	{"CALLM", "base", "lit", "lit"},
	{"CALL", "base", "lit", "lit"},
	{"CALLMT", "base", nil, "lit"},
	{"CALLT", "base", nil, "lit"},
	{"ITERC", "base", "lit", "lit"},
	{"ITERN", "base", "lit", "lit"},
	{"VARG", "base", "lit", "lit"},
	{"ISNEXT", "base", nil, "jump"},
	--
	{"RETM", "base", nil, "lit"},
	{"RET", "rbase", nil, "lit"},
	{"RET0", "rbase", nil, "lit"},
	{"RET1", "rbase", nil, "lit"},
	--
	{"FORI", "base", nil, "jump"},
	{"JFORI", "base", nil, "jump"},
	--
	{"FORL", "base", nil, "jump"},
	{"IFORL", "base", nil, "jump"},
	{"JFORL", "base", nil, "lit"},
	--
	{"ITERL", "base", nil, "jump"},
	{"IITERL", "base", nil, "jump"},
	{"JITERL", "base", nil, "lit"},
	--
	{"LOOP", "rbase", nil, "jump"},
	{"ILOOP", "rbase", nil, "jump"},
	{"JLOOP", "rbase", nil, "lit"},
	--
	{"JMP", "rbase", nil, "jump"},
	--
	{"FUNCF", "rbase", nil, nil},
	{"IFUNCF", "rbase", nil, nil},
	{"JFUNCF", "rbase", nil, "lit"},
	{"FUNCV", "rbase", nil, nil},
	{"IFUNCV", "rbase", nil, nil},
	{"JFUNCV", "rbase", nil, "lit"},
	{"FUNCC", "rbase", nil, nil},
	{"FUNCCW", "rbase", nil, nil},
} do
	table.insert(lj_bcdef, ent)
end

local function lj_read_nbytes(dump, pos)
	local len, p2 = read_lj_uleb128(dump, pos)
	if len == 0 then
		return "", p2
	else
		local val = string.sub(dump, p2, p2+len-1)
		return val, p2+len
	end
end

local function lj_parse_bytecode(phead, pos, pstr, _)
	local inslist = {}
	for k = 1, phead.numbc do
		local w = read_u32le(pstr, pos+4*k-4)
		local op = lj_bcdef[w%256]
		if not op then
			return nil, ("Invalid opcode: %02X"):format(op)
		end
		local ins = {op[1], {type = op[2], value = math.floor(w/256)%256}}
		if op[3] then
			ins[3] = {type = op[3], value = math.floor(w/256^3)%256}
			ins[4] = {type = op[4], value = math.floor(w/256^2)%256}
		else
			ins[3] = {type = op[4], value = math.floor(w/256^2)}
		end
		inslist[k] = ins
	end
	return inslist, pos+4*phead.numbc
end

local function lj_parse_uv(phead, pos, pstr, _)
	local uvlist = {}
	for k = 1, phead.numuv do
		uvlist[k] = read_u16le(pstr, pos+2*k-2)
	end
	return uvlist, pos+2*phead.numuv
end

local lj_ktab_type = {
	["nil"] = 0,
	["false"] = 1,
	["true"] = 2,
	int = 3,
	num = 4,
	str = 5,
}

local function lj_parse_ktabk(_, pos, pstr, _)
	local tp, p2 = read_lj_uleb128(pstr, pos)
	if tp >= lj_ktab_type.str then
		local len = tp - lj_ktab_type.str
		return string.sub(pstr, p2, p2+len-1), p2+len
	elseif tp == lj_ktab_type.int then
		return read_lj_uleb128(pstr, p2)
	elseif tp == lj_ktab_type.num then
		return read_lj_double(pstr, p2)
	elseif tp == lj_ktab_type["nil"] then
		return nil, p2
	elseif tp == lj_ktab_type["true"] then
		return true, p2
	elseif tp == lj_ktab_type["false"] then
		return false, p2
	end
	return nil, ("Bad KTABK constant type %d"):format(tp)
end

local function lj_parse_ktab(phead, pos, pstr, bcflags)
	local tab = {}
	local narr, p2 = read_lj_uleb128(pstr, pos)
	local nhash, p3 = read_lj_uleb128(pstr, p2)
	pos = p3
	for k = 0, narr-1 do
		tab[k], pos = lj_parse_ktabk(phead, pos, pstr, bcflags)
		if tab[k] == nil and type(pos) ~= "number" then
			return nil, pos
		end
	end
	for _ = 1, nhash do
		local k, p4 = lj_parse_ktabk(phead, pos, pstr, bcflags)
		if k == nil then
			if type(p4) == "number" then
				return nil, "Table index is nil"
			end
			return nil, p4
		end
		tab[k], pos = lj_parse_ktabk(phead, p4, pstr, bcflags)
		if tab[k] == nil and type(pos) ~= "number" then
			return nil, pos
		end
	end
	return tab, pos
end

local lj_kgc_type = {
	child = 0,
	tab = 1,
	str = 5,
}

local function lj_parse_kgc(phead, pos, pstr, bcflags)
	local gclist = {}
	for k = phead.numkgc-1, 0, -1 do
		local tp, p2 = read_lj_uleb128(pstr, pos)
		if tp >= lj_kgc_type.str then
			local len = tp-lj_kgc_type.str
			local str = string.sub(pstr, p2, p2+len-1)
			gclist[k] = str
			pos = p2+len
		elseif tp == lj_kgc_type.tab then
			local tbl, p3 = lj_parse_ktab(phead, p2, pstr, bcflags)
			if not tbl then
				return nil, p3
			end
			gclist[k] = tbl
			pos = p3
		elseif tp == lj_kgc_type.child then
			local idx = bcflags.top - 1
			if idx < 0 then
				return nil, "Child stack underflow"
			end
			gclist[k] = idx
			bcflags.top = idx
			pos = p2
		else
			return nil, ("Bad constant type %d"):format(tp)
		end
	end
	return gclist, pos
end

local function lj_parse_kn(phead, pos, pstr, _)
	local numlist = {}
	for k = 0, phead.numkn-1 do
		local n, p2 = read_lj_int_or_double(pstr, pos)
		if not n then
			return nil, p2
		end
		numlist[k] = n
		pos = p2
	end
	return numlist, pos
end

local function lj_parse_pdata_body(phead, pstr, bcflags)
	local pdata = {
		header = phead,
	}
	local pos = 1
	pdata.bytecode, pos = lj_parse_bytecode(phead, pos, pstr, bcflags)
	if not pdata.bytecode then
		return nil, pos
	end
	pdata.uv, pos = lj_parse_uv(phead, pos, pstr, bcflags)
	if not pdata.uv then
		return nil, pos
	end
	pdata.kgc, pos = lj_parse_kgc(phead, pos, pstr, bcflags)
	if not pdata.kgc then
		return nil, pos
	end
	pdata.kn, pos = lj_parse_kn(phead, pos, pstr, bcflags)
	if not pdata.kn then
		return nil, pos
	end
	return pdata
end

local lj_proto_flags = {}

local function lj_parse_proto(pstr, bcflags)
	local phead, pos = {}
	phead.flags = readflags(read_u8(pstr, 1), lj_proto_flags)
	phead.numparams = read_u8(pstr, 2)
	phead.framesize = read_u8(pstr, 3)
	phead.numuv = read_u8(pstr, 4)
	phead.numkgc, pos = read_lj_uleb128(pstr, 5)
	phead.numkn, pos = read_lj_uleb128(pstr, pos)
	phead.numbc, pos = read_lj_uleb128(pstr, pos)
	if not bcflags.strip then
		phead.debuglen, pos = read_lj_uleb128(pstr, pos)
		if phead.debuglen > 0 then
			phead.firstline, pos = read_lj_uleb128(pstr, pos)
			phead.numline, pos = read_lj_uleb128(pstr, pos)
		end
	end
	return lj_parse_pdata_body(phead, string.sub(pstr, pos), bcflags)
end

local lj_bcdump_flags = {
	be = 1,
	strip = 2,
	ffi = 4,
	fr2 = 8,
}

local function parse_lj2(dump)
	local flags, pos = read_lj_uleb128(dump, 1)
	flags = readflags(flags, lj_bcdump_flags)
	if flags.ffi then
		return nil, "LuaJIT bytecode dump with FFI is not supported"
	elseif flags.be then
		return nil, "Big-endian LuaJIT bytecode is not supported"
	end
	local chunkname
	if not flags.strip then
		local cname, p2 = lj_read_nbytes(dump, pos)
		pos = p2
		if cname ~= "" then
			chunkname = cname
		end
	end
	local prototypes = {}
	flags.top = 0
	while true do
		local proto, p2 = lj_read_nbytes(dump, pos)
		pos = p2
		if proto == "" then
			break
		end
		local pdata, err = lj_parse_proto(proto, flags)
		if pdata == nil then
			return nil, err
		end
		flags.top = flags.top + 1
		table.insert(prototypes, pdata)
	end
	flags.top = nil
	return {
		chunkname = chunkname,
		prototypes = prototypes,
	}
end

local function parse_lj(dump)
	local version = string.byte(dump, 1, 1)
	if version == 2 then
		return parse_lj2(string.sub(dump, 2))
	end
	return nil, "Unsupported LuaJIT bytecode version"
end

local function ensure_result(st, ...)
	local count = select("#", ...)
	if count > 0 and ... ~= nil then
		return st, ...
	end
	return ...
end

--- Try to parse a bytecode dump.
-- @tparam string|function dump The bytecode input or the function to read.
-- @return[1] "luajit" If `dump` is valid LuaJIT bytecode.
-- @treturn[1] ... Data parsed from the bytecode.
-- @treturn[2] nil If the dump cannot be parsed.
-- @treturn[2] string A message indicating the error.
function bc.parse(dump)
	local tp = type(dump)
	if tp == "function" then
		return bc.parse(string.dump(dump))
	elseif tp ~= "string" then
		return nil, "Invalid bytecode dump type"
	end
	local header = string.sub(dump, 1, 3)
	if header == "\27LJ" then
		return ensure_result("luajit", parse_lj(string.sub(dump, 4)))
	end
	return nil, "Unsupported bytecode dump format"
end

local escape_string_table = {
	["\n"] = [[\n]],
	["\r"] = [[\r]],
	["\0"] = [[\z]],
	["\""] = [[\"]],
	["\\"] = [[\\]],
}

local function escape_string(str)
	return (string.gsub(str, "[%z\1-\31\127-\255]", function(c)
		if escape_string_table[c] then
			return escape_string_table[c]
		end
		return ([[\%d]]):format(string.byte(c))
	end))
end

local function lj_value_tostring(proto, line, value)
	local vt, vv = value.type, value.value
	local vs = ("%3d"):format(vv)
	if vt == nil then
		return "   "
	elseif vt == "jump" then
		return ("=> %04d"):format(line+vv-32767)
	elseif vt == "str" then
		local ref = proto.kgc[vv]
		if type(ref) == "string" then
			return vs, ([["%s"]]):format(escape_string(ref))
		end
	elseif vt == "func" then
		local ref = proto.kgc[vv]
		if type(ref) == "number" then
			return vs, ("BYTECODE %d"):format(ref)
		end
	elseif vt == "num" then
		local num = proto.kn[vv]
		if type(num) == "number" then
			return vs, tostring(num)
		end
	elseif vt == "lits" then
		if vv >= 32768 then
			return ("%3d"):format(vv-65536)
		end
	end
	return vs
end

local function lj_kgcatom_tostring(val, top)
	local tp = type(val)
	if tp == "string" then
		return ([["%s"]]):format(escape_string(val))
	elseif tp == "number" then
		if top then
			return ("[BYTECODE %d]"):format(val)
		else
			return tostring(val)
		end
	elseif tp == "table" then
		if next(val) == nil then
			return "{}"
		end
		return "[TABLE]"
	else
		return "[???]"
	end
end

local function lj_proto_tostring(index, proto)
	local st = {("-- BYTECODE -- %d"):format(index-1)}
	local jmp_target = {}
	for ln, line in ipairs(proto.bytecode) do
		if line[3] and line[3].type == "jump" then
			jmp_target[ln+line[3].value-32767] = true
		end
	end
	local kgc_count = #proto.kgc
	if proto.kgc[0] == nil then
		kgc_count = -1
	end
	for id = 0, kgc_count do
		local val = proto.kgc[id]
		if type(val) == "table" and next(val) ~= nil then
			table.insert(st, ("%-7s %6d {"):format("KGC", id))
			local prepend = ("%-7s %6s"):format("KGC", ".")
			-- LJ bytecode does not have nested tables
			for k, v in pairs(val) do
				table.insert(st, ("%s   [%s] = %s,"):format(prepend,
					lj_kgcatom_tostring(k), lj_kgcatom_tostring(v)))
			end
			table.insert(st, ("%s }"):format(prepend))
		else
			table.insert(st, ("%-7s %6d %s"):format("KGC", id, lj_kgcatom_tostring(val, true)))
		end
	end
	local kn_count = #proto.kn
	if proto.kn[0] == nil then
		kn_count = -1
	end
	for id = 0, kn_count do
		table.insert(st, ("%-7s %6d %s"):format("KN", id, proto.kn[id]))
	end
	for ln, line in ipairs(proto.bytecode) do
		local lt = {("%04d"):format(ln)}
		if jmp_target[ln] then
			table.insert(lt, "=>")
		else
			table.insert(lt, "  ")
		end
		table.insert(lt, ("%-6s"):format(line[1]))
		table.insert(lt, lj_value_tostring(proto, ln, line[2]))
		if line[4] then
			local bs, bn = lj_value_tostring(proto, ln, line[3])
			local cs, cn = lj_value_tostring(proto, ln, line[4])
			table.insert(lt, ("%3d %3d"):format(bs, cs))
			if bn or cn then
				table.insert(lt, ";")
				table.insert(lt, bn)
				table.insert(lt, cn)
			end
		else
			local ds, dn = lj_value_tostring(proto, ln, line[3])
			table.insert(lt, ("%-7s"):format(ds))
			if dn then
				table.insert(lt, ("; %s"):format(dn))
			end
		end
		table.insert(st, table.concat(lt, " "))
	end
	return table.concat(st, "\n")
end

--- Try to format a bytecode dump.
-- @tparam string|function dump The bytecode input of the function to read.
-- @return[1] "luajit" If `dump` is valid LuaJIT bytecode.
-- @treturn[1] string A string describing the bytecode dump. The format is similar to that of `luajit -bl`
-- @treturn[2] nil If the dump cannot be parsed.
-- @treturn[2] string A message indicating the error.
function bc.tostring(dump)
	local tp, data = bc.parse(dump)
	if tp == "luajit" then
		local st = {}
		for k, proto in ipairs(data.prototypes) do
			table.insert(st, lj_proto_tostring(k, proto))
		end
		return tp, table.concat(st, "\n\n")
	end
	return nil, data
end

if not enable then
	local function noop()
		return nil, "Bytecode parsing is disabled"
	end
	bc.parse = noop
	bc.tostring = noop
end

-- luacheck: ignore 511
if false then -- debugging only
	minetest.register_chatcommand("atdoc_format_function", {
		params = "<lua code>",
		deescription = "Execute the given lua code and dump the resulting function",
		privs = {server = true},
		func = function(_, param)
			local f, err = loadstring(param)
			if not f then
				return false, err
			end
			local st, val = pcall(f)
			if not st then
				return false, val
			end
			local tp, desc = bc.tostring(val)
			if not tp then
				return false, desc
			end
			return true, desc
		end,
	})
end

return bc
