local sq = squill._internal
local format = string.format

local table_exists_stmt = sq.bootstrap_statement([[
    SELECT COUNT(*) FROM schema WHERE db_name = ? AND table = ?
]], squill.RETURN_SINGLE_VALUE)

local function parse_as(parser)
    if not parser:pop_if_equals("as") then
        local alias = parser:peek()
        if not alias or (alias:sub(1, 1) ~= "'" and not sq.valid_identifier(alias)) then
            return
        end
    end

    local alias = parser:next()
    if alias:sub(1, 1) == "'" then
        return alias:sub(2, -2):gsub("''", "'")
    end

    parser:assert(sq.valid_identifier(alias),
        "Invalid variable name for alias: %q", alias)
    return alias
end

local select_cmd
local function parse_nested_select(env)
    -- Selecting from another SELECT
    local parser = env.parser
    parser:expect("select")
    local nested_code, returned_columns = select_cmd(parser)
    parser:expect(")")

    -- The identifier uses capital letters so it doesn't conflict with actual
    -- identifiers
    local alias = parse_as(parser) or env:sym("NESTED")

    env:assert(not env.table_lookup[alias] and not env.nested_selects[alias],
        "Duplicate table name: %s", alias)

    env.nested_selects[alias] = {
        code = nested_code,
        returned_columns = returned_columns,
        sym = env:sym("nested_select_rows"),
    }

    return nil, alias
end

local function parse_table_name(env)
    local table_name = env.parser:next()
    if table_name == "(" then
        return parse_nested_select(env)
    end

    env:assert(sq.bootstrap or table_exists_stmt(env.db_name, table_name) > 0,
        "Table %q does not exist", table_name)

    local alias = parse_as(env.parser) or table_name
    env:assert(not env.table_lookup[alias] and not env.nested_selects[alias],
        "Duplicate table name: %s", alias)
    env.table_lookup[alias] = table_name
    return table_name, alias
end

local duplicate_err = "Duplicate column name in returned value: %q\n" ..
    "This is not supported by Squill to make its implementation simpler.\n" ..
    "Try using AS to give one of the returned columns a different name."

local aggregate_funcs = {
    count = "total + 1",
    min = "(total and total < value) and total or value",
    max = "(total and total > value) and total or value",
    sum = "(total or 0) + value",
}

