local byte, find, format = string.byte, string.find, string.format
local abs = math.abs
local EQ, APOSTROPHE, LT, GT, NOT, ZERO, NINE, DOT, A_UPPER, Z_UPPER,
    A_LOWER, Z_LOWER, UNDERSCORE, DOLLARS, HYPHEN = byte("='<>!09.AZaz_$-", 1, -1)
local type = type

local modpath = core.get_modpath("squill")

local sq = squill._internal

local function is_digit(symbol)
    return symbol >= ZERO and symbol <= NINE
end
sq.is_digit = is_digit

local function is_variable_name_char(symbol)
    if not symbol then return false end

    return (symbol >= A_UPPER and symbol <= Z_UPPER) or
        (symbol >= A_LOWER and symbol <= Z_LOWER) or
        symbol == UNDERSCORE or is_digit(symbol)
end

local allowed_chars = {}
for _, b in ipairs({byte(".,()<>=+-*/^%?;", 1, -1)}) do
    allowed_chars[b] = true
end

local schema_vers = {}
sq.schema_vers = schema_vers

local Parser = {param_count = 0}
local parser_mt = {__index = Parser}

sq.Parser = Parser

function Parser.new(db_name, code)
    assert(code, "No SQL code provided")
    if find(db_name, "[^A-Za-z0-9_:]") then
        error(format("Invalid database name: %q", db_name))
    end

    local code_length = #code
    local last_start_idx = 1
    local last_idx = 1
    local peeked_token, peek_last_idx
    local function next_token(self, optional)
        if peeked_token then
            local res = peeked_token
            peeked_token = nil
            peek_last_idx = nil
            return res
        end

        local start_idx, end_idx = find(code, "%S", last_idx)
        if not start_idx then
            if not optional then
                error("Error parsing SQL statement: Unexpected end of file", 0)
            end
            return
        end

        local symbol = byte(code, start_idx)
        if symbol == APOSTROPHE then
            -- Two apostrophes in a row are escaped
            local idx = start_idx + 1
            repeat
                end_idx = assert(find(code, "'", idx, true), "Unclosed string literal")
                idx = end_idx + 2
            until byte(code, end_idx + 1) ~= APOSTROPHE
        elseif (
            (symbol == EQ and byte(code, start_idx + 1) == EQ) or
            (symbol == LT and byte(code, start_idx + 1) == GT) or
            (symbol == LT and byte(code, start_idx + 1) == EQ) or
            (symbol == GT and byte(code, start_idx + 1) == EQ) or
            (symbol == NOT and byte(code, start_idx + 1) == EQ)
        ) then
            end_idx = end_idx + 1
        elseif symbol == HYPHEN and byte(code, start_idx + 1) == HYPHEN then
            -- Comment: Skip to the end of the line and get a new token
            last_idx = (find(code, "\n", start_idx + 2, true) or code_length) + 1
            return next_token(self, true)
        elseif is_digit(symbol) or ((symbol == DOT or symbol == DOLLARS) and
                is_digit(byte(code, start_idx + 1) or 0)) then
            -- Find numbers
            end_idx = (find(code, "[^0-9%.e]", start_idx + 1) or code_length + 1) - 1
        elseif is_variable_name_char(symbol) then
            while is_variable_name_char(byte(code, end_idx + 1)) do
                end_idx = end_idx + 1
            end

            -- Casefold variable names
            last_idx = end_idx + 1
            return code:sub(start_idx, end_idx):lower()
        elseif symbol == HYPHEN and byte(code, start_idx + 1) == GT then
            -- JSON operators
            end_idx = end_idx + 1
            if byte(code, start_idx + 2) == GT then
                end_idx = end_idx + 1
            end
        elseif not allowed_chars[symbol] then
            error("Syntax error: Unrecognised character: " ..
                code:sub(start_idx, end_idx), 0)
        end

        last_start_idx = start_idx
        last_idx = end_idx + 1
        return code:sub(start_idx, end_idx)
    end

    local function peek(self)
        if peeked_token then return peeked_token end

        peek_last_idx = last_idx
        local token = next_token(self, true)
        peeked_token = token
        return token
    end

    local function save_pos()
        local pos = find(code, "%S", peek_last_idx or last_idx) or code_length + 1
        return function()
            return code:sub(pos, (peek_last_idx or last_idx) - 1)
        end
    end

    local parser = setmetatable({
        db_name = db_name,
        next = next_token,
        _code = code,
        _get_last_idxs = function() return last_start_idx, last_idx - 1 end,
        save_pos = save_pos,
        peek = peek,
    }, parser_mt)

    parser:update_schema_ver()
    return parser
