--
-- Squill: Database engine for Luanti (formerly known as Minetest)
--
-- Copyright © 2025 by luk3yx
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Lesser General Public License as published by
-- the Free Software Foundation, either version 2.1 of the License, or
-- (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU Lesser General Public License for more details.
--
-- You should have received a copy of the GNU Lesser General Public License
-- along with this program.  If not, see <https://www.gnu.org/licenses/>.
--

squill = {}

-- Options for "return_mode" parameter
squill.RETURN_ALL_ROWS = nil
squill.RETURN_FIRST_ROW = 1
squill.RETURN_SINGLE_COLUMN = 2
squill.RETURN_SINGLE_VALUE = 3

local sq = {}
squill._internal = sq

local modpath = core.get_modpath("squill")
dofile(modpath .. "/storage.lua")
dofile(modpath .. "/schema.lua")
dofile(modpath .. "/parser.lua")
dofile(modpath .. "/parse_expr.lua")
dofile(modpath .. "/joins.lua")
dofile(modpath .. "/update_row.lua")

sq.bootstrap()

dofile(modpath .. "/dump.lua")
squill._internal = nil

-- Lazy load API wrappers
squill.compat = setmetatable({}, {__index = function(self, key)
    if key ~= "pgmoon" and key ~= "lsqlite3" and key ~= "luasql" then
        return
    end
    local api = dofile(modpath .. "/compat/" .. key .. ".lua")
    self[key] = api
    return api
end})

-- Connection API
local backends = {}

function squill.connect(db_name, supported_backends)
    supported_backends = supported_backends or {"squill"}

    -- Prefer existing databases first
    for _, backend in ipairs(supported_backends) do
        if backends[backend] then
            local db = backends[backend](db_name, true)
            if db then
                return db
            end
        end
    end

    for _, backend in ipairs(supported_backends) do
        if backends[backend] then
            return backends[backend](db_name)
        end
    end

    error("No supported backends found")
end

-- Basic wrapper for native Squill "connections"
do
    local Connection = {dbms = "squill"}
    Connection.__index = Connection
    function Connection:prepare(sql, return_mode)
        return squill.prepare_statement(self._db_name, sql, return_mode)
    end

    function Connection:exec(sql, ...)
        -- Discard return values to match other backends
        squill.exec(self._db_name, sql, ...)
    end

    function Connection:blob(str)
        assert(type(str) == "string")
        return str
    end

    local db_exists = squill.prepare_statement("squill", [[
        SELECT COUNT(*) > 0 FROM schema WHERE db_name = ?
    ]], squill.RETURN_SINGLE_VALUE)
    function backends.squill(db_name, check_existing)
        if not check_existing or db_exists(db_name) then
            return setmetatable({_db_name = db_name}, Connection)
        end
    end
end

-- GUI
if core.settings:get_bool("squill.enable_gui") then
    dofile(modpath .. "/gui.lua")
end

if core.global_exists("mtt") and mtt.enabled then
    dofile(modpath .. "/test.lua")
end

-- SQLite backend below (must be in init.lua to avoid having to use debug to
-- guess at whether get_modpath has been tampered with)
if not core.settings:get_bool("squill.enable_sqlite_ffi_backend") then
    return
end

local ie = core.request_insecure_environment()
if not ie then
    error([[
        squill.enable_sqlite_ffi_backend is true but Squill is not in
        secure.trusted_mods.

        If you don't know what this error means, disable the "Enable SQLite FFI
        backend" setting for the "Squill" mod in Luanti's settings.
    ]])
end

-- The below code must take extra care to be secure and not do anything like
-- using the string metatable

local insecure_io_open = io.open

-- Make luacheck complain about any accidental global access
-- luacheck: not_globals _G _VERSION arg assert bit collectgarbage coroutine
-- luacheck: not_globals debug dofile error gcinfo getfenv getmetatable io
-- luacheck: not_globals ipairs jit load loadfile loadstring math module
-- luacheck: not_globals newproxy next os package pairs pcall print rawequal
-- luacheck: not_globals rawget rawlen rawset require select setfenv
-- luacheck: not_globals setmetatable string table tonumber tostring type
-- luacheck: not_globals unpack xpcall

-- Using FFI seems to be the least error-prone way of doing insecure
-- environment things, since it doesn't leak any global variables and lets us
-- access sqlite3_limit().
local ffi = ie.require("ffi")

ffi.cdef[[
    typedef struct sqlite3 sqlite3;
    typedef struct sqlite3_stmt sqlite3_stmt;

    int sqlite3_open(const char *filename, sqlite3 **ppDb);
    int sqlite3_close_v2(sqlite3*);
    const char *sqlite3_errmsg(sqlite3*);

    int sqlite3_limit(sqlite3*, int id, int newVal);

    int sqlite3_exec(
        sqlite3*,
        const char *sql,
        int (*callback)(void*, int, char**, char**),
        void *,
        char **errmsg
    );

    int sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nByte,
        sqlite3_stmt **ppStmt, const char **pzTail);
    int sqlite3_column_count(sqlite3_stmt *pStmt);
    const char *sqlite3_column_name(sqlite3_stmt*, int N);
    int sqlite3_step(sqlite3_stmt*);
    int sqlite3_reset(sqlite3_stmt*);
    int sqlite3_clear_bindings(sqlite3_stmt*);
    int sqlite3_finalize(sqlite3_stmt*);

    int sqlite3_bind_parameter_count(sqlite3_stmt*);
    const char *sqlite3_bind_parameter_name(sqlite3_stmt*, int);
    int sqlite3_bind_text(sqlite3_stmt*, int, const void*, int n, void(*)(void*));
    int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*));
    int sqlite3_bind_double(sqlite3_stmt*, int, double);
    int sqlite3_bind_null(sqlite3_stmt*, int);

    const void *sqlite3_column_blob(sqlite3_stmt*, int iCol);
    double sqlite3_column_double(sqlite3_stmt*, int iCol);
    int sqlite3_column_bytes(sqlite3_stmt*, int iCol);
    int sqlite3_column_type(sqlite3_stmt*, int iCol);

    int sqlite3_changes(sqlite3*);
]]

