-- Join parser
-- This deals with join objects which look like this:
-- {
--     table = "mytable",
--     alias = "m",
--     cond = <expr>,
--     outer = true, -- For columns that should return NULL if no matching rows found
-- }
-- Join objects may have "table"+"alias", "cond", or both.
-- "outer" is used for left joins

local format = string.format
local sq = squill._internal
local Parser = sq.Parser

local get_column_type = sq.bootstrap_statement([[
    SELECT type FROM schema
    WHERE db_name = ? AND table = ? AND column = ?
    LIMIT 1
]], squill.RETURN_SINGLE_VALUE)

local get_column_uniques = sq.bootstrap_statement([[
    SELECT unique_id FROM uniques u
    JOIN schema s ON u.column_id = s.id
    WHERE db_name = ? AND table = ? AND column = ?
]], squill.RETURN_SINGLE_COLUMN)

local count_unique_columns = sq.bootstrap_statement([[
    SELECT COUNT(*) FROM uniques u
    JOIN schema s ON u.column_id = s.id
    WHERE db_name = ? AND table = ? AND unique_id = ?
]], squill.RETURN_SINGLE_VALUE)

-- Returns true if sq.get_unique_index can be used on the column
local function is_column_indexable(db_name, table_name, col_name)
    -- Hardcoded when bootstrapping
    if sq.bootstrap then return false end

    if get_column_type(db_name, table_name, col_name) == sq.BLOBS then
        -- BLOBS are meant to keep values out of memory, using sq.get_unique_index
        -- would defeat that purpose
        return false
    end

    for _, unique_id in ipairs(get_column_uniques(db_name, table_name, col_name)) do
        if count_unique_columns(db_name, table_name, unique_id) == 1 then
            return true
        end
    end
    return false
end

local function are_depends_satisfied(join, known_tables)
    if join.depends then
        for alias in pairs(join.depends) do
            if alias ~= join.alias and not known_tables[alias] then
                return false
            end
        end
    end
    return true
end

