local format = string.format
local sq = squill._internal
local Parser = sq.Parser

local get_column_type = sq.bootstrap_statement([[
    SELECT type FROM schema
    WHERE db_name = ? AND table = ? AND column = ?
    LIMIT 1
]], squill.RETURN_SINGLE_VALUE)

local get_column_uniques = sq.bootstrap_statement([[
    SELECT unique_id FROM uniques u
    JOIN schema s ON u.column_id = s.id
    WHERE db_name = ? AND table = ? AND column = ?
]], squill.RETURN_SINGLE_COLUMN)

local count_unique_columns = sq.bootstrap_statement([[
    SELECT COUNT(*) FROM uniques u
    JOIN schema s ON u.column_id = s.id
    WHERE db_name = ? AND table = ? AND unique_id = ?
]], squill.RETURN_SINGLE_VALUE)

-- Returns true if sq.get_unique_index can be used on the column
local function is_column_indexable(db_name, table_name, col_name)
    -- Hardcoded when bootstrapping
    if sq.bootstrap then return false end

    if get_column_type(db_name, table_name, col_name) == sq.BLOBS then
        -- BLOBS are meant to keep values out of memory, using sq.get_unique_index
        -- would defeat that purpose
        return false
    end

    for _, unique_id in ipairs(get_column_uniques(db_name, table_name, col_name)) do
        if count_unique_columns(db_name, table_name, unique_id) == 1 then
            return true
        end
    end
    return false
end

local function are_depends_satisfied(join, known_tables)
    if join.depends then
        for alias in pairs(join.depends) do
            if alias ~= join.alias and not known_tables[alias] then
                return false
            end
        end
    end
    return true
end

local function check_join_indexable(self, join)
    assert(not join.index_column)
    -- TODO: Maybe also optimise indexable_col = 1 AND <other cond>?
    if join.cond.type ~= sq.AST_OP or (join.cond.op ~= "=" and join.cond.op ~= "==") then
        return
    end

    assert(join.cond.argc == 2)
    for i = 1, 2 do
        local node = join.cond[i]

        -- Check if this is a column reference
        local column
        if node.type == sq.AST_OP and node.op == "." and
                node[1].type == sq.AST_VAR and node[1].name == join.alias then
            column = node[2].name
        elseif node.type == sq.AST_VAR and not sq.special_vars[node.name] and
                self:resolve_bare_identifier(node.name) == join.alias then
            column = node.name
        end

        if column and is_column_indexable(self.db_name, join.table, column) then
            -- Store the column and key (don't overwrite the old data, we might
            -- have to fall back to a worse join if there's a circular join
            -- loop or something)
            join.index_column = column
            join.index_key = join.cond[i == 1 and 2 or 1]
            return
        end
    end
end

function Parser:get_unique_index_ref(table_name, col_name, known_unique)
    for i, ref in ipairs(self.unique_index_refs) do
        if ref.table == table_name and ref.col_name == col_name then
            -- Index already referenced
            return format("unique_%d", i)
        end
    end

    -- Add a new reference
    local i = #self.unique_index_refs + 1
    self.unique_index_refs[i] = {
        table = table_name,
        col_name = col_name,

        -- Only UNIQUE columns get cached, this may get called by
        -- update_row.lua for non-UNIQUE columns (and caching those is
        -- dangerous since the cache would not be updated)
        -- Hopefully this can be removed if/when ON DELETE CASCADE gets
        -- implemented (and the unique code therefore won't use this)
        allow_cache = (
            known_unique or
            is_column_indexable(self.db_name, table_name, col_name)
        )
    }
    return format("unique_%d", i)
end

-- Returns the order of the joins
function Parser:get_optimal_join_order(all_joins)
    -- Cache some values
    local remaining_joins = {}
    for i, join in ipairs(all_joins) do
        remaining_joins[i] = join

        -- Cache some values
        if join.cond then
            check_join_indexable(self, join)

            -- Fill in dependencies
            join.depends = self:list_accessed_tables(join.index_key or join.cond)

            -- Don't try and use the index for columns that reference themselves
            if join.depends[join.alias] and join.index_column then
                join.index_column = nil
                join.index_key = nil
            end
        end
    end

    -- Try and be smart and add joins that we can eliminate before anything else
    local known_tables = {}
    local order = {}

    while #remaining_joins > 0 do
        local index_join_i = 1
        local cond_join_i = 1
        local to_add = {}
        for i = #remaining_joins, 1, -1 do
            local join = remaining_joins[i]
            if are_depends_satisfied(join, known_tables) then
                table.remove(remaining_joins, i)

                if join.index_column then
                    -- Add indexable joins first because they are faster
                    table.insert(to_add, index_join_i, join)
                    index_join_i = index_join_i + 1
                    cond_join_i = cond_join_i + 1
                elseif join.cond then
                    -- Add joins with conditions next so we check the condition
                    -- less frequently
                    table.insert(to_add, cond_join_i, join)
                    cond_join_i = cond_join_i + 1
                else
                    -- Add condition-less joins last
                    to_add[#to_add + 1] = join
                end
            end
        end

        if #to_add == 0 then
            -- Fall back to just adding all loops and then all conditions
            -- (ignoring indexes) if there's a circular JOIN loop
            -- I'm not sure if there's a real use case for this
            for _, join in ipairs(remaining_joins) do
                order[#order + 1] = {table = join.table, alias = join.alias}
            end
            for _, join in ipairs(remaining_joins) do
                order[#order + 1] = {cond = join.cond}
            end
            break
        end

        -- Mark each join as known for the next iteration
        for _, join in ipairs(to_add) do
            if join.alias then
                known_tables[join.alias] = true
            end
        end

        -- And add them all to the final order
        table.insert_all(order, to_add)
    end

    -- Convert to lua for later use (stores what columns are needed)
    for _, join in ipairs(order) do
        if join.index_key then
            join.index_key_lua = self:expr_to_lua(join.index_key)
            join.unique_index_ref = self:get_unique_index_ref(join.table,
                join.index_column, true)
        elseif join.cond then
            join.cond_lua = self:expr_to_lua(join.cond)
        end
    end

    return order
end

function Parser:write_joins(code, joins)
    local ends = 0
    for _, join in ipairs(joins) do
        if join.index_column then
            -- Use indexes if possible
            code[#code + 1] = format("local rowid_%s = %s[%s]", join.alias,
                join.unique_index_ref, join.index_key_lua)
            code[#code + 1] = format("if rowid_%s then", join.alias)
            ends = ends + 1
        else
            -- Otherwise use loops
            if join.alias then
                code[#code + 1] = format("for rowid_%s = 1, %s do", join.alias,
                    assert(self.lengths[join.alias]))
                ends = ends + 1
            end

            if join.cond_lua then
                code[#code + 1] = format("if OPS.to_boolean(%s) then",
                    join.cond_lua)
                ends = ends + 1
            end
        end
    end

    return ends
end
