local byte, find, format = string.byte, string.find, string.format
local EQ, APOSTROPHE, LT, GT, NOT, ZERO, NINE, DOT, A_UPPER, Z_UPPER, A_LOWER,
    Z_LOWER, UNDERSCORE, DOLLARS, HYPHEN, ASTERISK, SLASH, PIPE, COLON =
    byte("='<>!09.AZaz_$-*/|:", 1, -1)

local modpath = core.get_modpath("squill")

local sq = squill._internal
local schema_vers = sq.schema_vers

local function is_digit(symbol)
    return symbol >= ZERO and symbol <= NINE
end
sq.is_digit = is_digit

local function is_variable_name_char(symbol)
    if not symbol then return false end

    return (symbol >= A_UPPER and symbol <= Z_UPPER) or
        (symbol >= A_LOWER and symbol <= Z_LOWER) or
        symbol == UNDERSCORE or is_digit(symbol)
end

local allowed_chars = {}
for _, b in ipairs({byte(".,()<>=+-*/^%?;", 1, -1)}) do
    allowed_chars[b] = true
end

local Parser = {param_count = 0}
local parser_mt = {__index = Parser}

sq.Parser = Parser

function Parser.new(db_name, code)
    assert(code, "No SQL code provided")
    if find(db_name, "[^A-Za-z0-9_:]") then
        error(format("Invalid database name: %q", db_name))
    end

    local code_length = #code
    local last_start_idx = 1
    local last_idx = 1
    local peeked_token, peek_last_idx
    local function next_token(self, optional)
        if peeked_token then
            local res = peeked_token
            peeked_token = nil
            peek_last_idx = nil
            return res
        end

        local start_idx, end_idx = find(code, "%S", last_idx)
        if not start_idx then
            if not optional then
                error("Error parsing SQL statement: Unexpected end of file", 0)
            end
            return
        end

        local symbol = byte(code, start_idx)
        if symbol == APOSTROPHE then
            -- Two apostrophes in a row are escaped
            local idx = start_idx + 1
            repeat
                end_idx = assert(find(code, "'", idx, true), "Unclosed string literal")
                idx = end_idx + 2
            until byte(code, end_idx + 1) ~= APOSTROPHE
        elseif (
            (symbol == EQ and byte(code, start_idx + 1) == EQ) or
            (symbol == LT and byte(code, start_idx + 1) == GT) or
            (symbol == LT and byte(code, start_idx + 1) == EQ) or
            (symbol == GT and byte(code, start_idx + 1) == EQ) or
            (symbol == NOT and byte(code, start_idx + 1) == EQ) or
            (symbol == PIPE and byte(code, start_idx + 1) == PIPE)
        ) then
            end_idx = end_idx + 1
        elseif symbol == HYPHEN and byte(code, start_idx + 1) == HYPHEN then
            -- Comment: Skip to the end of the line and get a new token
            last_idx = (find(code, "\n", start_idx + 2, true) or code_length) + 1
            return next_token(self, true)
        elseif symbol == SLASH and byte(code, start_idx + 1) == ASTERISK then
            -- Block comment
            local _, comment_end = find(code, "*/", start_idx + 2, true)
            self:assert(comment_end, "Unclosed block comment")
            last_idx = comment_end + 1
            return next_token(self, true)
        elseif is_digit(symbol) or ((symbol == DOT or symbol == DOLLARS) and
                is_digit(byte(code, start_idx + 1) or 0)) then
            -- Find numbers (matches letters too so that "1a" is an error)
            end_idx = (find(code, "[^0-9%.A-Za-z]", start_idx + 1) or code_length + 1) - 1
        elseif is_variable_name_char(symbol) then
            while is_variable_name_char(byte(code, end_idx + 1)) do
                end_idx = end_idx + 1
            end

            -- Casefold variable names
            last_start_idx = start_idx
            last_idx = end_idx + 1
            return code:sub(start_idx, end_idx):lower()
        elseif symbol == COLON and is_variable_name_char(byte(code, start_idx + 1)) then
            -- Named parameters
            repeat
                end_idx = end_idx + 1
            until not is_variable_name_char(byte(code, end_idx + 1))

            local param_name = code:sub(start_idx + 1, end_idx)
            last_start_idx = start_idx
            last_idx = end_idx + 1
            self:assert(param_name:lower() == param_name,
                "Named parameters must be lowercase")
        elseif symbol == HYPHEN and byte(code, start_idx + 1) == GT then
            -- JSON operators
            end_idx = end_idx + 1
            if byte(code, start_idx + 2) == GT then
                end_idx = end_idx + 1
            end
        elseif not allowed_chars[symbol] then
            error("Syntax error: Unrecognised character: " ..
                code:sub(start_idx, end_idx), 0)
        end

        last_start_idx = start_idx
        last_idx = end_idx + 1
        return code:sub(start_idx, end_idx)
    end

    local function peek(self)
        if peeked_token then return peeked_token end

        peek_last_idx = last_idx
        local token = next_token(self, true)
        peeked_token = token
        return token
    end

    local function save_pos()
        local pos = find(code, "%S", peek_last_idx or last_idx) or code_length + 1
        return function()
            return code:sub(pos, (peek_last_idx or last_idx) - 1)
        end
    end

    local parser = setmetatable({
        db_name = db_name,
        next = next_token,
        _code = code,
        _get_last_idxs = function() return last_start_idx, last_idx - 1 end,
        save_pos = save_pos,
        peek = peek,
    }, parser_mt)

    parser:update_schema_ver()
    return parser
