local sq = squill._internal
local format = string.format

local table_exists_stmt = sq.bootstrap_statement([[
    SELECT COUNT(*) FROM schema WHERE db_name = ? AND table = ?
]], squill.RETURN_SINGLE_VALUE)

local function parse_as(self)
    if not self:pop_if_equals("as") then
        local alias = self:peek()
        if not alias or (alias:sub(1, 1) ~= "'" and not sq.valid_identifier(alias)) then
            return
        end
    end

    local alias = self:next()
    if alias:sub(1, 1) == "'" then
        return alias:sub(2, -2):gsub("''", "'")
    end

    self:assert(sq.valid_identifier(alias),
        "Invalid variable name for alias: %q", alias)
    return alias
end

local function parse_table_name(self)
    local table_name = self:next()
    self:assert(sq.bootstrap or table_exists_stmt(self.db_name, table_name) > 0,
        "Table %q does not exist", table_name)

    local alias = parse_as(self) or table_name
    self.table_lookup[alias] = table_name
    return table_name, alias
end

local duplicate_err = "Duplicate column name in returned value: %q\n" ..
    "This is not supported by Squill to make its implementation simpler.\n" ..
    "Try using AS to give one of the returned columns a different name."

local aggregate_funcs = {
    count = "(total or 0) + 1",
    min = "(total and total < value) and total or value",
    max = "(total and total > value) and total or value",
    sum = "(total or 0) + value",
}