end

function Parser:update_schema_ver()
    self.column_refs = {}
    self.unique_index_refs = {}
    self._col_cache = {}
    self.schema_ver = schema_vers[self.db_name] or 0
end

function Parser:assert(cond, msg, ...)
    if not cond then
        if select("#", ...) > 0 then
            msg = format(msg, ...)
        end
        self:error(msg)
    end
    return cond
end

function Parser:error(msg)
    local start_idx, end_idx = self:_get_last_idxs()
    local before = self._code:sub(1, start_idx - 1)
    local last_token = self._code:sub(start_idx, end_idx)

    local _, lineno = before:gsub("\n", "\n")
    lineno = lineno + 1

    local col = #before:match("[^\n]*$") + 1

    error(format("On line %d column %d (near %q): %s", lineno, col, last_token,
        msg or "Invalid syntax"), 0)
end

local operators = {}

operators["="] = function(a, b)
    if a == nil then return nil end
    return a == b
end
operators["=="] = operators["="]

operators["!="] = function(a, b)
    if a == nil then return nil end
    return a ~= b
end
operators["<>"] = operators["!="]

local function coerce_to_comparable(n)
    if n == true then
        return 1
    elseif n == false then
        return 0
    else
        return n
    end
end

operators[">"] = function(a, b)
    a, b = coerce_to_comparable(a), coerce_to_comparable(b)
    return a and b and a > b
end

operators["<"] = function(a, b)
    a, b = coerce_to_comparable(a), coerce_to_comparable(b)
    return a and b and a < b
end

operators[">="] = function(a, b)
    a, b = coerce_to_comparable(a), coerce_to_comparable(b)
    return a and b and a >= b
end

operators["<="] = function(a, b)
    a, b = coerce_to_comparable(a), coerce_to_comparable(b)
    return a and b and a <= b
end

operators["between"] = function(a, b, c)
    a, b, c = coerce_to_comparable(a), coerce_to_comparable(b), coerce_to_comparable(c)
    return a and b and a >= b and c and a <= c
end

local function coerce_to_number(n)
    return tonumber(coerce_to_comparable(n))
end

operators["unary +"] = coerce_to_number
operators["unary -"] = function(n)
    n = coerce_to_number(n)
    return n and -n
end
operators["^"] = function(a, b)
    a, b = coerce_to_number(a), coerce_to_number(b)
    return a and b and a ^ b
end
operators["+"] = function(a, b)
    a, b = coerce_to_number(a), coerce_to_number(b)
    return a and b and a + b
end
operators["-"] = function(a, b)
    a, b = coerce_to_number(a), coerce_to_number(b)
    return a and b and a - b
end
operators["*"] = function(a, b)
    a, b = coerce_to_number(a), coerce_to_number(b)
    return a and b and a * b
end
operators["/"] = function(a, b)
    a, b = coerce_to_number(a), coerce_to_number(b)
    return a and b and a / b
end

local function coerce_to_string(value)
    local value_type = type(value)
    if value_type == "string" then
        return value
    elseif value_type == "number" then
        return format("%.17g", value)
    elseif value_type == "boolean" then
        return value and "TRUE" or "FALSE"
    end
end

local function compile_like_pattern(pattern, already_compiled)
    if already_compiled then return pattern end

    pattern = pattern:gsub("[%(%)%.%%%+%-%*%?%[%^%$_]", function(chr)
        if chr == "%" then
            return ".-"
        elseif chr == "_" then
            return "."
        else
            return "%" .. chr
        end
    end)

    pattern = pattern:gsub("%a", function(chr)
        return format("[%s%s]", chr:upper(), chr:lower())
    end)

    pattern = pattern:sub(1, 2) == ".-" and pattern:sub(3) or "^" .. pattern
    pattern = pattern:sub(-2) == ".-" and pattern:sub(1, -3) or pattern .. "$"

    return pattern
