local sq = squill._internal
local format = string.format
local valid_identifier = sq.valid_identifier
local unpack = table.unpack or unpack

local known_column_types = {
    int = sq.INTEGERS, integer = sq.INTEGERS,
    real = sq.NUMBERS, text = sq.STRINGS, blob = sq.BLOBS,
    bool = sq.BOOLEANS, boolean = sq.BOOLEANS,
}
sq.known_column_types = known_column_types

local table_exists_stmt = sq.bootstrap_statement([[
    SELECT COUNT(*) FROM schema WHERE db_name = ? AND table = ?
]], squill.RETURN_SINGLE_VALUE)

local add_column = sq.bootstrap_statement([[
    INSERT INTO schema (
        db_name, table, column, type, id, not_null, primary_key,
        autoincrement, check
    )
    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
]])

-- Needed because AUTOINCREMENT doesn't work properly on the "schema" table
-- (since it tries to modify the schema table while the schema table is already
-- being modified)
local get_max_column_id = sq.bootstrap_statement([[
    SELECT MAX(id) FROM schema
]], squill.RETURN_SINGLE_VALUE)

local add_unique_constraint = sq.bootstrap_statement([[
    INSERT INTO uniques (column_id, unique_id)
    VALUES (?, ?)
]])

local get_next_key_id = sq.bootstrap_statement([[
    SELECT MAX(key_id) + 1 FROM foreign_keys
]], squill.RETURN_SINGLE_VALUE)

local add_foreign_key = sq.bootstrap_statement([[
    INSERT INTO foreign_keys (
        key_id, child_id, parent_table, parent_column
    )
    VALUES (?, ?, ?, ?)
]])