end

function Parser:update_schema_ver()
    self.schema_ver = schema_vers[self.db_name] or 0
end

function Parser:assert(cond, msg, ...)
    if not cond then
        if select("#", ...) > 0 then
            msg = format(msg, ...)
        end
        self:error(msg)
    end
    return cond
end

function Parser:error(msg)
    local start_idx, end_idx = self:_get_last_idxs()
    local before = self._code:sub(1, start_idx - 1)
    local last_token = self._code:sub(start_idx, end_idx)

    local _, lineno = before:gsub("\n", "\n")
    lineno = lineno + 1

    local col = #before:match("[^\n]*$") + 1

    error(format("On line %d column %d (near %q): %s", lineno, col, last_token,
        msg or "Invalid syntax"), 0)
end

local reserved_words = {
    from = true, join = true, where = true, order = true, group = true,
    using = true, with = true, limit = true, [";"] = true, having = true,
    transaction = true, by = true, as = true, on = true, all = true,
    distinct = true, current_time = true, current_date = true,
    current_timestamp = true, primary = true, default = true,

    union = true, intersect = true, except = true,

    -- JOIN keywords
    left = true, right = true, full = true, outer = true, inner = true, cross = true,
}
sq.reserved_words = reserved_words

sq.special_vars = {
    null = "nil", ["true"] = "true", ["false"] = "false",
    current_time = "os_date('!%H:%M:%S')",
    current_date = "os_date('!%Y-%m-%d')",
    current_timestamp = "os_date('!%Y-%m-%d %H:%M:%S')",
}
for k in pairs(sq.special_vars) do reserved_words[k] = true end
for op in pairs(sq.operators) do reserved_words[op] = true end

local function valid_identifier(name)
    return not reserved_words[name] and find(name, "^[a-z_][a-z0-9_]*$") ~= nil
end
sq.valid_identifier = valid_identifier

local cmds = {}
sq.cmds = cmds
for _, cmd in ipairs({"select", "insert", "update", "delete", "create",
        "drop", "tx", "alter", "explain", "pragma"}) do
    cmds[cmd] = dofile(modpath .. "/commands/" .. cmd .. ".lua")
end
sq.cmds = nil

for cmd in pairs(cmds) do
    reserved_words[cmd] = true
end

function Parser:pop_if_equals(value)
    if self:peek() == value then
        self:next()
        return true
    end
    return false
end

function Parser:expect(value)
    local token = self:next()
    if token ~= value then
        self:error(format("Expected %q, not %q", value:upper(), token), 2)
    end
end

local query_cache = {}
function Parser:add_to_query_cache(func, returned_columns)
    if #self._code > 4096 then
        -- This is probably a once-off query like importing a table
        return
    end

    query_cache[self.db_name] = query_cache[self.db_name] or {}
    query_cache[self.db_name][self._code] = {
        func = func,
        returned_columns = returned_columns,
        schema_ver = self.schema_ver,
        last_used = core.get_us_time(),
    }
end