local assert, debug, error, pairs, select, setmetatable, string, tonumber, type =
    ie.assert, ie.debug, ie.error, ie.pairs, ie.select, ie.setmetatable,
    ie.string, ie.tonumber, ie.type

-- Data types
local SQLITE_INTEGER = 1
local SQLITE_FLOAT = 2
local SQLITE_TEXT = 3
local SQLITE_BLOB = 4
local SQLITE_NULL = 5

-- Error codes
local SQLITE_OK = 0
local SQLITE_ROW = 100
local SQLITE_DONE = 101

-- Limits
local SQLITE_LIMIT_LENGTH = 0
local SQLITE_LIMIT_SQL_LENGTH = 1
local SQLITE_LIMIT_COLUMN = 2
local SQLITE_LIMIT_EXPR_DEPTH = 3
local SQLITE_LIMIT_COMPOUND_SELECT = 4
local SQLITE_LIMIT_VDBE_OP = 5
local SQLITE_LIMIT_FUNCTION_ARG = 6
local SQLITE_LIMIT_ATTACHED = 7
local SQLITE_LIMIT_LIKE_PATTERN_LENGTH = 8
local SQLITE_LIMIT_VARIABLE_NUMBER = 9
local SQLITE_LIMIT_TRIGGER_DEPTH = 10

-- Misc
local SQLITE_TRANSIENT = ffi.cast("void(*)(void*)", -1)

-- Luanti has already loaded sqlite, no need to load it again
-- local C = ffi.load("libsqlite3.so.0")
local C = ffi.C

-- Binds a value to SQLite (1-indexed)
local blob_objs = setmetatable({}, {__index = "k"})
local function bind_value(stmt, idx, value)
    if value == nil then
        return C.sqlite3_bind_null(stmt, idx)
    end

    local val_type = type(value)
    if val_type == "string" then
        return C.sqlite3_bind_text(stmt, idx, value, string.len(value),
            SQLITE_TRANSIENT)
    elseif val_type == "number" then
        return C.sqlite3_bind_double(stmt, idx, value)
    elseif val_type == "boolean" then
        return C.sqlite3_bind_double(stmt, idx, value and 1 or 0)
    elseif val_type == "table" then
        local blob = blob_objs[value]
        if blob then
            return C.sqlite3_bind_blob(stmt, idx, blob, string.len(blob),
                SQLITE_TRANSIENT)
        end
    end

    error("Unsupported type passed to bind: " .. val_type)
end

-- Reads a SQLite3 value (zero-indexed)
local function read_value(stmt, idx)
    local ctype = C.sqlite3_column_type(stmt, idx)
    if ctype == SQLITE_INTEGER or ctype == SQLITE_FLOAT then
        return C.sqlite3_column_double(stmt, idx)
    elseif ctype == SQLITE_TEXT or ctype == SQLITE_BLOB then
        return ffi.string(
            C.sqlite3_column_blob(stmt, idx),
            C.sqlite3_column_bytes(stmt, idx)
        )
    elseif ctype == SQLITE_NULL then
        return nil
    end
    error("Unknown SQL type? " .. ctype)
