-- SQL environment (contains tables, accessed columns, etc) used when compiling

local format = string.format

local sq = squill._internal
local valid_identifier = sq.valid_identifier

local Env = {var_refs_inserted = false}
sq.Env = Env
Env.__index = Env

function Env.new(parser)
    -- Make sure parser is actually a Parser
    assert(parser.next)

    return setmetatable({
        parser = parser,
        db_name = parser.db_name,
        schema_ver = parser.schema_ver,
        table_lookup = {},
        column_refs = {},
        unique_index_refs = {},
        using_cols = {},
        _col_cache = {},
        nested_selects = {},
    }, Env)
end

function Env:error(...)
    return self.parser:error(...)
end

function Env:assert(...)
    return self.parser:assert(...)
end

-- Generates a unique symbol from a specified variable name.
function Env:sym(fmt, ...)
    -- The symbol index is stored in the parser so it gets shared across
    -- Env instances using the same parser.
    local symbol = self.parser.next_symbol or 0
    self.parser.next_symbol = symbol + 1
    return format("%s_%d", format(fmt, ...), symbol)
end

local list_columns_stmt = sq.bootstrap_statement([[
    SELECT column FROM schema
    WHERE db_name = ? AND table = ?
    ORDER BY id
]], squill.RETURN_SINGLE_COLUMN)
function Env:list_columns(table_name)
    if self._col_cache[table_name] then
        return self._col_cache[table_name]
    end

    if sq.bootstrap and self.db_name == "squill" then
        return assert(sq.bootstrap_col_names[table_name])
    end

    local list = list_columns_stmt(self.db_name, table_name)
    self._col_cache[table_name] = list
    return list
end

function Env:get_col_ref(table_alias, col_name)
    assert(not self.var_refs_inserted)

    if self.nested_selects[table_alias] then
        -- Nested SELECTs are row-oriented so the entire column value is nil
        -- They may also have invalid column identifiers for auto-generated
        -- names and select *
        local nested = self.nested_selects[table_alias]
        self:assert(table.indexof(nested.returned_columns, col_name) > 0,
            "Non-existent column in nested select: %q", col_name)
        return format(
            "(%s[rowid_%s] and %s[rowid_%s][%q])",
            nested.sym, table_alias,
            nested.sym, table_alias, col_name
        ), nil
    end

    self:assert(valid_identifier(col_name),
        "Not a valid column identifier: %q", col_name)

    local var_name
    for key, ref in pairs(self.column_refs) do
        if ref.table_alias == table_alias and ref.col_name == col_name then
            var_name = key
            break
        end
    end

    if not var_name then
        var_name = self:sym("col_%s_%s", table_alias, col_name)
        self.column_refs[var_name] = {
            table_alias = table_alias,
            col_name = col_name
        }
    end

    return format("%s[rowid_%s]", var_name, table_alias), var_name
end

local get_column_type = sq.bootstrap_statement([[
    SELECT type FROM schema
    WHERE db_name = ? AND table = ? AND column = ?
    LIMIT 1
]], squill.RETURN_SINGLE_VALUE)

local find_column_for_length = sq.bootstrap_statement([[
    SELECT column, type FROM schema WHERE db_name = ? AND table = ?
    ORDER BY type = ? DESC
]], squill.RETURN_FIRST_ROW)

function Env:insert_var_refs(code, readonly, modified_columns)
    assert(not self.var_refs_inserted)
    self.var_refs_inserted = true

    -- Look up attributes
    self.lengths = {}

    for var_name, var in pairs(self.unique_index_refs) do
        assert(var.allow_cache ~= nil)
        local col_type = get_column_type(self.db_name, var.table, var.col_name)
        self:assert(col_type, "Column or table not found: %s.%s", var.table, var.col_name)
        code:putf("local %s = get_unique_index(%q, %q, %q, %q, %s)",
            var_name, self.db_name, var.table, var.col_name, col_type,
            var.allow_cache and "true" or "false")
    end

    for alias, nested in pairs(self.nested_selects) do
        local func_name = self:sym("nested_select_func")
        code:register_func(func_name, nested.code)
        code:putf("local %s = %s(%s)", nested.sym, func_name,
            sq.get_param_arg_list(self.parser))
        self.lengths[alias] = "#" .. nested.sym
    end

    for var_name, var in pairs(self.column_refs) do
        local table_name = self.table_lookup[var.table_alias]
        if not table_name then
            error(format(
                "Reference to table that isn't in the query: %q", var.table_alias
            ))
        end
        if not self.lengths[var.table_alias] then
            self.lengths[var.table_alias] = format("%s.length", var_name)
        end

        local col_type
        if sq.bootstrap then
            -- Use hardcoded column types while compiling built-in queries
            assert(self.db_name == "squill")
            col_type = assert(sq.bootstrap_col_types[table_name][
                table.indexof(sq.bootstrap_col_names[table_name], var.col_name)
            ])
        else
            col_type = get_column_type(self.db_name, table_name, var.col_name)
            if not col_type then
                error(format("Non-existent column: %q", var.col_name))
            end
        end

        local modified = readonly or
            (modified_columns and not modified_columns[var.col_name])
        code:putf("local %s = get_column(%q, %q, %q, %q, %s)",
            var_name, self.db_name, table_name, var.col_name,
            col_type, modified and "true" or "false")
    end

    -- Edge case: Fill in lengths for tables that aren't included, and throw an
    -- error on empty tables
    for alias, table_name in pairs(self.table_lookup) do
        if not self.lengths[alias] then
            -- Prefer to use blob columns to get the size since the entire
            -- column doesn't have to be read
            -- The bootstrap SQL statements don't do this, so it doesn't have
            -- to account for them.
            -- TODO: Maybe store the size separately?
            local column = find_column_for_length(self.db_name, table_name,
                sq.BLOBS)
            if not column then
                error(format("Table %q does not exist", table_name), 0)
            end

            local var_name = format("length_of_%s", alias)
            self.lengths[alias] = var_name
            code:putf(
                "local %s = get_column(%q, %q, %q, %q, true).length",
                var_name, self.db_name, table_name, column.column, column.type
            )
        end
    end
end

function Env:add_set_columns(code, modified_columns, index_updated)
    if self.db_name == "squill" and not sq.allow_modifying_internal_db then
        -- Please don't uncomment this unless you know what you're doing or
        -- don't care about your data
        self:error("The 'squill' database is read-only to prevent " ..
            "accidental data corruption.")
    end

    -- Note: modified_columns has different types of values, the only thing
    -- this function can rely on is that they're truthy
    for col_var, col in pairs(self.column_refs) do
        if modified_columns[col.col_name] then
            code:putf(
                "set_column(%q, %q, %q, %s, %s)",
                self.db_name,
                col.table_alias, col.col_name, col_var,
                index_updated or "false"
            )
        end
    end
end