function squill.drop_query_cache()
    for k in pairs(query_cache) do
        query_cache[k] = nil
    end
end

-- Remove old entries from the query cache
local query_cache_ttl = 60
if core.settings:get_bool("squill.aggressive_caching") then
    query_cache_ttl = query_cache_ttl * 10
end
local function clean_up_query_cache()
    local min_age = core.get_us_time() - 60 * 1e6

    for db_name, db_cache in pairs(query_cache) do
        local is_empty = true
        local schema_ver = schema_vers[db_name] or 0
        for _, cache in pairs(db_cache) do
            if cache.last_used < min_age or cache.schema_ver ~= schema_ver then
                db_cache[cache] = nil
            else
                is_empty = false
            end
        end

        if is_empty then
            query_cache[db_name] = nil
        end
    end
    core.after(query_cache_ttl / 2, clean_up_query_cache)
end
core.after(query_cache_ttl / 2, clean_up_query_cache)

-- Prepares a single query
function Parser:parse_query()
    local cmd = self:next()
    self:assert(cmds[cmd], "Unsupported command")
    local func, returned_columns = cmds[cmd](self)
    self:pop_if_equals(";")
    if self:peek() then
        self:error(format("Unexpected token: %q", self:peek()))
    end

    returned_columns = returned_columns or {}
    self:add_to_query_cache(func, returned_columns)
    return func, returned_columns
end

-- Executes one or more queries
function Parser:exec_multiple(...)
    local all_rows = {}
    local all_returned_columns = {}

    local old_schema_ver = self.schema_ver
    local func, returned_columns
    repeat
        local cmd = self:next()
        self:assert(cmds[cmd], "Unsupported command: %q", cmd)
        func, returned_columns = cmds[cmd](self)
        local rows = func(...)
        all_rows[#all_rows + 1] = rows or {}
        all_returned_columns[#all_returned_columns + 1] = returned_columns or {}
        self:update_schema_ver()
    until not self:pop_if_equals(";") or not self:peek()

    if self:peek() then
        self:error(format("Unexpected token: %q", self:peek()))
    end

    if #all_returned_columns == 1 and self.schema_ver == old_schema_ver then
        -- There was only one statement and it didn't touch the schema, cache it
        -- Caching multiple statements probably isn't worth the effort
        self:add_to_query_cache(func, returned_columns)
    end

    return all_rows, all_returned_columns
end

local function get_cached_stmt(db_name, sql)
    local cache = query_cache[db_name] and query_cache[db_name][sql]
    if cache and cache.schema_ver == (schema_vers[db_name] or 0) then
        cache.last_used = core.get_us_time()
        return cache.func, cache.returned_columns
    end
end

sq.ERR_NOT_ONE_COLUMN = "The SQL statement does not return exactly one column"
function squill.prepare_statement(db_name, sql, return_mode)
    local func, returned_columns = get_cached_stmt(db_name, sql)
    if not func then
        func, returned_columns = Parser.new(db_name, sql):parse_query()
    end

    -- Just transform the return output instead of creating it as desired for
    -- now to reduce the number of things to account for in commands/*.lua.
    if return_mode == squill.RETURN_FIRST_ROW then
        return function(...)
            return func(...)[1]
        end, returned_columns
    elseif return_mode == squill.RETURN_SINGLE_COLUMN then
        assert(#returned_columns == 1, sq.ERR_NOT_ONE_COLUMN)
        local col = returned_columns[1]
        return function(...)
            local rows = func(...)
            for i = 1, #rows do
                rows[i] = rows[i][col]
            end
            return rows
        end, returned_columns
    elseif return_mode == squill.RETURN_SINGLE_VALUE then
        assert(#returned_columns == 1, sq.ERR_NOT_ONE_COLUMN)
        local col = returned_columns[1]
        return function(...)
            local rows = func(...)
            return rows[1] and rows[1][col]
        end, returned_columns
    elseif return_mode ~= squill.RETURN_ALL_ROWS then
        error("Invalid value for return_mode parameter", 2)
    end

    return func, returned_columns
end

function squill.exec(db_name, sql, ...)
    local func, returned_columns = get_cached_stmt(db_name, sql)
    if func then
        return {func(...)}, {returned_columns}
    end

    return Parser.new(db_name, sql):exec_multiple(...)
end