local function parse_foreign_key_clause(self, child_cols, foreign_keys)
    self:expect("references")

    -- We don't validate that the parent table actually exists just yet so that
    -- circular references can be done. It's checked when compiling queries
    -- that modify the affected columns.
    local parent_table = self:next()
    self:assert(valid_identifier(parent_table) and not sq.special_vars[parent_table],
        "%q is not a valid table name", parent_table)

    local parent_cols = {}
    self:expect("(")
    repeat
        local column = self:next()
        self:assert(valid_identifier(column), "Invalid column name: %q", column)
        self:assert(table.indexof(parent_cols, column) < 0,
            "Duplicate column name in REFERENCES: %q", column)
        parent_cols[#parent_cols + 1] = column
    until not self:pop_if_equals(",")
    self:expect(")")

    self:assert(#parent_cols == #child_cols,
        "Mismatch in number of columns on either side of FOREIGN KEY constraint")

    foreign_keys[#foreign_keys + 1] = {
        parent_cols = parent_cols,
        child_cols = child_cols,
        parent_table = parent_table,
    }
end

return function(self)
    self:expect("table")

    local if_not_exists = self:pop_if_equals("if")
    if if_not_exists then
        self:expect("not")
        self:expect("exists")
    end

    local table_name = self:next()
    self:assert(valid_identifier(table_name) and not sq.special_vars[table_name],
        "%q is not a valid table name", table_name)
    self:assert(table_name:sub(1, 7) ~= "squill_",
        "Tables starting with 'squill_' are reserved for future use")

    self:expect("(")
    local columns = {}
    local unique_constraints = {}
    local foreign_keys = {}
    local seen_uniques = {}
    local have_primary_key = false
    local order = get_max_column_id()
    repeat
        local name = self:next()

        self:assert(not columns[name], "Duplicate column name: %q", name)
        self:assert(name ~= ")",
            "SQL requires that you do not have a comma after the last entry in CREATE TABLE")
        self:assert(valid_identifier(name) and not sq.special_vars[name],
            "%q is not a valid column name", name)

        local col_type = self:next()
        local internal_type = self:assert(known_column_types[col_type],
            "Unknown column type: %q", col_type)

        -- Constraints
        local not_null = false
        local primary_key = false
        local check
        while true do
            if self:pop_if_equals("not") then
                self:expect("null")
                self:assert(not not_null)
                not_null = true
            elseif self:pop_if_equals("unique") then
                unique_constraints[#unique_constraints + 1] = {name}
                seen_uniques[name] = true
            elseif self:pop_if_equals("check") then
                self:assert(check == nil)

                self:expect("(")
                local get_check_code = self:save_pos()
                self:parse_expr()
                check = get_check_code()
                self:expect(")")
            elseif self:pop_if_equals("primary") then
                self:expect("key")
                self:assert(not have_primary_key,
                    "Only one primary key can be specified")
                primary_key = true
                have_primary_key = true
                if self:pop_if_equals("autoincrement") then
                    -- Implicitly enabled below
                    self:assert(internal_type == sq.INTEGERS,
                        "AUTOINCREMENT is only supported on integer columns")
                end
            elseif self:peek() == "references" then
                parse_foreign_key_clause(self, {name}, foreign_keys)
            else
                break
            end
        end

        -- Primary keys are always NOT NULL and UNIQUE (but it should be
        -- possible to explicitly specify both)
        local autoincrement
        if primary_key then
            not_null = true
            if not seen_uniques[name] then
                unique_constraints[#unique_constraints + 1] = {name}
                seen_uniques[name] = true
            end

            -- Autoincrement is implicitly enabled for all integer primary keys
            -- to be as close to SQLite as possible, since there isn't a
            -- persistent rowid like sqlite
            if internal_type == sq.INTEGERS then
                autoincrement = 0
            end
        end

        order = order + 1
        columns[name] = {internal_type, order, not_null, primary_key, autoincrement, check}
    until not self:pop_if_equals(",") or self:peek() == "unique" or
        self:peek() == "primary" or self:peek() == "foreign"

    -- Handle table-wide unique constraints
    repeat
        local primary_key, foreign_key = false, false
        if self:pop_if_equals("primary") then
            self:expect("key")
            self:assert(not have_primary_key,
                "Only one primary key can be specified")
            primary_key = true
            have_primary_key = true
        elseif self:pop_if_equals("foreign") then
            self:expect("key")
            foreign_key = true
        elseif not self:pop_if_equals("unique") then
            break
        end

        self:expect("(")
        local col_names = {}
        repeat
            local column = self:next()
            self:assert(columns[column], "Unknown column: %q", column)
            self:assert(table.indexof(col_names, column) < 0,
                "Duplicate column name in constraint: %q", column)
            col_names[#col_names + 1] = column
        until not self:pop_if_equals(",")
        self:expect(")")

        if foreign_key then
            parse_foreign_key_clause(self, col_names, foreign_keys)
        else
            -- Make sure there are no duplicates
            local key = table.concat(col_names, "/")
            if seen_uniques[key] then
                self:error(format(
                    "Duplicate UNIQUE(%s) constraint", table.concat(col_names, ", ")
                ))
            end
            seen_uniques[key] = true

            unique_constraints[#unique_constraints + 1] = col_names

            -- Set primary_key in the database
            if primary_key then
                for _, col_name in ipairs(col_names) do
                    columns[col_name][4] = true
                end
            end
        end
    until not self:pop_if_equals(",")

    self:expect(")")

    -- Compatibility with SQLite
    self:pop_if_equals("strict")

    return function()
        sq.assert_no_transaction(self.db_name)
        sq.begin_transaction("squill")

        if table_exists_stmt(self.db_name, table_name) > 0 then
            sq.rollback_transaction("squill")
            if if_not_exists then
                return
            end

            error(format("Table %q already exists in database %q",
                table_name, self.db_name), 2)
        end

        -- Add the column info to the schema
        for column, col_def in pairs(columns) do
            add_column(self.db_name, table_name, column, unpack(col_def, 1, 6))
        end

        for unique_id, col_names in ipairs(unique_constraints) do
            for _, column in ipairs(col_names) do
                add_unique_constraint(columns[column][2], unique_id)
            end
        end

        for _, foreign_key in ipairs(foreign_keys) do
            local key_id = get_next_key_id() or 0
            for i, parent_col in ipairs(foreign_key.parent_cols) do
                local child_col = foreign_key.child_cols[i]
                add_foreign_key(key_id, columns[child_col][2],
                    foreign_key.parent_table, parent_col)
            end
        end

        sq.commit_transaction("squill")

        sq.schema_vers[self.db_name] = (sq.schema_vers[self.db_name] or 0) + 1
        return {}
    end
end
