local sq = squill._internal
local Parser = sq.Parser
local format = string.format

local get_checks_stmt = sq.bootstrap_statement([[
    SELECT type, not_null, autoincrement, check, id FROM schema
    WHERE db_name = ? AND table = ? AND column = ?
    LIMIT 1
]], squill.RETURN_FIRST_ROW)

local get_other_checks_stmt = sq.bootstrap_statement([[
    SELECT column, check FROM schema
    WHERE db_name = ? AND table = ? AND column <> ? AND check IS NOT NULL
]])

local get_unique_ids = sq.bootstrap_statement([[
    SELECT unique_id FROM uniques
    WHERE column_id = ?
]], squill.RETURN_SINGLE_COLUMN)

local get_unique_columns = sq.bootstrap_statement([[
    SELECT column FROM uniques u
    JOIN schema s ON u.column_id = s.id
    WHERE db_name = ? AND table = ? and unique_id = ?
]], squill.RETURN_SINGLE_COLUMN)

local table_exists_stmt = sq.bootstrap_statement([[
    SELECT COUNT(*) FROM schema WHERE db_name = ? AND table = ?
]], squill.RETURN_SINGLE_VALUE)

-- I am starting to regret using Squill to implement itself
local get_foreign_key_cols = sq.bootstrap_statement([[
    SELECT
        k1.key_id, child.table AS child_table, child.column AS child_column,
        k1.parent_table, k1.parent_column,
        k2.child_id <> $1 AS is_parent
    FROM foreign_keys k1
    JOIN foreign_keys k2 ON k1.key_id = k2.key_id
    JOIN schema child ON k1.child_id = child.id
    WHERE k2.child_id = $1 OR (
        child.db_name = $2 AND k2.parent_table = $3 AND k2.parent_column = $4
    )
]])

sq.foreign_keys_disabled = {}

local function add_check(self, checks, table_name, col_name, check, must_mention)
    local p = Parser.new(self.db_name, check)
    local expr = p:parse_expr()
    assert(not p:peek())

    if must_mention and not sq.expr_mentions_variable(expr, must_mention) then
        return
    end

    checks[#checks + 1] = format(
        "if OPS['unary not'](%s) then error(%q) end",
        self:expr_to_lua(expr, {[table_name] = table_name}),
        format("CHECK contstraint violated in %s.%s", table_name, col_name)
    )
end