end

local function check(db, value)
    if value ~= SQLITE_OK then
        error(ffi.string(C.sqlite3_errmsg(db)), 2)
    end
end

-- Use a weak key table to avoid exposing the raw FFI object to other mods
local dbs = setmetatable({}, {__mode = "k"})

-- Automatically close databases on shutdown
core.register_on_shutdown(function()
    for _, db in pairs(dbs) do
        ffi.gc(db, nil)
        C.sqlite3_close_v2(db)
    end
end)

local SqliteDB = {dbms = "sqlite"}
SqliteDB.__index = SqliteDB

-- Imperfect check but better than not checking at all
local get_worldpath = core.get_worldpath
assert(debug.getinfo(get_worldpath, "S").what == "C",
    "core.get_worldpath() has been tampered with!")
local worldpath = get_worldpath()

-- Another imperfect check in case the get_worldpath one was bypassed
assert(ie.io.open(worldpath .. "/world.mt", "r")):close()

local db_dir = worldpath .. "/squill-sqlite/"
function backends.sqlite(db_name, check_existing)
    core.mkdir(db_dir)
    assert(not string.find(db_name, "[^A-Za-z0-9_:]"), "Invalid database name")
    local path = db_dir .. string.gsub(db_name, ":", "-") .. ".sqlite"

    if check_existing then
        -- No need to worry about security here
        local f = insecure_io_open(path, "rb")
        if not f then return end
        f:close()
    end

    local db_ptr = ffi.new("sqlite3*[1]")
    assert(C.sqlite3_open(path, db_ptr) == SQLITE_OK)

    -- Close the database automatically
    local db = db_ptr[0]
    ffi.gc(db, C.sqlite3_close_v2)

    -- https://sqlite.org/security.html
    C.sqlite3_limit(db, SQLITE_LIMIT_LENGTH, 1000000)
    C.sqlite3_limit(db, SQLITE_LIMIT_SQL_LENGTH, 100000)
    C.sqlite3_limit(db, SQLITE_LIMIT_COLUMN, 100)
    C.sqlite3_limit(db, SQLITE_LIMIT_EXPR_DEPTH, 10)
    C.sqlite3_limit(db, SQLITE_LIMIT_COMPOUND_SELECT, 3)
    C.sqlite3_limit(db, SQLITE_LIMIT_VDBE_OP, 25000)
    C.sqlite3_limit(db, SQLITE_LIMIT_FUNCTION_ARG, 8)
    C.sqlite3_limit(db, SQLITE_LIMIT_ATTACHED, 0)
    C.sqlite3_limit(db, SQLITE_LIMIT_LIKE_PATTERN_LENGTH, 50)
    C.sqlite3_limit(db, SQLITE_LIMIT_VARIABLE_NUMBER, 16)
    C.sqlite3_limit(db, SQLITE_LIMIT_TRIGGER_DEPTH, 10)

    -- Store the database and return a wrapped connection object
    local res = setmetatable({}, SqliteDB)
    dbs[res] = db

    -- Sensible default (since this is not security sensitive it doesn't have
    -- to use the C API directly)
    res:exec("PRAGMA foreign_keys = ON")

    return res
end

-- Returns a blob object
function SqliteDB:blob(str)
    assert(type(str) == "string")
    local res = {}
    blob_objs[res] = str
    return res
end

