local sq = squill._internal
local Parser = sq.Parser
local byte, format = string.byte, string.format
local unpack = table.unpack or unpack

local APOSTROPHE, DOLLARS, QUESTION_MARK, COLON = byte("'$?:", 1, -1)

local precedence, unary_precedence = {}, {}
for i, val in ipairs({
    {"or"},
    {"and"},
    {"unary not"},
    {"<", ">", ">=", "<="},
    {"=", "<>", "is", "is not", "like", "not like", "between", "not between",
     "in", "not in"},
    {"+", "-"},
    {"*", "/", "%"},
    {"^"},
    {"||", "->", "->>"},
    {"unary +", "unary -"},
    {"."},
}) do
    for _, v in ipairs(val) do
        if v:sub(1, 6) == "unary " then
            unary_precedence[v:sub(7)] = i
        else
            precedence[v] = i
        end
    end
end

local operator_aliases = {["=="] = "=", ["!="] = "<>"}

-- Enum
local AST_OP, AST_NUM, AST_STR, AST_VAR, AST_PAR, AST_CASE, AST_CALL,
    AST_BR, AST_CUSTOM_LUA, AST_STAR = 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
sq.AST_OP = AST_OP
sq.AST_VAR = AST_VAR
sq.AST_CALL = AST_CALL
sq.AST_CUSTOM_LUA = AST_CUSTOM_LUA
sq.AST_STAR = AST_STAR