end
sq.compile_like_pattern = compile_like_pattern

operators["like"] = function(str, pattern, compiled)
    str, pattern = coerce_to_string(str), coerce_to_string(pattern)
    if str == nil or pattern == nil then return nil end
    return find(str, compile_like_pattern(pattern, compiled)) ~= nil
end

operators["not like"] = function(str, pattern, compiled)
    str, pattern = coerce_to_string(str), coerce_to_string(pattern)
    if str == nil or pattern == nil then return nil end
    return find(str, compile_like_pattern(pattern, compiled)) == nil
end

local boolean_vals = {
    [true] = true, ["true"] = true, yes = true, on = true, ["1"] = true, [1] = true,
    [false] = false, ["false"] = false, no = false, off = false, ["0"] = false, [0] = false,
}

local function coerce_to_boolean(v)
    local bool = boolean_vals[v]
    if bool == nil and type(v) == "number" then
        bool = true
    end
    return bool
end
sq.coerce_to_boolean = coerce_to_boolean

operators["to_boolean"] = coerce_to_boolean

operators["unary not"] = function(v)
    local bool = coerce_to_boolean(v)
    if bool == nil then return nil end
    return not bool
end

-- Or is implemented with two helper functions so that "NULL or FALSE" returns
-- NULL instead of FALSE.
local or_is_null
operators["or_left"] = function(v)
    local bool = coerce_to_boolean(v)
    or_is_null = bool == nil
    return bool
end

operators["or_right"] = function(v)
    local bool = coerce_to_boolean(v)
    if or_is_null and not bool then
        return nil
    end
    return bool
end

operators["->"], operators["->>"] = dofile(modpath .. "/json.lua")

local reserved_words = {
    from = true, join = true, where = true, order = true, group = true,
    using = true, with = true, limit = true, [";"] = true, having = true,
    transaction = true, by = true, as = true, on = true, all = true,
    distinct = true, current_time = true, current_date = true,
    current_timestamp = true, primary = true,
}
sq.reserved_words = reserved_words

sq.special_vars = {
    null = "nil", ["true"] = "true", ["false"] = "false",
    current_time = "os_date('!%H:%M:%S')",
    current_date = "os_date('!%Y-%m-%d')",
    current_timestamp = "os_date('!%Y-%m-%d %H:%M:%S')",
}
for k in pairs(sq.special_vars) do reserved_words[k] = true end
for op in pairs(operators) do reserved_words[op] = true end

local function valid_identifier(name)
    return not reserved_words[name] and find(name, "^[a-z_][a-z0-9_]*$") ~= nil
end
sq.valid_identifier = valid_identifier

