local sq = squill._internal
local format = string.format
local valid_identifier = sq.valid_identifier

local get_autoincrements = sq.bootstrap_statement([[
    SELECT column AS name, id FROM schema
    WHERE db_name = ? AND table = ? AND autoincrement IS NOT NULL
]])

local autoincrement_helper = sq.bootstrap_statement([[
    UPDATE schema SET autoincrement = autoincrement + 1
    WHERE id = ?
    RETURNING autoincrement - 1
]], squill.RETURN_SINGLE_VALUE)

local function insert_stmt(self, on_conflict)
    self:expect("into")

    local table_name = self:next()
    self:assert(valid_identifier(table_name), "%q is not a valid table name", table_name)

    local refs = {}
    local all_column_names = self:list_columns(table_name)
    for _, col in ipairs(all_column_names) do
        local row_ref, col_ref = self:get_col_ref(table_name, col)
        refs[col] = {row = row_ref, col = col_ref}
    end
    self:assert(next(refs), "The table %q does not exist", table_name)

    local column_names = {}
    if self:pop_if_equals("(") then
        repeat
            local name = self:next()
            column_names[#column_names + 1] = name
            self:assert(valid_identifier(name), "%q is not a valid column name", name)
            self:assert(refs[name], "%q does not exist in table %q", name, table_name)
        until not self:pop_if_equals(",")
        self:expect(")")
    else
        table.insert_all(column_names, all_column_names)
    end

    self:expect("values")

    self:expect("(")
    local values_by_col = {}
    local i = 0
    repeat
        i = i + 1
        self:assert(column_names[i], "Too many values passed in")
        values_by_col[column_names[i]] = self:expr_to_lua(self:parse_expr(), {})
    until not self:pop_if_equals(",")
    self:expect(")")

    self:assert(i == #column_names, "Not enough values passed in")

    self.table_lookup = {[table_name] = table_name}
    local returned_columns = {affected_rows = true}
    local returning = false
    local return_exprs_lua = {}
    if self:pop_if_equals("returning") then
        repeat
            local name = sq.parse_select_result_column(self, return_exprs_lua,
                returned_columns)
            return_exprs_lua[name] = self:expr_to_lua(return_exprs_lua[name])
        until not self:pop_if_equals(",")
        returning = true
    end

    local code = {}
    local found_autoincrement = false
    local autoincrements = {}
    for _, col in ipairs(get_autoincrements(self.db_name, table_name)) do
        if not values_by_col[col.name] then
            -- If no value is specified, autoincrement
            if not found_autoincrement then
                code[#code + 1] = "local autoincrement_helper = ..."
                found_autoincrement = true
            end

            values_by_col[col.name] = format("autoincrement_helper(%d)", col.id)
            autoincrements[col.name] = true
        end
    end

    -- update_row() must be called before insert_var_refs so it can reference
    -- new columns
    local update_code, checks, checked_uniques = {}, {}, {}

    local on_unique_violation
    if on_conflict == "ignore" then
        -- This probably isn't the most efficient way to handle conflicts, but
        -- it's the easiest to implement
        -- Note that "INSERT OR IGNORE" is also supposed to ignore other
        -- constraints like checks and not null, but currently doesn't since I
        -- don't know if that would be useful for anything
        update_code[#update_code + 1] = "local function back_out()"

        -- Revert changes to any columns that got touched
        for _, ref in pairs(refs) do
            update_code[#update_code + 1] = format("if %s.length == new_length then", ref.col)
            update_code[#update_code + 1] = format("%s[new_length] = nil", ref.col)
            update_code[#update_code + 1] = format("%s.length = old_length", ref.col)
            update_code[#update_code + 1] = "end"
        end

        update_code[#update_code + 1] = "return {affected_rows = 0}"
        update_code[#update_code + 1] = "end"

        on_unique_violation = "return back_out()"
    elseif on_conflict == "replace" then
        -- https://www.sqlite.org/lang_conflict.html
        -- Just delete any conflicting rows and carry on with the insertion
        update_code[#update_code + 1] = "local to_delete = {}"
        on_unique_violation = "to_delete[i] = true"
    elseif on_conflict then
        error(format("Not implemented: INSERT OR %s", on_conflict:upper()))
    end

    local index_updates = {}
    for col, ref in pairs(refs) do
        update_code[#update_code + 1] = format(
            "assert(%s.length == old_length, 'Corrupted database table!')",
            ref.col
        )
        self:update_row(update_code, checks, table_name, col, values_by_col[col],
            ref.row, ref.col, false, checked_uniques, autoincrements[col],
            on_unique_violation, index_updates)

        -- If on_conflict is replaced, length is only updated after the insert
        -- command has finished
        if on_conflict ~= "replace" then
            update_code[#update_code + 1] = format("%s.length = new_length", ref.col)
        end
    end

    -- Generate argument list
    code[#code + 1] = self:create_function_def()

    self:insert_var_refs(code)

    code[#code + 1] = format("local old_length = %s", assert(self.lengths[table_name]))
    code[#code + 1] = "local new_length = old_length + 1"
    code[#code + 1] = format("local rowid_%s = new_length", table_name)

    table.insert_all(code, update_code)
    table.insert_all(code, checks)

    local index_updated = "true"
    if on_conflict == "replace" then
        -- Delete any conflicting rows
        code[#code + 1] = "local to_delete_list = {}"
        code[#code + 1] = "for i in pairs(to_delete) do"
        code[#code + 1] = "to_delete_list[#to_delete_list + 1] = i"
        code[#code + 1] = "end"

        -- Keys must be deleted from highest to lowest to avoid the deletions
        -- moving around the indexes of rows that are yet to be deleted
        code[#code + 1] = "table_sort(to_delete_list)"
        code[#code + 1] = "for list_idx = #to_delete_list, 1, -1 do"

        code[#code + 1] = "local i = to_delete_list[list_idx]"
        for _, ref in pairs(refs) do
            code[#code + 1] = format("%s[i] = %s[new_length]", ref.col, ref.col)
            code[#code + 1] = format("%s[new_length] = nil", ref.col)
        end

        code[#code + 1] = "new_length = new_length - 1"
        code[#code + 1] = "end"

        -- Only update new_length once everything is done
        for _, ref in pairs(refs) do
            code[#code + 1] = format("%s.length = new_length", ref.col)
        end

        -- TODO: Maybe keep indexes with INSERT OR REPLACE when deleting rows?
        -- Doing so would require making sure all indexes are correct though
        index_updated = "#to_delete_list == 0"
        code[#code + 1] = "if #to_delete_list == 0 then"
        table.insert_all(code, index_updates)
        code[#code + 1] = "end"
    else
        table.insert_all(code, index_updates)
    end

    self:add_set_columns(code, refs, index_updated)

    code[#code + 1] = "return {"
    if returning then
        code[#code + 1] = "{"
        for name, expr in pairs(return_exprs_lua) do
            code[#code + 1] = format("[%q] = %s,", name, expr)
        end
        code[#code + 1] = "},"
    end
    code[#code + 1] = "affected_rows = 1"
    code[#code + 1] = "}"

    code[#code + 1] = "end"

    local func = self:compile_lua(table.concat(code, "\n"), autoincrement_helper)
    return func, returned_columns
end

function sq.cmds:replace()
    -- Support REPLACE for compatibility
    return insert_stmt(self, "replace")
end

return function(self)
    local on_conflict
    if self:pop_if_equals("or") then
        on_conflict = self:next()
    end
    return insert_stmt(self, on_conflict)
end