function SqliteDB:prepare(sql, return_mode)
    assert(type(sql) == "string")

    local db = dbs[self]
    local stmt_ptr = ffi.new("sqlite3_stmt*[1]")
    local tail_ptr = ffi.new("const char*[1]")
    check(db, C.sqlite3_prepare_v2(db, sql, #sql, stmt_ptr, tail_ptr))

    -- Call sqlite3_finalize if the statement gets garbage collected
    local stmt = stmt_ptr[0]
    ffi.gc(stmt, C.sqlite3_finalize)

    -- Checking whether there is more code after this SQL statement is not
    -- security sensitive, so it can just use Squill's parser
    local tail = ffi.string(tail_ptr[0])
    if sq.Parser.new("", tail):peek() then
        error("Multiple SQL statements are not supported by prepare()")
    end

    -- Read column names into Lua early to avoid having to do it every time
    local column_count = C.sqlite3_column_count(stmt)
    local columns = {}
    for i = 0, column_count - 1 do
        columns[i] = ffi.string(C.sqlite3_column_name(stmt, i))
    end

    local function read_row()
        local status = C.sqlite3_step(stmt)
        if status == SQLITE_DONE then
            return
        end

        assert(status == SQLITE_ROW, C.sqlite3_errmsg(db))

        local row = {}
        for i = 0, column_count - 1 do
            row[columns[i]] = read_value(stmt, i)
        end
        return row
    end

    -- Reset the prepared statement back to its initial state
    local function cleanup()
        check(db, C.sqlite3_reset(stmt))
        check(db, C.sqlite3_clear_bindings(stmt))
    end

    -- Figure out what we have to bind
    local bind_map = {}
    local param_type
    for i = 1, C.sqlite3_bind_parameter_count(stmt) do
        local name_ptr = C.sqlite3_bind_parameter_name(stmt, i)
        local is_nameless = name_ptr == nil

        local first_char
        if is_nameless then
            first_char = "?"
            bind_map[i] = i
        else
            -- Support $1, $2, etc as well
            local name = ffi.string(name_ptr)
            first_char = string.sub(name, 1, 1)
            if first_char == "$" then
                bind_map[i] = assert(tonumber(string.sub(name, 2)))
                assert(bind_map[i] >= 1 and bind_map[i] % 1 == 0)
            elseif first_char == ":" then
                bind_map[i] = string.sub(name, 2)
            else
                error("Unsupported bind parameter")
            end
        end

        param_type = param_type or first_char
        assert(param_type == first_char,
            "Cannot mix different types of bind parameters")
    end

    -- Bind all values to the prepared statement
    local bind
    if param_type == "?" or param_type == "$" then
        -- Positional parameters
        function bind(...)
            -- In case the statement was run and errored
            cleanup()

            for sql_idx, lua_idx in pairs(bind_map) do
                check(db, bind_value(stmt, sql_idx, (select(lua_idx, ...))))
            end
        end
    elseif param_type == ":" then
        -- Named parameters
        function bind(params)
            cleanup()

            for sql_idx, lua_idx in pairs(bind_map) do
                check(db, bind_value(stmt, sql_idx, params[lua_idx]))
            end
        end
    else
        assert(param_type == nil)

        -- No parameters
        bind = cleanup
    end

    -- Try and figure out whether we need to return affected_rows (this may
    -- as well use Squill's tokeniser since it's there, it shouldn't be
    -- security-sensitive)
    local cmd = sq.Parser.new("", sql):next()
    local affected_rows = cmd == "insert" or cmd == "update" or cmd == "delete"

    -- Make a 1-indexed copy of columns to return
    local returned_columns = {affected_rows = affected_rows or nil}
    for i = 1, column_count do
        returned_columns[i] = columns[i - 1]
    end

    if return_mode == squill.RETURN_ALL_ROWS then
        return function(...)
            bind(...)

            local res = {}
            repeat
                local row = read_row()
                res[#res + 1] = row
            until not row

            if affected_rows then
                res.affected_rows = C.sqlite3_changes(db)
            end

            cleanup()
            return res
        end, returned_columns
    elseif return_mode == squill.RETURN_FIRST_ROW then
        return function(...)
            bind(...)
            local row = read_row()
            cleanup()
            return row
        end, returned_columns
    elseif return_mode == squill.RETURN_SINGLE_COLUMN then
        assert(column_count == 1, sq.ERR_NOT_ONE_COLUMN)
        return function(...)
            bind(...)
            local res = {}
            local idx = 1
            while true do
                local status = C.sqlite3_step(stmt)
                if status == SQLITE_DONE then
                    break
                end

                assert(status == SQLITE_ROW, C.sqlite3_errmsg(db))
                res[idx] = read_value(stmt, 0)
                idx = idx + 1
            end
            cleanup()
            return res
        end, returned_columns
    elseif return_mode == squill.RETURN_SINGLE_VALUE then
        assert(column_count == 1, sq.ERR_NOT_ONE_COLUMN)
        return function(...)
            bind(...)
            local status = C.sqlite3_step(stmt)
            if status == SQLITE_DONE then
                return nil
            end

            assert(status == SQLITE_ROW, C.sqlite3_errmsg(db))
            local res = read_value(stmt, 0)
            cleanup()
            return res
        end, returned_columns
    else
        error("Unsupported value for return_mode")
    end
end

function SqliteDB:exec(sql, ...)
    assert(type(sql) == "string")

    if select("#", ...) > 0 then
        -- Just use the prepare API
        self:prepare(sql)(...)
    else
        local db = dbs[self]
        check(db, C.sqlite3_exec(db, sql, nil, nil, nil))
    end
end