local function check_single_join_indexable(self, join, cond)
    if cond.type ~= sq.AST_OP or (cond.op ~= "=" and cond.op ~= "==") then
        return false
    end

    assert(cond.argc == 2)
    for i = 1, 2 do
        local node = cond[i]

        -- Check if this is a column reference
        local column
        if node.type == sq.AST_OP and node.op == "." and
                node[1].type == sq.AST_VAR and node[1].name == join.alias then
            column = node[2].name
        elseif node.type == sq.AST_VAR and not sq.special_vars[node.name] and
                self:resolve_bare_identifier(node.name) == join.alias then
            column = node.name
        end

        if column then
            if is_column_indexable(self.db_name, join.table, column) then
                -- Store the column and key (don't overwrite the old
                -- data, we might have to fall back to a worse join if
                -- there's a circular join loop or something)
                join.index_column = column
                join.index_key = cond[i == 1 and 2 or 1]
                return true
            end

            break
        end
    end

    return false
end

local function check_join_indexable(self, join, conditions)
    assert(not join.index_column)

    for cond_idx, cond in ipairs(conditions) do
        if check_single_join_indexable(self, join, cond) then
            -- Associate the condition with this join so that it
            -- can be skipped when indexing
            join.cond = cond
            table.remove(conditions, cond_idx)
        end
    end
end

function Parser:get_unique_index_ref(table_name, col_name, known_unique)
    for i, ref in ipairs(self.unique_index_refs) do
        if ref.table == table_name and ref.col_name == col_name then
            -- Index already referenced
            return format("unique_%d", i)
        end
    end

    -- Add a new reference
    local i = #self.unique_index_refs + 1
    self.unique_index_refs[i] = {
        table = table_name,
        col_name = col_name,

        -- Only UNIQUE columns get cached, this may get called by
        -- update_row.lua for non-UNIQUE columns (and caching those is
        -- dangerous since the cache would not be updated)
        -- Hopefully this can be removed if/when ON DELETE CASCADE gets
        -- implemented (and the unique code therefore won't use this)
        allow_cache = (
            known_unique or
            is_column_indexable(self.db_name, table_name, col_name)
        )
    }
    return format("unique_%d", i)
end

local function atomise_join_condition(conditions, cond)
    if cond.type == sq.AST_OP and cond.op == "and" then
        for i = 1, cond.argc do
            atomise_join_condition(conditions, cond[i])
        end
    else
        conditions[#conditions + 1] = cond
    end
end

local function atomise_joins(self, all_joins)
    local joins = {}
    local conditions = {}
    for i, join in ipairs(all_joins) do
        if join.outer then
            -- Keep LEFT JOINs together (they depend on their condition)
            joins[#joins + 1] = join

            -- Check for references to columns to the right
            -- This is only enforced for OUTER JOINs because SQLite allows it
            -- for inner joins
            if join.cond then
                -- Depends cannot be cached for later yet as conditions may be
                -- changed by the index check
                local depends = self:list_accessed_tables(join.cond)
                for j = i + 1, #all_joins do
                    if depends[all_joins[j].alias] then
                        self:error(format(
                            "ON clause references table to the right of itself: %q",
                            all_joins[j].alias
                        ))
                    end
                end
            end
        else
            -- Split up INNER JOINs
            if join.alias then
                joins[#joins + 1] = {table = join.table, alias = join.alias}
            end

            -- Separate join conditions into their own list
            if join.cond then
                atomise_join_condition(conditions, join.cond)
            end
        end
    end

    -- Look for indexes
    for _, join in ipairs(joins) do
        -- Cache some values
        if join.alias then
            if join.cond then
                check_single_join_indexable(self, join, join.cond)
            else
                assert(not join.outer)
                check_join_indexable(self, join, conditions)
            end
        end
    end

    -- Add separated conditions (that weren't associated with a join by
    -- check_join_indexable) back to the joins list
    -- Add the last condition to the join first to account for the later
    -- backwards iteration
    for cond_idx = #conditions, 1, -1 do
        joins[#joins + 1] = {cond = conditions[cond_idx]}
    end

    -- Add join dependencies
    for _, join in ipairs(joins) do
        if join.cond then
            -- Fill in dependencies
            join.depends = self:list_accessed_tables(join.index_key or join.cond)

            -- Don't try and use the index for columns that reference themselves
            if join.depends[join.alias] and join.index_column then
                join.index_column = nil
                join.index_key = nil
            end
        end
    end

    return joins
end

-- Returns the order of the joins
function Parser:get_optimal_join_order(all_joins)
    local remaining_joins = atomise_joins(self, all_joins)

    -- Try and be smart and add joins that we can eliminate before anything else
    local known_tables = {}
    local order = {}

    while #remaining_joins > 0 do
        local index_join_i = 1
        local cond_join_i = 1
        local to_add = {}
        for i = #remaining_joins, 1, -1 do
            local join = remaining_joins[i]
            -- TODO: Add bare loops as late as possible
            if are_depends_satisfied(join, known_tables) then
                table.remove(remaining_joins, i)

                if join.index_column then
                    -- Add indexable joins first because they are faster
                    table.insert(to_add, index_join_i, join)
                    index_join_i = index_join_i + 1
                    cond_join_i = cond_join_i + 1
                elseif join.cond then
                    -- Add joins with conditions next so we check the condition
                    -- less frequently
                    table.insert(to_add, cond_join_i, join)
                    cond_join_i = cond_join_i + 1
                else
                    -- Add condition-less joins last
                    to_add[#to_add + 1] = join
                end
            end
        end

        if #to_add == 0 then
            -- Fall back to just adding all loops and then all conditions
            -- (ignoring indexes) if there's a circular JOIN loop
            -- I'm not sure if there's a real use case for this
            for _, join in ipairs(remaining_joins) do
                assert(not join.outer)
                order[#order + 1] = {table = join.table, alias = join.alias}
            end
            for _, join in ipairs(remaining_joins) do
                order[#order + 1] = {cond = join.cond}
            end
            break
        end

        -- Mark each join as known for the next iteration
        for _, join in ipairs(to_add) do
            if join.alias then
                known_tables[join.alias] = true
            end
        end

        -- And add them all to the final order
        table.insert_all(order, to_add)
    end

    -- Convert to lua for later use (stores what columns are needed)
    for _, join in ipairs(order) do
        if join.index_key then
            join.index_key_lua = self:expr_to_lua(join.index_key)
            join.unique_index_ref = self:get_unique_index_ref(join.table,
                join.index_column, true)
        elseif join.cond then
            join.cond_lua = self:expr_to_lua(join.cond)
        end
    end

    return order
end

function Parser:write_joins(code, joins)
    local ends = 0
    for _, join in ipairs(joins) do
        if join.index_column then
            -- Use indexes if possible
            code[#code + 1] = format("local rowid_%s = %s[%s]", join.alias,
                join.unique_index_ref, join.index_key_lua)

            -- For LEFT JOINs, just skip the nil check (table[nil] == nil)
            if not join.outer then
                code[#code + 1] = format("if rowid_%s then", join.alias)
                ends = ends + 1
            end
        else
            -- Otherwise use loops
            if join.alias then
                -- For LEFT JOINs: An extra iteration is done to add a NULL
                -- row if no rows were found.
                -- (Note: This iteration actually keeps the rowid as a number,
                -- but since it's greater than the length it must be nil)
                if join.outer then
                    code[#code + 1] = format("local null_%s = true", join.alias)
                end

                code[#code + 1] = format(
                    "for rowid_%s = 1, %s%s do",
                    join.alias,
                    assert(self.lengths[join.alias]),
                    join.outer and " + 1" or ""
                )

                ends = ends + 1
            end

            if join.cond_lua then
                code[#code + 1] = format(
                    "if %sOPS.to_boolean(%s) then",
                    join.outer and format(
                        "(null_%s and rowid_%s > %s) or ",
                        join.alias, join.alias,
                        assert(self.lengths[join.alias])
                    ) or "",
                    join.cond_lua
                )
                ends = ends + 1

                if join.outer then
                    code[#code + 1] = format("null_%s = false", join.alias)
                end
            end
        end
    end

    return ends
end