function Parser:get_col_ref(tbl, col_name)
    self:assert(valid_identifier(col_name), "Not a valid column identifier: %q", col_name)

    local key = format("col%d_%s_%s", #tbl, tbl, col_name)
    self.column_refs[key] = {
        table = tbl,
        col_name = col_name
    }

    return format("%s[rowid_%s]", key, tbl), key
end

local get_column_type = sq.bootstrap_statement([[
    SELECT type FROM schema
    WHERE db_name = ? AND table = ? AND column = ?
    LIMIT 1
]], squill.RETURN_SINGLE_VALUE)

local find_column_for_length = sq.bootstrap_statement([[
    SELECT column, type FROM schema WHERE db_name = ? AND table = ?
    ORDER BY type = ? DESC
]], squill.RETURN_FIRST_ROW)

function Parser:insert_var_refs(code, readonly, modified_columns)
    -- Look up attributes
    self.lengths = {}

    for i, var in pairs(self.unique_index_refs) do
        assert(var.allow_cache ~= nil)
        local col_type = get_column_type(self.db_name, var.table, var.col_name)
        self:assert(col_type, "Column or table not found: %s.%s", var.table, var.col_name)
        code[#code + 1] = format("local unique_%d = get_unique_index(%q, %q, %q, %q, %s)",
            i, self.db_name, var.table, var.col_name, col_type,
            var.allow_cache and "true" or "false")
    end

    for var_name, var in pairs(self.column_refs) do
        if not var.table then
            local new_key, table_name_as = self:expr_to_lua(var_name)
            if not self.column_refs[new_key] then
                var_name = new_key
                var.table = table_name_as
            end
        end

        if var.table then
            local table_name = self.table_lookup[var.table]
            if not table_name then
                error(format(
                    "Reference to table that isn't in the query: %q", var.table
                ))
            end
            if not self.lengths[var.table] then
                self.lengths[var.table] = format("%s.length", var_name)
            end

            local col_type
            if sq.bootstrap then
                -- Use hardcoded column types while compiling built-in queries
                assert(self.db_name == "squill")
                col_type = assert(sq.bootstrap_col_types[table_name][
                    table.indexof(sq.bootstrap_col_names[table_name], var.col_name)
                ])
            else
                col_type = get_column_type(self.db_name, table_name, var.col_name)
                if not col_type then
                    error(format("Non-existent column: %q", var.col_name))
                end
            end

            local modified = readonly or
                (modified_columns and not modified_columns[var.col_name])
            code[#code + 1] = format("local %s = get_column(%q, %q, %q, %q, %s)",
                var_name, self.db_name, table_name, var.col_name,
                col_type, modified and "true" or "false")
        end
    end

    -- Edge case: Fill in lengths for tables that aren't included, and throw an
    -- error on empty tables
    for alias, table_name in pairs(self.table_lookup) do
        if not self.lengths[alias] then
            -- Prefer to use blob columns to get the size since the entire
            -- column doesn't have to be read
            -- The bootstrap SQL statements don't do this, so it doesn't have
            -- to account for them.
            -- TODO: Maybe store the size separately?
            local column = find_column_for_length(self.db_name, table_name,
                sq.BLOBS)
            if not column then
                error(format("Table %q does not exist", table_name), 0)
            end

            local var_name = format("length_of_%s", alias)
            self.lengths[alias] = var_name
            code[#code + 1] = format(
                "local %s = get_column(%q, %q, %q, %q, true).length",
                var_name, self.db_name, table_name, column.column, column.type
            )
        end
    end
end

local current_db_name
local function make_coerce_helper(coerce)
    return function(value, qualname)
        if value == nil then return nil end

        local res = coerce(value)
        if res == nil then
            sq.rollback_transaction(current_db_name, true)
            error(format("Cannot store value %s (of type %s) in column %s",
                dump(value), type(value), qualname), 3)
        end

        return res
    end
end

local str_coerce_or_error = make_coerce_helper(coerce_to_string)
local stmt_env = {
    OPS = operators,

    -- Make assert and error automatically roll back transactions
    assert = function(value, msg)
        if not value then
            sq.rollback_transaction(current_db_name, true)
            error(msg, 3)
        end
        return value
    end,
    error = function(msg)
        sq.rollback_transaction(current_db_name, true)
        error(msg, 3)
    end,

    -- ipairs is not included to prevent its accidental usage over column.length
    -- (since ipairs cannot handle null values)
    pairs = pairs,
    table_sort = table.sort,
    get_column = sq.get_column,
    set_column = sq.set_column,
    get_unique_index = sq.get_unique_index,
    begin_stmt = function(db_name, ver)
        if (schema_vers[db_name] or 0) ~= ver then
            sq.rollback_transaction(db_name, true)
            error("Prepared statement is no longer valid, the database " ..
                "schema has been modified since it was created.", 3)
        end

        -- Store current database name for automatic rollbacks
        current_db_name = db_name
    end,
    coerce_to_string = coerce_to_string,
    coerce_or_error = {
        [sq.NUMBERS] = make_coerce_helper(coerce_to_number),
        [sq.INTEGERS] = make_coerce_helper(function(n)
            n = coerce_to_number(n)
            if n and n % 1 == 0 and abs(n) < 2^53 then
                return n
            end
        end),
        [sq.BOOLEANS] = make_coerce_helper(coerce_to_boolean),
        [sq.STRINGS] = str_coerce_or_error,
        [sq.BLOBS] = str_coerce_or_error,
    },
    autoincrement_set_min = sq.bootstrap_statement([[
        UPDATE schema SET autoincrement = $4 + 1
        WHERE db_name = $1 AND table = $2 AND column = $3 AND autoincrement <= $4
    ]]),

    str_upper = function(s)
        s = coerce_to_string(s)
        return s and s:upper()
    end,
    str_lower = function(s)
        s = coerce_to_string(s)
        return s and s:lower()
    end,
    coalesce = function(...)
        for i = 1, select("#", ...) do
            local arg = select(i, ...)
            if arg ~= nil then
                return arg
            end
        end
        return nil
    end,

    math = math,
    os_date = os.date,
}

-- local a stmt_env.print=print stmt_env.dump=dump

local cmds = {}
sq.cmds = cmds
for _, cmd in ipairs({"select", "insert", "update", "delete", "create",
        "drop", "tx", "alter", "explain", "pragma"}) do
    cmds[cmd] = dofile(modpath .. "/commands/" .. cmd .. ".lua")
end
sq.cmds = nil

for cmd in pairs(cmds) do
    reserved_words[cmd] = true
end

setmetatable(stmt_env, {
    __index = function(_, k)
        error(format(
            "Attempt to get undefined variable %q in statement environment", k
        ))
    end,
    __newindex = function(_, k)
        error(format(
            "Attempt to create global variable %q in statement environment", k
        ))
    end
})

-- Convert the environment to local variables
local stmt_localise
do
    local list = {}
    for var_name in pairs(stmt_env) do
        list[#list + 1] = var_name
    end
    local names = table.concat(list, ", ")
    stmt_localise = format("local %s = %s", names, names)
end

sq.PARAM_VAR_PREFIX = "param_"
function Parser:create_function_def()
    local args = {}
    for i = 1, self.param_count do
        args[i] = sq.PARAM_VAR_PREFIX .. i
    end

    return format("%s\nreturn function(%s)\nbegin_stmt(%q, %d)",
        stmt_localise, table.concat(args, ", "), self.db_name, self.schema_ver)
end

function Parser:pop_if_equals(value)
    if self:peek() == value then
        self:next()
        return true
    end
    return false
end

function Parser:expect(value)
    local token = self:next()
    if token ~= value then
        self:error(format("Expected %q, not %q", value:upper(), token), 2)
    end
end

function Parser:add_set_columns(code, modified_columns, index_updated)
    if self.db_name == "squill" and not sq.allow_modifying_internal_db then
        -- Please don't uncomment this unless you know what you're doing or
        -- don't care about your data
        self:error("The 'squill' database is read-only to prevent " ..
            "accidental data corruption.")
    end

    -- Note: modified_columns has different types of values, the only thing
    -- this function can rely on is that they're truthy
    for col_var, col in pairs(self.column_refs) do
        if modified_columns[col.col_name] then
            code[#code + 1] = format("set_column(%q, %q, %q, %s, %s)", self.db_name,
                col.table, col.col_name, col_var, index_updated or "false")
        end
    end
end

local list_columns_stmt = sq.bootstrap_statement([[
    SELECT column FROM schema
    WHERE db_name = ? AND table = ?
    ORDER BY id
]], squill.RETURN_SINGLE_COLUMN)
function Parser:list_columns(table_name)
    if self._col_cache[table_name] then
        return self._col_cache[table_name]
    end

    if sq.bootstrap and self.db_name == "squill" then
        return assert(sq.bootstrap_col_names[table_name])
    end

    local list = list_columns_stmt(self.db_name, table_name)
    self._col_cache[table_name] = list
    return list
end

function Parser:compile_lua(lua, ...)
    local f
    if loadstring then
        f = assert(loadstring(lua, self._code))
        if f then setfenv(f, stmt_env) end
    else
        f = assert(load(lua, self._code, "t", stmt_env))
    end
    return f(...)

    -- local func = f()
    -- return function(...)
    --     print('*** RUNNING ***', ...)
    --     print(self._code)
    --     print(lua)
    --     return func(...)
    -- end
end

local query_cache = {}
function Parser:add_to_query_cache(func, returned_columns)
    if #self._code > 4096 then
        -- This is probably a once-off query like importing a table
        return
    end

    query_cache[self.db_name] = query_cache[self.db_name] or {}
    query_cache[self.db_name][self._code] = {
        func = func,
        returned_columns = returned_columns,
        schema_ver = self.schema_ver,
        last_used = core.get_us_time(),
    }
end

function squill.drop_query_cache()
    for k in pairs(query_cache) do
        query_cache[k] = nil
    end
end

-- Remove old entries from the query cache
local query_cache_ttl = 60
if core.settings:get_bool("squill.aggressive_caching") then
    query_cache_ttl = query_cache_ttl * 10
end
local function clean_up_query_cache()
    local min_age = core.get_us_time() - 60 * 1e6

    for db_name, db_cache in pairs(query_cache) do
        local is_empty = true
        local schema_ver = schema_vers[db_name] or 0
        for _, cache in pairs(db_cache) do
            if cache.last_used < min_age or cache.schema_ver ~= schema_ver then
                db_cache[cache] = nil
            else
                is_empty = false
            end
        end

        if is_empty then
            query_cache[db_name] = nil
        end
    end
    core.after(query_cache_ttl / 2, clean_up_query_cache)
end
core.after(query_cache_ttl / 2, clean_up_query_cache)

-- Prepares a single query
function Parser:parse_query()
    local cmd = self:next()
    self:assert(cmds[cmd], "Unsupported command")
    local func, returned_columns = cmds[cmd](self)
    self:pop_if_equals(";")
    if self:peek() then
        self:error(format("Unexpected token: %q", self:peek()))
    end

    returned_columns = returned_columns or {}
    self:add_to_query_cache(func, returned_columns)
    return func, returned_columns
end

-- Executes one or more queries
function Parser:exec_multiple(...)
    local all_rows = {}
    local all_returned_columns = {}

    local old_schema_ver = self.schema_ver
    local func, returned_columns
    repeat
        local cmd = self:next()
        assert(cmds[cmd], "Unsupported command")
        func, returned_columns = cmds[cmd](self)
        local rows = func(...)
        all_rows[#all_rows + 1] = rows or {}
        all_returned_columns[#all_returned_columns + 1] = returned_columns or {}
        self:update_schema_ver()
    until not self:pop_if_equals(";") or not self:peek()

    if self:peek() then
        self:error(format("Unexpected token: %q", self:peek()))
    end

    if #all_returned_columns == 1 and self.schema_ver == old_schema_ver then
        -- There was only one statement and it didn't touch the schema, cache it
        -- Caching multiple statements probably isn't worth the effort
        self:add_to_query_cache(func, returned_columns)
    end

    return all_rows, all_returned_columns
end

local function get_cached_stmt(db_name, sql)
    local cache = query_cache[db_name] and query_cache[db_name][sql]
    if cache and cache.schema_ver == (schema_vers[db_name] or 0) then
        cache.last_used = core.get_us_time()
        return cache.func, cache.returned_columns
    end
end

function squill.prepare_statement(db_name, sql, return_mode)
    local func, returned_columns = get_cached_stmt(db_name, sql)
    if not func then
        func, returned_columns = Parser.new(db_name, sql):parse_query()
    end

    -- Just transform the return output instead of creating it as desired for
    -- now to reduce the number of things to account for in commands/*.lua.
    if return_mode == squill.RETURN_FIRST_ROW then
        return function(...)
            return func(...)[1]
        end, returned_columns
    elseif return_mode == squill.RETURN_SINGLE_COLUMN then
        assert(#returned_columns == 1,
            "The SQL statement does not return exactly one column")
        local col = returned_columns[1]
        return function(...)
            local rows = func(...)
            for i = 1, #rows do
                rows[i] = rows[i][col]
            end
            return rows
        end, returned_columns
    elseif return_mode == squill.RETURN_SINGLE_VALUE then
        assert(#returned_columns == 1,
            "The SQL statement does not return exactly one column")
        local col = returned_columns[1]
        return function(...)
            local rows = func(...)
            return rows[1] and rows[1][col]
        end, returned_columns
    elseif return_mode ~= squill.RETURN_ALL_ROWS then
        error("Invalid value for return_mode parameter", 2)
    end

    return func, returned_columns
end

function squill.exec(db_name, sql, ...)
    local func, returned_columns = get_cached_stmt(db_name, sql)
    if func then
        return {func(...)}, {returned_columns}
    end

    return Parser.new(db_name, sql):exec_multiple(...)
end