local function parse_select_result_column(self, return_exprs, returned_columns)
    local expr = self:parse_expr()

    local name = parse_as(self)
    if not name then
        if expr.type == sq.AST_OP and expr.op == "." and expr[1].type == sq.AST_VAR then
            name = expr[2].name
        elseif expr.type == sq.AST_VAR then
            name = expr.name
        else
            name = "?column?"
            while return_exprs[name] do
                -- Allow duplicate auto-named columns for easier debugging (and
                -- for APIs where they are not returned by name)
                -- The columns are named randomly to discourage relying on this
                -- name, however (it may change in the future)
                name = format("?column~%04x?", math.random(0, 65535))

                -- Should be impossible anyway
                assert(#returned_columns < 32768)
            end
        end
    end

    self:assert(not return_exprs[name], duplicate_err, name)

    returned_columns[#returned_columns + 1] = name
    return_exprs[name] = expr

    return name
end
sq.parse_select_result_column = parse_select_result_column

local function separate_aggregate_funcs(self, tree, aggregations, final_group_var)
    for node in sq.walk_ast(tree) do
        if node.type == sq.AST_CALL and aggregate_funcs[node.func] and
                node.argc == 1 then
            local idx = #aggregations + 1
            local a = {func = node.func, distinct = node.distinct}
            if node[1].type ~= sq.AST_STAR then
                a.lua = self:expr_to_lua(node[1])
            end
            aggregations[idx] = a

            node.type = sq.AST_CUSTOM_LUA
            node.lua = format("%s[%d]", final_group_var, idx)
        end
    end
end

return function(self)
    local return_exprs = {} -- {column_name = expr}
    local returned_columns = {} -- List of column names
    local wildcard = false
    local distinct = self:pop_if_equals("distinct")
    if not distinct then
        self:pop_if_equals("all")
    end

    repeat
        if self:peek() == "*" then
            self:next()
            wildcard = #returned_columns
        else
            parse_select_result_column(self, return_exprs, returned_columns)
        end
    until not self:pop_if_equals(",")

    local joins = {}
    self.table_lookup = {}
    if self:pop_if_equals("from") then
        -- It's simpler to represent everything as a join
        local table_name, table_name_as = parse_table_name(self)
        joins[1] = {
            table = table_name,
            alias = table_name_as,
        }

        while true do
            local outer
            if self:pop_if_equals("left") then
                self:pop_if_equals("outer")
                outer = true
            elseif self:peek() ~= "join" and not self:pop_if_equals("inner") then
                break
            end

            self:expect("join")
            local t, alias = parse_table_name(self)

            local mode = self:next()
            local cond
            if mode == "on" then
                cond = self:parse_expr()
            elseif mode == "using" then
                self:expect("(")
                repeat
                    local using = self:next()
                    self:assert(sq.valid_identifier(using))

                    -- Treat them as equivalent by ignoring one when resolving
                    -- bare identifiers
                    self.using_cols = self.using_cols or {}
                    self.using_cols[alias .. "/" .. using] = true

                    local new_cond = {
                        type = sq.AST_OP, op = "=", argc = 2,
                        {
                            type = sq.AST_OP, op = ".", argc = 2,
                            {type = sq.AST_VAR,
                                name = self:resolve_bare_identifier(using)},
                            {type = sq.AST_VAR, name = using},
                        },
                        {
                            type = sq.AST_OP, op = ".", argc = 2,
                            {type = sq.AST_VAR, name = alias},
                            {type = sq.AST_VAR, name = using},
                        }
                    }
                    cond = cond and {
                        type = sq.AST_OP, op = "and", argc = 2,
                        cond, new_cond,
                    } or new_cond
                until not self:pop_if_equals(",")
                self:expect(")")
            else
                self:error("Join must have either on or using, not " .. mode)
            end
            joins[#joins + 1] = {
                table = t,
                alias = alias,
                cond = cond,
                outer = outer,
            }
        end
    end

    if self:pop_if_equals("where") then
        joins[1] = joins[1] or {}
        assert(not joins[1].cond)
        joins[1].cond = self:parse_expr()
    end


    local group_by_exprs = {}
    if self:pop_if_equals("group") then
        self:expect("by")
        repeat
            group_by_exprs[#group_by_exprs + 1] = self:parse_expr()
        until not self:pop_if_equals(",")
    end

    -- Figure out if this is an aggregate query
    local aggregations = {}
    local group_by = {}
    local final_group_var = #group_by_exprs > 0 and format("group_%d",
        #group_by_exprs) or "aggregate"
    for _, tree in pairs(return_exprs) do
        separate_aggregate_funcs(self, tree, aggregations, final_group_var)
    end
    for i, expr in ipairs(group_by_exprs) do
        separate_aggregate_funcs(self, expr, aggregations, final_group_var)
        group_by[i] = self:expr_to_lua(expr)
    end
    local aggregate = #aggregations > 0 or #group_by > 0

    local having
    if self:pop_if_equals("having") then
        self:assert(aggregate, "HAVING only works on aggregate queries")

        local having_expr = self:parse_expr()
        separate_aggregate_funcs(self, having_expr, aggregations, final_group_var)
        having = self:expr_to_lua(having_expr)
    end

    local order_by
    if self:pop_if_equals("order") then
        self:expect("by")
        order_by = {}
        repeat
            local expr = self:expr_to_lua(self:parse_expr())
            local desc = self:pop_if_equals("desc")
            if not desc then
                self:pop_if_equals("asc")
            end
            order_by[#order_by + 1] = {expr = expr, desc = desc}
        until not self:pop_if_equals(",")
    end

    local limit_num, limit_code
    if self:pop_if_equals("limit") then
        local is_number
        limit_code, is_number = self:expr_to_lua(self:parse_expr(), {})

        -- Parse limit at compile time so that 1 can be handled separately
        if is_number then
            limit_num = assert(core.deserialize("return " .. limit_code))
            self:assert(limit_num % 1 == 0, "LIMIT value must be an integer")

            -- Treat LIMIT 0 as if it were a variable, anything better is
            -- probably not worth implementing
            if limit_num < 1 then
                limit_num = nil
            end
        end
    end

    local return_exprs_lua = {}
    for name, expr in pairs(return_exprs) do
        return_exprs_lua[name] = self:expr_to_lua(expr)
    end

    -- Parse wildcards
    if wildcard then
        self:assert(#joins > 0, "You can't do SELECT * without any tables")
        for _, join in ipairs(joins) do
            for _, column in ipairs(self:list_columns(join.table)) do
                if not self.using_cols or not self.using_cols[join.alias .. "/" .. column] then
                    wildcard = wildcard + 1
                    self:assert(not return_exprs[column], duplicate_err, column)
                    table.insert(returned_columns, wildcard, column)
                    return_exprs_lua[column] = self:get_col_ref(join.alias, column)
                end
            end
        end
    end


    -- Figure out what order to loop in (note that ordered_joins might be
    -- longer than joins if joins had to be split up)
    -- Must be before insert_var_refs so all the variables it needs can be
    -- defined
    local ordered_joins = self:get_optimal_join_order(joins)

    local code = {}

    -- Generate argument list
    code[#code + 1] = self:create_function_def()

    local aggregate_defaults
    if aggregate then
        -- Make COUNT(value) default to 0
        local d = {}
        for idx, a in pairs(aggregations) do
            if a.func == "count" then
                d[#d + 1] = format("[%d] = 0", idx)
            end

            if a.distinct then
                d[#d + 1] = format("distinct_%d = {}", idx)
            end
        end
        aggregate_defaults = table.concat(d, ", ")

        if #group_by > 0 then
            -- If GROUP BY is present: aggregate_groups[group_by_value]
            -- If whatever group_by returns is nil, we use aggregate_groups
            -- itself as a key (as to not conflict with anything else it could
            -- possibly return) to avoid Lua complaining about a nil key
            code[#code + 1] = "local aggregate_groups = {}"
        else
            -- If there's no GROUP BY, we use one table and fill it in with the
            -- COUNT() defaults
            code[#code + 1] = format("local aggregate = {%s}", aggregate_defaults)
        end
    end

    code[#code + 1] = "local res = {}"

    if order_by then
        for i = 1, #order_by do
            code[#code + 1] = format("local order_keys_%d = {}", i)
        end
    end

    self:insert_var_refs(code, true)
    local join_ends = self:write_joins(code, ordered_joins)

    -- Aggregate clauses have two loops:
    -- The first one (below) builds a list of requested values per group, and
    -- the second one which adds these all to "res" and deals with HAVING etc
    if aggregate then
        self:assert(#joins > 0, "Cannot use aggregate query without any tables")

        -- If GROUP BY clauses are present, get an "aggregate" value
        local group = "aggregate"
        for i, clause in ipairs(group_by) do
            local prev_group = i > 1 and group or "aggregate_groups"
            group = format("group_%d", i)

            code[#code + 1] = format("%sgroup_key = %s", i > 1 and "" or "local ", clause)
            -- Reuse the aggregate_groups table as a sentinel value for nils
            code[#code + 1] = "if group_key == nil then group_key = aggregate_groups end"

            code[#code + 1] = format("local %s = %s[group_key]", group, prev_group)
            code[#code + 1] = format("if %s == nil then", group)
            code[#code + 1] = format("%s = {%s}", group,
                i == #group_by and aggregate_defaults or "")

            code[#code + 1] = format("%s[group_key] = %s", prev_group, group)
            code[#code + 1] = "end"
        end

        local local_value = false
        for idx, a in pairs(aggregations) do
            local total = format("%s[%d]", group, idx)

            -- NULL columns are ignored (except for COUNT(*))
            if a.lua then
                code[#code + 1] = format("%svalue = %s", local_value and "" or "local ", a.lua)
                local_value = true

                if a.distinct then
                    code[#code + 1] = format(
                        "if value ~= nil and not %s.distinct_%d[value] then",
                        group, idx
                    )
                    code[#code + 1] = format(
                        "%s.distinct_%d[value] = true",
                        group, idx
                    )
                else
                    code[#code + 1] = "if value ~= nil then"
                end
            else
                assert(not a.distinct)
            end

            local c = aggregate_funcs[a.func]:gsub("total", total)
            code[#code + 1] = format("%s = %s", total, c)

            if a.lua then
                code[#code + 1] = "end"
            end
        end

        -- Store the rowids of any row
        code[#code + 1] = format("if %s.rowid_%s == nil then", group, joins[1].alias)
        for _, join in pairs(joins) do
            code[#code + 1] = format("%s.rowid_%s = rowid_%s", group, join.alias, join.alias)
        end
        code[#code + 1] = "end"

        -- End giant loop
        for _ = 1, join_ends do
            code[#code + 1] = "end"
        end

        -- Convert the aggregate table into a list of rows to return
        for i = 1, #group_by do
            code[#code + 1] = format(
                "for _, group_%d in pairs(%s) do", i,
                i > 1 and "group_" .. i - 1 or "aggregate_groups"
            )
        end

        -- Get the rowids of whatever column we found (if any)
        for _, join in ipairs(joins) do
            code[#code + 1] = format(
                "local rowid_%s = %s.rowid_%s", join.alias, group, join.alias
            )
        end

        join_ends = #group_by
        if having then
            join_ends = join_ends + 1
            code[#code + 1] = format("if OPS.to_boolean(%s) then", having)
        end
    end

    -- The limit check must be first if it's a variable so that LIMIT 0 works
    if limit_code and not order_by and not limit_num then
        code[#code + 1] = format("if #res >= %s then return res end", limit_code)
    end

    code[#code + 1] = "res[#res + 1] = {"
    for name, expr in pairs(return_exprs_lua) do
        code[#code + 1] = format("[%q] = %s,", name, expr)
    end
    code[#code + 1] = "}"

    if order_by then
        for i, clause in ipairs(order_by) do
            code[#code + 1] = format("order_keys_%d[res[#res]] = %s", i, clause.expr)
        end
    end

    if distinct then
        -- SELECT DISTINCT, deduplicate returned rows
        code[#code + 1] = "local t = res[#res]"
        code[#code + 1] = "for i = 1, #res - 1 do"
        local eq_nil_checks = {}
        for _, name in ipairs(returned_columns) do
            eq_nil_checks[#eq_nil_checks + 1] = format("t[%q] == res[i][%q]",
                name, name)
        end
        code[#code + 1] = format("if %s then res[#res] = nil; break end",
            table.concat(eq_nil_checks, " and "))
        code[#code + 1] = "end"
    end

    -- LIMIT can't just stop early with order by
    if limit_num and not order_by and join_ends > 0 then
        if limit_num == 1 then
            code[#code + 1] = "return res"
        else
            code[#code + 1] = format("if #res >= %d then return res end", limit_num)
        end
    end

    -- End of filling in res loop
    for _ = 1, join_ends do
        code[#code + 1] = "end"
    end

    if order_by then
        code[#code + 1] = "table_sort(res, function(a, b)"
        for i, clause in ipairs(order_by) do
            local loc = i == 1 and "local " or ""
            code[#code + 1] = format("%ska = order_keys_%d[a]", loc, i)
            code[#code + 1] = format("%skb = order_keys_%d[b]", loc, i)
            if i < #order_by then
                code[#code + 1] = "if ka ~= kb then"
            end
            code[#code + 1] = format(
                "return (%s == nil and %s ~= nil) or OPS[%q](ka, kb)",
                -- Sort nulls as smaller than anything else
                clause.desc and "kb" or "ka",
                clause.desc and "ka" or "kb",

                clause.desc and ">" or "<"
            )
            if i < #order_by then
                code[#code + 1] = "end"
            end
        end
        code[#code + 1] = "end)"
    end

    -- Remove any extra rows if we had to fetch every single row
    if limit_code and order_by then
        code[#code + 1] = format("for i = %s + 1, #res do", limit_code)
        code[#code + 1] = "res[i] = nil"
        code[#code + 1] = "end"
    end

    code[#code + 1] = "return res"

    code[#code + 1] = "end"

    return self:compile_lua(table.concat(code, "\n")), returned_columns
end