local function handle_foreign_keys(self, code, row, foreign_key_cols, old_value, delete)
    if sq.foreign_keys_disabled[self.db_name] then
        return
    end

    -- TODO: Verify that constraints are unique
    local keys_by_id = {}
    for _, col in ipairs(foreign_key_cols) do
        keys_by_id[col.key_id] = keys_by_id[col.key_id] or {}
        table.insert(keys_by_id[col.key_id], col)
    end

    for _, cols in pairs(keys_by_id) do
        assert(#cols == 1, "Multi-column FOREIGN KEY constraints are not implemented")

        local col = cols[1]
        if col.is_parent then
            code[#code + 1] = format(
                "if %s[%s] then error('FOREIGN KEY constraint failed') end",
                self:get_unique_index_ref(col.child_table, col.child_column),
                old_value
            )
        elseif not delete then
            code[#code + 1] = format(
                "if%s not %s[value] then error('FOREIGN KEY constraint failed') end",
                row.not_null and "" or " value ~= nil and",
                self:get_unique_index_ref(col.parent_table, col.parent_column)
            )
        end
    end
end

-- For delete.lua
function Parser:add_foreign_key_pre_delete_checks(code, table_name, refs)
    for col_name, ref in pairs(refs) do
        local row = assert(get_checks_stmt(self.db_name, table_name, col_name))
        local foreign_key_cols = get_foreign_key_cols(row.id, self.db_name,
            table_name, col_name)
        handle_foreign_keys(self, code, row, foreign_key_cols, assert(ref.row), true)
    end
end

function Parser:update_row(code, checks, table_name, col_name, expr,
        target_var, entire_column_ref, get_all_checks, checked_uniques,
        already_autoincremented, on_unique_violation, index_updates,
        row_is_not_new)
    local row = get_checks_stmt(self.db_name, table_name, col_name)
    if not row then
        if table_exists_stmt(self.db_name, table_name) > 0 then
            error(format("Non-existent column: %q", col_name))
        else
            error(format("The table %q does not exist", table_name))
        end
    end

    -- Check value types to avoid data corruption crashes
    local qualname = format("%s.%s", table_name, col_name)
    expr = format("coerce_or_error[%q](%s, %q)", row.type, expr, qualname)

    if row.check then
        add_check(self, checks, table_name, col_name, row.check)
    end

    -- If get_all_checks is set (i.e. this is an UPDATE), also add checks from
    -- other columns that reference this one
    if get_all_checks then
        for _, c in ipairs(get_other_checks_stmt(self.db_name, table_name, col_name)) do
            add_check(self, checks, table_name, c.column, c.check, col_name)
        end
    end

    local autoincrement = row.autoincrement and not already_autoincremented

    local pre_update_idx = #code + 1
    local foreign_key_cols = get_foreign_key_cols(row.id, self.db_name, table_name, col_name)
    if row.not_null or autoincrement or #foreign_key_cols > 0 then
        code[#code + 1] = format("local value = %s", expr)
        if row.not_null then
            code[#code + 1] = format("if value == nil then error(%q) end",
                format("NOT NULL contstraint violated in %s", qualname))
        end

        if autoincrement then
            code[#code + 1] = format(
                "autoincrement_set_min(%q, %q, %q, value)",
                self.db_name, table_name, col_name
            )
        end

        handle_foreign_keys(self, code, row, foreign_key_cols, target_var)

        code[#code + 1] = format("%s = value", target_var)
    else
        code[#code + 1] = format("%s = %s", target_var, expr)
    end

    local custom_violation_handler = on_unique_violation ~= nil
    on_unique_violation = on_unique_violation or format(
        "error(%q)",
        format("UNIQUE contstraint violated in %s", qualname)
    )

    local unique_ids = get_unique_ids(row.id)
    local complicated_uniques = {}
    for _, unique_id in ipairs(unique_ids) do
        if not checked_uniques[unique_id] then
            checked_uniques[unique_id] = true
            local cols = get_unique_columns(self.db_name, table_name, unique_id)
            if #cols == 1 then
                assert(cols[1] == col_name)

                -- Use get_unique_index for fast lookups (if the index is still
                -- in RAM)
                local unique_index = format(
                    "%s[%s]",
                    self:get_unique_index_ref(table_name, col_name),
                    target_var
                )
                if custom_violation_handler then
                    -- Custom UNIQUE violation handlers expect "i" to be the
                    -- index of the conflicting row
                    code[#code + 1] = format("local i = %s", unique_index)
                    code[#code + 1] = format("if i then %s end", on_unique_violation)
                else
                    code[#code + 1] = format("if %s then %s end",
                        unique_index, on_unique_violation)
                end

                if index_updates then
                    local null_check, null_check_end = "", ""
                    if not row.not_null then
                        -- Null values are not indexed
                        null_check = format("if %s ~= nil then ", target_var)
                        null_check_end = " end"
                    end

                    if row_is_not_new then
                        table.insert(code, pre_update_idx, format(
                            "%s%s = nil%s",
                            null_check, unique_index, null_check_end
                        ))
                    end

                    index_updates[#index_updates + 1] = format(
                        "%s%s = rowid_%s%s",
                        null_check,
                        unique_index, table_name,
                        null_check_end
                    )
                end
            else
                -- Fall back to slow lookups
                complicated_uniques[#complicated_uniques + 1] = cols
            end
        end
    end

    -- Skip adding the loop if no multi-column UNIQUE constraints exist
    if #complicated_uniques == 0 then return end

    checks[#checks + 1] = format("if %s ~= nil then", target_var)
    checks[#checks + 1] = format("for i = 1, %s.length do", entire_column_ref)
    -- Don't check the currently modified row
    checks[#checks + 1] = format("if i ~= rowid_%s then", table_name)

    for _, cols in ipairs(complicated_uniques) do
        local conditions = {}
        for i, col in ipairs(cols) do
            local ref, entire_col = self:get_col_ref(table_name, col)
            conditions[i] = format("%s[i] == %s", entire_col, ref)
        end
        assert(#conditions > 0)

        checks[#checks + 1] = format("if %s then %s end",
            table.concat(conditions, " and "), on_unique_violation)
    end

    checks[#checks + 1] = "end"
    checks[#checks + 1] = "end"
    checks[#checks + 1] = "end"
end
