local sq = squill._internal
local format = string.format

local get_tables = squill.prepare_statement("squill", [[
    SELECT DISTINCT table FROM schema WHERE db_name = ?
]], squill.RETURN_SINGLE_COLUMN)

local get_columns = squill.prepare_statement("squill", [[
    SELECT
        column, type, not_null, primary_key, autoincrement, check,default_value
    FROM schema
    WHERE db_name = ? AND table = ?
    ORDER BY id
]])

local get_uniques = squill.prepare_statement("squill", [[
    SELECT unique_id, column FROM uniques u
    JOIN schema s ON u.column_id = s.id
    WHERE db_name = ? AND table = ?
]])

local get_foreign_key_ids = squill.prepare_statement("squill", [[
    SELECT DISTINCT key_id, parent_table FROM foreign_keys
    JOIN schema ON child_id = id
    WHERE db_name = ? AND table = ?
]])

local get_foreign_key_cols = squill.prepare_statement("squill", [[
    SELECT child.column AS child_column, parent_column
    FROM foreign_keys
    JOIN schema child ON child_id = id
    WHERE key_id = ?
    ORDER BY sort_order
]])

local function dump_value(value)
    local value_type = type(value)
    if value_type == "boolean" then
        return value and "TRUE" or "FALSE"
    elseif value_type == "nil" then
        return "NULL"
    elseif value_type == "number" then
        return format("%.17g", value)
    elseif value_type == "string" then
        return "'" .. value:gsub("'", "''") .. "'"
    end

    error("Cannot convert type " .. value_type .. " to SQL")
end
squill.dump_value = dump_value

local type_id_to_name = {}
for type_name, type_id in pairs(sq.known_column_types) do
    if not type_id_to_name[type_id] or #type_id_to_name[type_id] < #type_name then
        type_id_to_name[type_id] = type_name
    end
end

function squill.dump(db_name, out)
    local tables = get_tables(db_name)
    if #tables == 0 then
        out:write(format("-- The database %s is empty.\n", db_name))
        return
    end

    out:write(format("-- Squill dump of database: %s\n", db_name))
    local cols_by_table = {}
    for _, t in ipairs(tables) do
        assert(sq.valid_identifier(t))

        out:write(format("\nCREATE TABLE %s (", t))

        local uniques = {}
        for _, u in ipairs(get_uniques(db_name, t)) do
            uniques[u.unique_id] = uniques[u.unique_id] or {}
            table.insert(uniques[u.unique_id], u.column)
        end

        local cols = get_columns(db_name, t)
        cols_by_table[t] = cols

        local primary_keys = {}
        for _, c in ipairs(cols) do
            if c.primary_key then
                primary_keys[#primary_keys + 1] = c.column
            end
        end

        for i, c in ipairs(cols) do
            if i > 1 then
                out:write(",")
            end

            local type_name = assert(type_id_to_name[c.type])

            -- Prefer the inline UNIQUE constraints where possible
            local unique = false
            for unique_id, unique_cols in pairs(uniques) do
                if #unique_cols == 1 and unique_cols[1] == c.column then
                    unique = true
                    uniques[unique_id] = nil
                    break
                end
            end

            local is_pk = #primary_keys == 1 and primary_keys[1] == c.column
            out:write(format(
                "\n    %s %s%s%s%s%s%s",
                c.column,
                type_name:upper(),
                (c.not_null and not is_pk) and " NOT NULL" or "",
                (unique and not is_pk) and " UNIQUE" or "",
                is_pk and " PRIMARY KEY" or "",
                c.autoincrement and " AUTOINCREMENT" or "",
                c.default_value and format(" DEFAULT %s", c.default_value) or "",
                c.check and format(" CHECK (%s)", c.check) or ""
            ))
        end

        if next(uniques) then
            out:write(",\n")
            local first = true
            for _, unique_cols in pairs(uniques) do
                local constr = "UNIQUE"
                if #primary_keys == #unique_cols then
                    local eq = true
                    for _, name in ipairs(unique_cols) do
                        if table.indexof(primary_keys, name) < 0 then
                            eq = false
                            break
                        end
                    end

                    if eq then
                        constr = "PRIMARY KEY"
                    end
                end

                if not first then
                    out:write(",")
                end
                first = false
                out:write(format("\n    %s (%s)", constr,
                    table.concat(unique_cols, ", ")))
            end
        end

        local key_ids = get_foreign_key_ids(db_name, t)
        for _, k in ipairs(key_ids) do
            local child_cols = {}
            local parent_cols = {}
            for i, c in ipairs(get_foreign_key_cols(k.key_id)) do
                child_cols[i] = c.child_column
                parent_cols[i] = c.parent_column
            end

            out:write(format(
                ",\n    FOREIGN KEY (%s) REFERENCES %s (%s)",
                table.concat(child_cols, ", "),
                k.parent_table,
                table.concat(parent_cols, ", ")
            ))

        end

        out:write("\n);\n")
    end

    -- Use a transaction to speed up writing
    out:write("\nPRAGMA foreign_keys = OFF;\nBEGIN;\n")

    for _, t in ipairs(tables) do
        local rows = squill.prepare_statement(db_name,
            "SELECT * FROM " .. t)()

        local cols = cols_by_table[t]
        for _, row in ipairs(rows) do
            local values = {}
            for i, c in ipairs(cols) do
                values[i] = dump_value(row[c.column])
            end

            out:write(format(
                "\nINSERT INTO %s VALUES (%s);",
                t, table.concat(values, ", ")
            ))
        end

        out:write("\n")
    end

    out:write("\nCOMMIT;\nPRAGMA foreign_keys = ON;\n")
end