-- Pushes an operation onto a stack, converting to AST when necessary
local function push_operation(values, ops)
    local node = table.remove(ops)
    assert(node.type ~= AST_BR)
    if node.type == AST_OP then
        for i = node.argc, 1, -1 do
            node[i] = assert(table.remove(values))
        end

        -- HACK: Convert "(1 BETWEEN 2) AND 3" to a 3-argument between
        if node.op == "and" and node[1].type == AST_OP and
                (node[1].op == "between" or node[1].op == "not between") and
                node[1].argc == 2 and
                node.bracket_level == node[1].bracket_level then
            node[1].argc = 3
            node[1][3] = node[2]
            values[#values + 1] = node[1]
            return
        end
    end
    values[#values + 1] = node
end

-- Returns an AST that looks like this:
-- The "AST_BR" type is used internally and never returned
-- "AST_CUSTOM_LUA" is not returned either, it's intended for other functions
-- to embed references to Lua variables or something
--[[
{
    type = AST_OP,
    func = "+",
    argc = 2,
    {type = VALUE, value = 1},
    {type = AST_OP, func = "unary -", argc = 1, {type = VALUE, value = 2}}
}
]]
function Parser:parse_expr()
    local values = {}
    local ops = {}
    local bracket_level = 0

    local function read_operator()
        while bracket_level > 0 and self:pop_if_equals(")") do
            while ops[#ops].type ~= AST_BR do
                push_operation(values, ops)
            end
            table.remove(ops)
            bracket_level = bracket_level - 1
        end

        local op = self:peek()
        op = operator_aliases[op] or op

        -- Special case for "is not" and "not like"
        if op == "is" or op == "not" then
            self:next()
            local new_op = op .. " " .. self:peek()
            if precedence[new_op] then
                op = operator_aliases[new_op] or new_op
                self:next()
            end
        elseif not precedence[op] then
            return
        else
            self:next() -- Consume the operator
        end

        while #ops > 0 and ops[#ops].precedence >= precedence[op] do
            push_operation(values, ops)
        end

        return op, {
            type = AST_OP,
            argc = 2,
            op = op,
            precedence = precedence[op],

            -- For BETWEEN syntax hack
            bracket_level = bracket_level
        }
    end

    -- https://en.wikipedia.org/wiki/Shunting_yard_algorithm
    while true do
        local value = self:next()
        if value == "(" then
            bracket_level = bracket_level + 1

            -- Push a special value
            ops[#ops + 1] = {
                type = AST_BR,

                -- Prevent the precedence checking code from trying to remove
                -- this
                precedence = -1
            }
        elseif unary_precedence[value] then
            -- Unary operators
            ops[#ops + 1] = {
                type = AST_OP,
                argc = 1,
                op = "unary " .. value,
                precedence = unary_precedence[value],
            }
        else
            -- Resolve these now so that error messages can be done better
            local chr = byte(value, 1)
            if sq.is_digit(chr) then
                values[#values + 1] = {
                    type = AST_NUM,
                    n = self:assert(tonumber(value))
                }
            elseif chr == APOSTROPHE then
                values[#values + 1] = {
                    type = AST_STR,
                    s = value:sub(2, -2):gsub("''", "'")
                }
            elseif chr == DOLLARS or chr == QUESTION_MARK or chr == COLON then
                if self.param_chr ~= nil and self.param_chr ~= chr then
                    self:error("Cannot mix ? and $ argument substitutions")
                end
                self.param_chr = chr

                if chr == COLON then
                    values[#values + 1] = {type = AST_PAR, name = value:sub(2)}
                    self.param_count = 1
                else
                    local n
                    if chr == DOLLARS then
                        n = self:assert(tonumber(value:sub(2)))
                        self:assert(n % 1 == 0)
                    else
                        n = self.param_count + 1
                    end
                    values[#values + 1] = {type = AST_PAR, n = n}
                    self.param_count = math.max(self.param_count, n)
                end
            elseif value == "case" then
                local node = {type = AST_CASE}
                if self:peek() ~= "when" then
                    node.compare_to = self:parse_expr()
                end

                self:expect("when")
                repeat
                    local when = self:parse_expr()
                    self:expect("then")
                    node[#node + 1] = {when = when, value = self:parse_expr()}
                until not self:pop_if_equals("when")
                if self:pop_if_equals("else") then
                    node[#node + 1] = {value = self:parse_expr()}
                end
                self:expect("end")
                ops[#ops + 1] = node
            elseif self:pop_if_equals("(") then
                self:assert(sq.valid_identifier(value))
                self:assert(not ops[#ops] or ops[#ops].op ~= ".")
                local node = {type = AST_CALL, func = value}
                if value == "count" and self:pop_if_equals("*") then
                    -- Special case for COUNT(*)
                    node[#node + 1] = {type = AST_STAR}
                elseif self:peek() ~= ")" then
                    -- Special case for COUNT DISTINCT
                    if self:pop_if_equals("distinct") then
                        node.distinct = true
                    end

                    repeat
                        node[#node + 1] = self:parse_expr()
                    until not self:pop_if_equals(",")
                end
                self:expect(")")
                node.argc = #node
                values[#values + 1] = node
            else
                self:assert(sq.valid_identifier(value) or sq.special_vars[value])
                values[#values + 1] = {
                    type = AST_VAR,
                    name = value
                }
            end

            local op, node = read_operator()
            -- Special case to parse IN and NOT IN with an argument list
            while op == "in" or op == "not in" do
                node[1] = assert(table.remove(values))
                self:expect("(")
                if self:peek() ~= ")" then
                    repeat
                        node[#node + 1] = self:parse_expr()
                    until not self:pop_if_equals(",")
                end
                self:expect(")")
                node.argc = #node
                values[#values + 1] = node

                op, node = read_operator()
            end

            if not op then break end

            ops[#ops + 1] = node
        end
    end

    -- for i = #ops, 1, -1 do
    --     assert(ops[i] ~= "(", "Unclosed bracket")
    --     values[#values + 1] = ops[i]
    -- end
    while #ops > 0 do
        push_operation(values, ops)
    end

    -- We should now have only one value on the stack that's an operation
    assert(#values == 1)
    return values[1]
end

-- Returns an iterator that returns all nodes in the AST.
function sq.walk_ast(tree, skip_table_names, skip_attributes)
    local stack = {tree}
    return function()
        local node = table.remove(stack)
        if not node then return end

        if node.type == AST_CALL or (node.type == AST_OP and
                (not skip_attributes or node.op ~= ".")) then
            local start = (skip_table_names and node.op == ".") and 2 or 1
            for i = start, node.argc do
                stack[#stack + 1] = node[i]
            end
        elseif node.type == AST_CASE then
            for _, c in ipairs(node) do
                if c.when then
                    stack[#stack + 1] = c.when
                end
                stack[#stack + 1] = c.value
            end
        end
        return node
    end
end

-- Returns true if the expression passed in contains the "name" variable.
function sq.expr_mentions_variable(tree, name)
    for node in sq.walk_ast(tree, true) do
        if node.type == AST_VAR and node.name == name then
            return true
        end
    end
    return false
end

function Parser:resolve_bare_identifier(column_name)
    local matches = {}
    for alias, table_name in pairs(self.table_lookup) do
        if table.indexof(self:list_columns(table_name), column_name) > 0 then
            matches[#matches + 1] = alias
        end
    end

    if #matches == 1 then
        return matches[1]
    elseif #matches == 0 then
        local msg = format("Unknown column name: %q", column_name)
        if column_name == "nil" then
            msg = msg .. ", did you mean NULL?"
        end
        error(msg, 0)
    else
        error(format(
            "Ambiguous column name: %q, exists on the following tables: %s",
            column_name, table.concat(matches, ", ")
        ), 0)
    end
end

-- Returns tables that the AST expression will access in the form
-- {table_alias = true, some_other_alias = true}
local bad_attr_msg = "The '.' operator can only be used with bare names " ..
    "(i.e. table.attribute)"
function Parser:list_accessed_tables(tree)
    local tables = {}
    for node in sq.walk_ast(tree, false, true) do
        if node.type == AST_OP and node.op == "." then
            assert(node[1].type == AST_VAR, bad_attr_msg)
            tables[node[1].name] = true
        elseif node.type == AST_VAR and not sq.special_vars[node.name] then
            tables[self:resolve_bare_identifier(node.name)] = true
        end
    end
    return tables
end

-- Operations that can be converted to plain Lua if all operands are numbers
local lua_ops = {
    ["+"] = "%s + %s", ["-"] = "%s - %s",
    ["*"] = "%s * %s", ["/"] = "%s / %s",
    ["^"] = "%s ^ %s", ["%"] = "%s %% %s",
    ["unary +"] = "+%s", ["unary -"] = "-%s",
}

local function case_to_function_body(self, code, node)
    if node.compare_to then
        code[#code + 1] = format("local case = %s",
            self:expr_to_lua(node.compare_to))
    end

    for i, c in ipairs(node) do
        if c.when then
            local when, is_numeric = self:expr_to_lua(c.when)
            if node.compare_to then
                if is_numeric or when.type == AST_STR then
                    -- Guaranteed to not be null == null, use == for performance
                    when = format("case == %s", when)
                else
                    -- Use custom = operator to handle nulls correctly
                    when = format("OPS['='](case, %s)", when)
                end
            else
                when = format("OPS.to_boolean(%s)", when)
            end
            code[#code + 1] = format("%s %s then", i > 1 and "elseif" or "if", when)
        else
            code[#code + 1] = "else"
        end

        if c.value.type == AST_CASE then
            -- Embed CASE statements directly if possible
            case_to_function_body(self, code, c.value)
        else
            code[#code + 1] = format(
                "return %s",
                self:expr_to_lua(c.value)
            )
        end
    end
    code[#code + 1] = "end"
end

local sql_funcs = {}
function sql_funcs.concat(self, node)
    -- Convert CONCAT() into a native concatenation operation
    local code = {}
    for i = 1, node.argc do
        local arg, is_equation = self:expr_to_lua(node[i])
        -- Avoid coerce_to_string if this known to be a string
        if node[i].type ~= AST_STR and (node.argc == 1 or not is_equation) then
            arg = format("(coerce_to_string(%s) or '')", arg)
        end
        code[i] = arg
    end
    return "(" .. table.concat(code, " .. ") .. ")"
end

local function runtime_evaluated(name, arg_min, arg_max, not_a_function)
    arg_min = arg_min or 1
    arg_max = arg_max or arg_min
    return function(self, node)
        self:assert(node.argc >= arg_min,
            "%s() needs at least %d argument(s)", node.func, arg_min)
        self:assert(node.argc <= arg_max,
            "%s() takes at most %d argument(s)", node.func, arg_max)

        local args = {}
        for i = 1, node.argc do
            args[i] = self:expr_to_lua(node[i])
        end

        if not_a_function then
            return name
        end

        return name .. "(" .. table.concat(args, ", ") .. ")"
    end
end

sql_funcs.upper = runtime_evaluated("str_upper")
sql_funcs.lower = runtime_evaluated("str_lower")
sql_funcs.coalesce = runtime_evaluated("coalesce", 2, math.huge)

-- Math functions that have a 1:1 equivalent in Lua
for _, name in ipairs({"abs", "acos", "asin", "atan", "ceil", "cos", "cosh", "exp",
        "floor", "log10", "round", "sign", "sin", "sinh", "sqrt", "tan", "tanh"}) do
    if math[name] then
        sql_funcs[name] = runtime_evaluated("math." .. name)
    else
        core.log("warning", "[squill] Could not find math." .. name ..
            ", is Luanti up to date?")
    end
end
sql_funcs.atan2 = runtime_evaluated("math.atan2", 2)
sql_funcs.degrees = runtime_evaluated("math.deg")
sql_funcs.ln = runtime_evaluated("math.log")
sql_funcs.ceiling = sql_funcs.ceil
sql_funcs.pi = runtime_evaluated("math.pi", 0, 0, true)
sql_funcs.pow = runtime_evaluated("OPS['^']", 2, 2)
sql_funcs.power = sql_funcs.pow
sql_funcs.radians = runtime_evaluated("math.rad")
sql_funcs.max = runtime_evaluated("math.max", 2, math.huge)
sql_funcs.min = runtime_evaluated("math.min", 2, math.huge)

sql_funcs.date = runtime_evaluated(sq.special_vars.current_date, 0, 0, true)
sql_funcs.datetime = runtime_evaluated(sq.special_vars.current_timestamp, 0, 0, true)
sql_funcs.time = runtime_evaluated(sq.special_vars.current_time, 0, 0, true)

-- Returns lua_code, is_pure_arithmetic
function Parser:expr_to_lua(node)
    if node.type == AST_NUM then
        if node.n == math.huge then
            return "1e999", true
        elseif node.n == -math.huge then
            return "-1e999", true
        end
        return format("%.17g", node.n), true
    elseif node.type == AST_STR then
        return format("%q", node.s), false
    elseif node.type == AST_PAR then
        if node.name then
            return format("%s1[%q]", sq.PARAM_VAR_PREFIX, node.name)
        else
            return sq.PARAM_VAR_PREFIX .. node.n
        end
    elseif node.type == AST_VAR then
        if sq.special_vars[node.name] then
            return sq.special_vars[node.name], false
        end
        local table_alias = self:resolve_bare_identifier(node.name)
        return self:get_col_ref(table_alias, node.name), false
    elseif node.type == AST_CASE then
        -- An anonymous function like this seems like the easiest way to
        -- translate CASE statements without doing anything special for
        -- variable scoping.
        local code = {"(function()"}
        case_to_function_body(self, code, node)
        code[#code + 1] = "end)()"
        return table.concat(code, "\n")
    elseif node.type == AST_CALL then
        local func = node.func
        if func == "count" or func == "sum" or
                ((func == "min" or func == "max") and node.argc == 1) then
            error(format(
                "Aggregate function %q used in unsupported location",
                func
            ))
        end

        local func_codegen = sql_funcs[func]
        if not func_codegen then
            error(format("SQL function %q not implemented", func))
        elseif node.distinct then
            error(format("%s(DISTINCT ...) is invalid or not implemented",
                func:upper()))
        end
        return func_codegen(self, node)
    elseif node.type == AST_CUSTOM_LUA then
        return node.lua
    end

    assert(node.type == AST_OP)
    if node.op == "." then
        assert(node[1].type == AST_VAR, bad_attr_msg)
        assert(node[2].type == AST_VAR, bad_attr_msg)
        return self:get_col_ref(node[1].name, node[2].name), false
    end

    if (node.op == "like" or node.op == "not like") and
            node[2].type == AST_STR then
        assert(node.argc == 2)
        node.argc = 3

        -- Compile LIKE patterns now if they are given a string
        node[2].s = sq.compile_like_pattern(node[2].s)
        node[3] = {type = AST_VAR, name = "true"}
    end

    if node.op == "between" and node.argc == 2 then
        self:error("Invalid syntax for BETWEEN")
    end

    local args = {}
    local all_numbers = true
    for i = 1, node.argc do
        local is_equation
        args[i], is_equation = self:expr_to_lua(node[i])
        all_numbers = all_numbers and is_equation
    end

    if all_numbers and lua_ops[node.op] then
        -- All operands are numbers or equations with numbers, compile to a
        -- plain Lua expression so that Lua can optimise them. They are wrapped
        -- in brackets to prevent Lua from interpreting precedence differently
        return "(" .. format(lua_ops[node.op], unpack(args)) .. ")", true
    end

    if node.op == "and" or node.op == "or" then
        -- Short circuit and/or, which are implemented with two helper
        -- functions to ensure that they work correctly with null
        assert(node.argc == 2)
        return format(
            "(OPS.%s_left(%s) %s OPS.%s_right(%s))",
            node.op, args[1], node.op, node.op, args[2]
        )
    elseif node.op == "is" then
        assert(node.argc == 2)
        return format("(%s == %s)", args[1], args[2]), false
    elseif node.op == "is not" then
        assert(node.argc == 2)
        return format("(%s ~= %s)", args[1], args[2]), false
    end

    return format("OPS[%q](%s)", node.op, table.concat(args, ", ")), false
end