local function parse_select_result_column(parser, return_exprs, returned_columns)
    local expr = parser:parse_expr()

    local name = parse_as(parser)
    if not name then
        if expr.type == sq.AST_OP and expr.op == "." and expr[1].type == sq.AST_VAR then
            name = expr[2].name
        elseif expr.type == sq.AST_VAR then
            name = expr.name
        else
            name = "?column?"
            while return_exprs[name] do
                -- Allow duplicate auto-named columns for easier debugging (and
                -- for APIs where they are not returned by name)
                -- The columns are named randomly to discourage relying on this
                -- name, however (it may change in the future)
                name = format("?column~%04x?", math.random(0, 65535))

                -- Should be impossible anyway
                assert(#returned_columns < 32768)
            end
        end
    end

    parser:assert(not return_exprs[name], duplicate_err, name)

    returned_columns[#returned_columns + 1] = name
    return_exprs[name] = expr

    return name
end
sq.parse_select_result_column = parse_select_result_column

local function separate_aggregate_funcs(env, tree, aggregations, final_group_var)
    for node in sq.walk_ast(tree) do
        if node.type == sq.AST_CALL and aggregate_funcs[node.func] and
                node.argc == 1 then
            local idx = #aggregations + 1
            local a = {func = node.func, distinct = node.distinct}
            if node[1].type ~= sq.AST_STAR then
                a.lua = env:expr_to_lua(node[1])
            end
            aggregations[idx] = a

            node.type = sq.AST_CUSTOM_LUA
            node.lua = format("%s[%d]", final_group_var, idx)
        end
    end
end

local function parse_select_core(parser)
    -- Each core SELECT statement gets its own Env instance so they don't
    -- interfere with each other
    local env = sq.Env.new(parser)

    local returned_columns = {} -- List of column names
    local return_exprs = {} -- {column_name = expr}
    local wildcard = false
    local distinct = parser:pop_if_equals("distinct")
    if not distinct then
        parser:pop_if_equals("all")
    end

    repeat
        if parser:peek() == "*" then
            parser:next()
            wildcard = #returned_columns
        else
            parse_select_result_column(parser, return_exprs, returned_columns)
        end
    until not parser:pop_if_equals(",")

    local joins = {}
    if parser:pop_if_equals("from") then
        -- It's simpler to represent everything as a join
        local table_name, table_name_as = parse_table_name(env)
        joins[1] = {
            table = table_name,
            alias = table_name_as,
        }

        while true do
            local outer
            if parser:pop_if_equals("left") then
                parser:pop_if_equals("outer")
                outer = true
            elseif parser:peek() ~= "join" and not parser:pop_if_equals("inner") then
                break
            end

            parser:expect("join")
            local t, alias = parse_table_name(env)

            local mode = parser:next()
            local cond
            if mode == "on" then
                cond = parser:parse_expr()
            elseif mode == "using" then
                parser:expect("(")
                repeat
                    local using = parser:next()
                    parser:assert(sq.valid_identifier(using))

                    -- Treat them as equivalent by ignoring one when resolving
                    -- bare identifiers
                    env.using_cols[alias .. "/" .. using] = true

                    local new_cond = {
                        type = sq.AST_OP, op = "=", argc = 2,
                        {
                            type = sq.AST_OP, op = ".", argc = 2,
                            {type = sq.AST_VAR,
                                name = env:resolve_bare_identifier(using)},
                            {type = sq.AST_VAR, name = using},
                        },
                        {
                            type = sq.AST_OP, op = ".", argc = 2,
                            {type = sq.AST_VAR, name = alias},
                            {type = sq.AST_VAR, name = using},
                        }
                    }
                    cond = cond and {
                        type = sq.AST_OP, op = "and", argc = 2,
                        cond, new_cond,
                    } or new_cond
                until not parser:pop_if_equals(",")
                parser:expect(")")
            else
                parser:error("Join must have either on or using, not " .. mode)
            end
            joins[#joins + 1] = {
                table = t,
                alias = alias,
                cond = cond,
                outer = outer,
            }
        end
    end

    if parser:pop_if_equals("where") then
        joins[1] = joins[1] or {}
        assert(not joins[1].cond)
        joins[1].cond = parser:parse_expr()
    end


    local group_by_exprs = {}
    if parser:pop_if_equals("group") then
        parser:expect("by")
        repeat
            group_by_exprs[#group_by_exprs + 1] = parser:parse_expr()
        until not parser:pop_if_equals(",")
    end

    -- Figure out if this is an aggregate query
    local aggregations = {}
    local group_by = {}
    local final_group_var = #group_by_exprs > 0 and format("group_%d",
        #group_by_exprs) or "aggregate"
    for _, tree in pairs(return_exprs) do
        separate_aggregate_funcs(env, tree, aggregations, final_group_var)
    end
    for i, expr in ipairs(group_by_exprs) do
        separate_aggregate_funcs(env, expr, aggregations, final_group_var)
        group_by[i] = env:expr_to_lua(expr)
    end
    local aggregate = #aggregations > 0 or #group_by > 0

    local having
    if parser:pop_if_equals("having") then
        parser:assert(aggregate, "HAVING only works on aggregate queries")

        local having_expr = parser:parse_expr()
        separate_aggregate_funcs(env, having_expr, aggregations, final_group_var)
        having = env:expr_to_lua(having_expr)
    end

    local return_exprs_lua = {}
    for name, expr in pairs(return_exprs) do
        return_exprs_lua[name] = env:expr_to_lua(expr)
    end

    -- Parse wildcards
    if wildcard then
        parser:assert(#joins > 0, "You can't do SELECT * without any tables")
        for _, join in ipairs(joins) do
            local columns = join.table and env:list_columns(join.table) or
                env.nested_selects[join.alias].returned_columns
            for _, column in ipairs(columns) do
                if not env.using_cols[join.alias .. "/" .. column] then
                    wildcard = wildcard + 1
                    parser:assert(not return_exprs[column], duplicate_err, column)
                    table.insert(returned_columns, wildcard, column)
                    return_exprs_lua[column] = env:get_col_ref(join.alias, column)
                end
            end
        end
    end

    return {
        aggregate = aggregate,
        aggregations = aggregations,
        env = env,
        distinct = distinct,
        group_by = group_by,
        having = having,
        joins = joins,
        return_exprs_lua = return_exprs_lua,
        returned_columns = returned_columns,
    }
end

function select_cmd(parser)
    local unions = {parse_select_core(parser)}
    local returned_columns = unions[1].returned_columns

    -- Merge UNIONed queries together
    while parser:pop_if_equals("union") do
        local union_all = parser:pop_if_equals("all")
        parser:expect("select")

        local s = parse_select_core(parser)
        parser:assert(#s.returned_columns == #returned_columns,
            "Queries in UNION do not return the same number of columns")

        -- Rename columns in this statement to match the first one
        local new_returns = {}
        for i, col_name in ipairs(s.returned_columns) do
            new_returns[returned_columns[i]] = s.return_exprs_lua[col_name]
        end
        s.return_exprs_lua = new_returns

        -- Prevent accidentally using the wrong returned_columns later on possible
        s.returned_columns = nil

        unions[#unions + 1] = s

        if not union_all then
            -- Enable DISTINCT
            for _, union2 in ipairs(unions) do
                union2.union_distinct = true
            end
        end
    end

    local order_by
    if parser:pop_if_equals("order") then
        parser:expect("by")
        order_by = {}

        -- HACK: Wrap the env object to prioritise returned columns (like
        -- SQLite does)
        local base_env = #unions == 1 and unions[1].env or sq.Env.new(parser)
        local env = setmetatable({
            get_col_ref = function(self, table_alias, col_name)
                if table_alias == "*result*" then
                    return format("row[%q]", col_name), nil
                end
                return base_env:get_col_ref(table_alias, col_name)
            end,
            resolve_bare_identifier = function(self, column_name)
                if table.indexof(returned_columns, column_name) > 0 then
                    return "*result*"
                end
                return base_env:resolve_bare_identifier(column_name)
            end,
        }, {__index = base_env, __newindex = base_env})

        repeat
            -- Note: "ORDER BY 5 + 0" orders by the literal value "5" in SQLite
            -- but seems pointless and potentially confusing so it's not allowed
            local expr, is_numeric = env:expr_to_lua(parser:parse_expr())
            parser:assert(not is_numeric,
                "Numeric column indexes are not supported in ORDER BY")
            local desc = parser:pop_if_equals("desc")
            if not desc then
                parser:pop_if_equals("asc")
            end
            order_by[#order_by + 1] = {expr = expr, desc = desc}
        until not parser:pop_if_equals(",")
    end

    local limit_num, limit_code
    if parser:pop_if_equals("limit") then
        local is_number

        -- Parse LIMIT in a separate Env so that it can't access anything
        -- from the query
        local env = sq.Env.new(parser)
        limit_code, is_number = env:expr_to_lua(parser:parse_expr())

        -- Parse limit at compile time so that 1 can be handled separately
        if is_number then
            limit_num = assert(core.deserialize("return " .. limit_code))
            parser:assert(limit_num % 1 == 0, "LIMIT value must be an integer")

            -- Treat LIMIT 0 as if it were a variable, anything better is
            -- probably not worth implementing
            if limit_num < 1 then
                limit_num = nil
            end
        end
    end

    local code = sq.CodeBuf.new(parser)
    code:put("local res = {}")

    if order_by then
        for i = 1, #order_by do
            code:putf("local order_keys_%d = {}", i)
        end
    end

    for union_num, s in ipairs(unions) do
        local aggregate, aggregations, env, group_by, having, joins,
            return_exprs_lua =
            s.aggregate, s.aggregations, s.env, s.group_by, s.having, s.joins,
            s.return_exprs_lua

        local aggregate_defaults
        if aggregate then
            -- Make COUNT(value) default to 0
            local d = {}
            for idx, a in pairs(aggregations) do
                if a.func == "count" then
                    d[#d + 1] = format("[%d] = 0", idx)
                end

                if a.distinct then
                    d[#d + 1] = format("distinct_%d = {}", idx)
                end
            end
            aggregate_defaults = table.concat(d, ", ")

            if #group_by > 0 then
                -- If GROUP BY is present: aggregate_groups[group_by_value]
                -- If whatever group_by returns is nil, we use aggregate_groups
                -- itself as a key (as to not conflict with anything else it could
                -- possibly return) to avoid Lua complaining about a nil key
                code:put("local aggregate_groups = {}")
            else
                -- If there's no GROUP BY, we use one table and fill it in with the
                -- COUNT() defaults
                code:putf("local aggregate = {%s}", aggregate_defaults)
            end
        end

        -- Make SELECT DISTICTs inside UNION ALLs only scan the newly unioned
        -- rows for duplicates
        local distinct_start_at = "1"
        local distinct = s.distinct or s.union_distinct
        if distinct and not s.union_distinct and union_num > 1 then
            distinct_start_at = env:sym("union_start_idx")
            code:putf("local %s = #res + 1", distinct_start_at)
        end

        -- Figure out what order to loop in (note that ordered_joins might be
        -- longer than joins if joins had to be split up)
        -- Must be before insert_var_refs so all the variables it needs can be
        -- defined
        local ordered_joins = env:get_optimal_join_order(joins)
        env:insert_var_refs(code, true)
        local join_ends = env:write_joins(code, ordered_joins)

        -- Aggregate clauses have two loops:
        -- The first one (below) builds a list of requested values per group, and
        -- the second one which adds these all to "res" and deals with HAVING etc
        if aggregate then
            env:assert(#joins > 0, "Cannot use aggregate query without any tables")

            -- If GROUP BY clauses are present, get an "aggregate" value
            local group = "aggregate"
            for i, clause in ipairs(group_by) do
                local prev_group = i > 1 and group or "aggregate_groups"
                group = format("group_%d", i)

                code:putf("%sgroup_key = %s", i > 1 and "" or "local ", clause)
                -- Reuse the aggregate_groups table as a sentinel value for nils
                code:put("if group_key == nil then group_key = aggregate_groups end")

                code:putf("local %s = %s[group_key]", group, prev_group)
                code:putf("if %s == nil then", group)
                code:putf("%s = {%s}", group,
                    i == #group_by and aggregate_defaults or "")

                code:putf("%s[group_key] = %s", prev_group, group)
                code:put("end")
            end

            local local_value = false
            for idx, a in pairs(aggregations) do
                local total = format("%s[%d]", group, idx)

                -- NULL columns are ignored (except for COUNT(*))
                if a.lua then
                    code:putf("%svalue = %s",
                        local_value and "" or "local ", a.lua)
                    local_value = true

                    if a.distinct then
                        code:putf(
                            "if value ~= nil and not %s.distinct_%d[value] then",
                            group, idx
                        )
                        code:putf(
                            "%s.distinct_%d[value] = true",
                            group, idx
                        )
                    else
                        code:put("if value ~= nil then")
                    end
                else
                    assert(not a.distinct)
                end

                local c = aggregate_funcs[a.func]:gsub("total", total)
                code:putf("%s = %s", total, c)

                if a.lua then
                    code:put("end")
                end
            end

            -- Store the rowids of any row
            code:putf("if %s.rowid_%s == nil then", group,
                joins[1].alias)
            for _, join in pairs(joins) do
                code:putf("%s.rowid_%s = rowid_%s", group,
                    join.alias, join.alias)
            end
            code:put("end")

            -- End giant loop
            for _ = 1, join_ends do
                code:put("end")
            end

            -- Convert the aggregate table into a list of rows to return
            for i = 1, #group_by do
                code:putf(
                    "for _, group_%d in pairs(%s) do", i,
                    i > 1 and "group_" .. i - 1 or "aggregate_groups"
                )
            end

            -- Get the rowids of whatever column we found (if any)
            for _, join in ipairs(joins) do
                code:putf(
                    "local rowid_%s = %s.rowid_%s",
                    join.alias, group, join.alias
                )
            end

            join_ends = #group_by
            if having then
                join_ends = join_ends + 1
                code:putf("if OPS.to_boolean(%s) then", having)
            end
        end

        -- The limit check must be first if it's a variable so that LIMIT 0 works
        if limit_code and not order_by and not limit_num then
            code:putf("if #res >= %s then return res end",
                limit_code)
        end

        local row_variable = order_by or distinct
        code:putf("%s = {", row_variable and "local row" or "res[#res + 1]")
        for name, expr in pairs(return_exprs_lua) do
            code:putf("[%q] = %s,", name, expr)
        end
        code:put("}")

        if row_variable then
            code:put("res[#res + 1] = row")
        end

        if order_by then
            for i, clause in ipairs(order_by) do
                code:putf("order_keys_%d[row] = %s", i, clause.expr)
            end
        end

        if distinct then
            -- SELECT DISTINCT, deduplicate returned rows
            code:putf("for i = %s, #res - 1 do", distinct_start_at)
            local eq_nil_checks = {}
            for _, name in ipairs(returned_columns) do
                eq_nil_checks[#eq_nil_checks + 1] = format(
                    "row[%q] == res[i][%q]", name, name
                )
            end
            code:putf("if %s then res[#res] = nil; break end",
                table.concat(eq_nil_checks, " and "))
            code:put("end")
        end

        -- LIMIT can't just stop early with order by
        if limit_num and not order_by and join_ends > 0 then
            if limit_num == 1 then
                code:put("return res")
            else
                code:putf("if #res >= %d then return res end", limit_num)
            end
        end

        -- End of filling in res loop
        for _ = 1, join_ends do
            code:put("end")
        end
    end

    if order_by then
        code:put("table_sort(res, function(a, b)")
        for i, clause in ipairs(order_by) do
            local loc = i == 1 and "local " or ""
            code:putf("%ska = order_keys_%d[a]", loc, i)
            code:putf("%skb = order_keys_%d[b]", loc, i)
            if i < #order_by then
                code:put("if ka ~= kb then")
            end
            code:putf(
                "return (%s == nil and %s ~= nil) or OPS[%q](ka, kb)",
                -- Sort nulls as smaller than anything else
                clause.desc and "kb" or "ka",
                clause.desc and "ka" or "kb",

                clause.desc and ">" or "<"
            )
            if i < #order_by then
                code:put("end")
            end
        end
        code:put("end)")
    end

    -- Remove any extra rows if we had to fetch every single row
    if limit_code and order_by then
        code:putf("for i = %s + 1, #res do", limit_code)
        code:put("res[i] = nil")
        code:put("end")
    end

    code:put("return res")

    return code, returned_columns
end

return function(parser)
    local code, returned_columns = select_cmd(parser)
    return code:compile(parser), returned_columns
end
