local sq = squill._internal
local format = string.format
local valid_identifier = sq.valid_identifier

local get_autoincrements = sq.bootstrap_statement([[
    SELECT column AS name, id, autoincrement, default_value FROM schema
    WHERE db_name = ? AND table = ? AND (
        autoincrement IS NOT NULL OR default_value IS NOT NULL
    )
]])

local function insert_stmt(parser, on_conflict)
    parser:expect("into")

    local table_name = parser:next()
    parser:assert(valid_identifier(table_name), "%q is not a valid table name", table_name)

    local env = sq.Env.new(parser)
    local refs = {}
    local all_column_names = env:list_columns(table_name)
    for _, col in ipairs(all_column_names) do
        local row_ref, col_ref = env:get_col_ref(table_name, col)
        refs[col] = {row = row_ref, col = col_ref}
    end
    parser:assert(next(refs), "The table %q does not exist", table_name)

    local column_names = {}
    if parser:pop_if_equals("(") then
        repeat
            local name = parser:next()
            column_names[#column_names + 1] = name
            parser:assert(valid_identifier(name), "%q is not a valid column name", name)
            parser:assert(refs[name], "%q does not exist in table %q", name, table_name)
        until not parser:pop_if_equals(",")
        parser:expect(")")
    else
        table.insert_all(column_names, all_column_names)
    end

    parser:expect("values")

    parser:expect("(")
    local values_by_col = {}
    local i = 0
    repeat
        i = i + 1
        parser:assert(column_names[i], "Too many values passed in")
        values_by_col[column_names[i]] = env:expr_to_lua(parser:parse_expr())
    until not parser:pop_if_equals(",")
    parser:expect(")")

    parser:assert(i == #column_names, "Not enough values passed in")

    env.table_lookup[table_name] = table_name
    local returned_columns = {affected_rows = true}
    local returning = false
    local return_exprs_lua = {}
    if parser:pop_if_equals("returning") then
        repeat
            local name = sq.parse_select_result_column(parser, return_exprs_lua,
                returned_columns)
            return_exprs_lua[name] = env:expr_to_lua(return_exprs_lua[name])
        until not parser:pop_if_equals(",")
        returning = true
    end

    local autoincrements = {}
    for _, col in ipairs(get_autoincrements(parser.db_name, table_name)) do
        if not values_by_col[col.name] then
            if col.autoincrement then
                values_by_col[col.name] = format("autoincrement_helper(%d)", col.id)
                autoincrements[col.name] = true
            else
                local p = sq.Parser.new(parser.db_name, col.default_value)
                local expr = p:parse_expr()
                assert(not p:peek())
                values_by_col[col.name] = env:expr_to_lua(expr)
            end
        end
    end

    -- update_row() must be called before insert_var_refs so it can reference
    -- new columns
    local update_code = sq.CodeBuf.new()
    local checks = sq.CodeBuf.new()
    local checked_uniques = sq.CodeBuf.new()

    local on_unique_violation
    if on_conflict == "ignore" then
        -- This probably isn't the most efficient way to handle conflicts, but
        -- it's the easiest to implement
        -- Note that "INSERT OR IGNORE" is also supposed to ignore other
        -- constraints like checks and not null, but currently doesn't since I
        -- don't know if that would be useful for anything
        update_code:put("local function back_out()")

        -- Revert changes to any columns that got touched
        for _, ref in pairs(refs) do
            update_code:putf("if %s.length == new_length then", ref.col)
            update_code:putf("%s[new_length] = nil", ref.col)
            update_code:putf("%s.length = old_length", ref.col)
            update_code:put("end")
        end

        update_code:put("return {affected_rows = 0}")
        update_code:put("end")

        on_unique_violation = "return back_out()"
    elseif on_conflict == "replace" then
        -- https://www.sqlite.org/lang_conflict.html
        -- Just delete any conflicting rows and carry on with the insertion
        update_code:put("local to_delete = {}")
        on_unique_violation = "to_delete[i] = true"

        -- Add delete checks now before insert_var_refs()
        for col_name, ref in pairs(refs) do
            ref.delete_checks = sq.CodeBuf.new()
            env:add_foreign_key_pre_delete_checks(ref.delete_checks, table_name, {
                [col_name] = {row = format("%s[i]", ref.col)}
            })
        end
    elseif on_conflict then
        error(format("Not implemented: INSERT OR %s", on_conflict:upper()))
    end

    local index_updates = {}
    for col, ref in pairs(refs) do
        update_code:putf(
            "assert(%s.length == old_length, 'Corrupted database table!')",
            ref.col
        )
        env:update_row(update_code, checks, table_name, col, values_by_col[col],
            ref.row, ref.col, false, checked_uniques, autoincrements[col],
            on_unique_violation, index_updates)

        -- If on_conflict is replaced, length is only updated after the insert
        -- command has finished
        if on_conflict ~= "replace" then
            update_code:putf("%s.length = new_length", ref.col)
        end
    end

    local code = sq.CodeBuf.new()
    env:insert_var_refs(code)

    code:putf("local old_length = %s", assert(env.lengths[table_name]))
    code:put("local new_length = old_length + 1")
    code:putf("local rowid_%s = new_length", table_name)

    code:insert_all(update_code)
    code:insert_all(checks)

    local index_updated = "true"
    if on_conflict == "replace" then
        -- Delete any conflicting rows
        code:put("local to_delete_list = {}")
        code:put("for i in pairs(to_delete) do")
        code:put("to_delete_list[#to_delete_list + 1] = i")
        code:put("end")

        -- Keys must be deleted from highest to lowest to avoid the deletions
        -- moving around the indexes of rows that are yet to be deleted
        code:put("table_sort(to_delete_list)")
        code:put("for list_idx = #to_delete_list, 1, -1 do")

        code:put("local i = to_delete_list[list_idx]")
        for _, ref in pairs(refs) do
            if #ref.delete_checks > 0 then
                -- If the newly inserted row has the same value, then foreign
                -- key checks should be skipped to allow updating
                code:putf("if %s[i] ~= %s then", ref.col, ref.row)
                code:insert_all(ref.delete_checks)
                code:put("end")
            end

            code:putf("%s[i] = %s[new_length]", ref.col, ref.col)
            code:putf("%s[new_length] = nil", ref.col)
        end

        -- Keep track of the index for RETURNING and constraints
        code:putf("if rowid_%s == new_length then", table_name)
        code:putf("rowid_%s = i", table_name)
        code:put("end")

        code:put("new_length = new_length - 1")
        code:put("end")

        -- Only update new_length once everything is done
        for _, ref in pairs(refs) do
            code:putf("%s.length = new_length", ref.col)
        end

        -- TODO: Maybe keep indexes with INSERT OR REPLACE when deleting rows?
        -- Doing so would require making sure all indexes are correct though
        index_updated = "#to_delete_list == 0"
        code:put("if #to_delete_list == 0 then")
        code:insert_all(index_updates)
        code:put("end")
    else
        code:insert_all(index_updates)
    end

    env:add_set_columns(code, refs, index_updated)

    code:put("return {")
    if returning then
        code:put("{")
        for name, expr in pairs(return_exprs_lua) do
            code:putf("[%q] = %s,", name, expr)
        end
        code:put("},")
    end
    code:put("affected_rows = 1")
    code:put("}")

    return code:compile(parser), returned_columns
end

function sq.cmds:replace()
    -- Support REPLACE for compatibility
    return insert_stmt(self, "replace")
end

return function(parser)
    local on_conflict
    if parser:pop_if_equals("or") then
        on_conflict = parser:next()
    end
    return insert_stmt(parser, on_conflict)
end
