-- lsqlite3 compatibility layer
local unpack = table.unpack or unpack

-- Hack for implementing last_insert_rowid
local get_autoincrements = squill.prepare_statement("squill", [[
    SELECT autoincrement - 1 FROM schema
    WHERE db_name = ? AND primary_key AND autoincrement IS NOT NULL
    ORDER BY id
]], squill.RETURN_SINGLE_COLUMN)

local sqlite3 = {OK = 0, ERROR = 1, ABORT = 4, MISUSE = 21, ROW = 100, DONE = 101}

local DB = {}
DB.__index = DB

local Stmt = {}
Stmt.__index = Stmt

function sqlite3.open(path)
    local db_name = path:match("[/\\]([a-z0-9_]+)[%.a-z0-9_]-$") or path
    return setmetatable({_db_name = db_name}, DB)
end

function DB:changes()
    return self._affected_rows or 0
end

function DB:close()
    self._db_name = nil
    return sqlite3.OK
end

function DB:errcode()
    return sqlite3.ERROR
end
DB.error_code = DB.errcode

function DB:errmsg()
    return self._errmsg
end
DB.error_msg = DB.errmsg
DB.error_message = DB.errmsg

function DB:exec(sql, func, udata)
    self._old_rowids = nil

    assert(self:isopen())
    local ok, all_rows, all_cols = pcall(squill.exec, self._db_name, sql)
    if not ok then
        self._errmsg = all_rows
        return sqlite3.ERROR
    end

    if not func then return sqlite3.OK end

    for all_rows_idx, rows in ipairs(all_rows) do
        local names = all_cols[all_rows_idx]

        for _, row in ipairs(rows) do
            local values = {}
            for i, col_name in ipairs(names) do
                values[i] = row[col_name]
            end

            local r = func(udata, #names, values, names)
            if tonumber(r) ~= 0 then
                return sqlite3.ABORT
            end
        end
    end

    return sqlite3.OK
end
DB.execute = DB.exec

function DB:interrupt() end

function DB:isopen()
    return self._db_name ~= nil
end

function DB:nrows(sql)
    local stmt = self:prepare(sql)
    assert(stmt, self:errmsg())
    return stmt:nrows()
end

function DB:prepare(sql)
    assert(self:isopen())
    local ok, func, cols = pcall(squill.prepare_statement, self._db_name, sql)
    if not ok then
        self._errmsg = func
        return nil, sqlite3.ERROR
    end

    local stmt = setmetatable({}, Stmt)
    stmt._db, stmt._func, stmt._columns = self, func, cols
    stmt:reset()
    return stmt
end

function DB:rows(sql)
    local stmt = self:prepare(sql)
    assert(stmt, self:errmsg())
    return stmt:rows()
end

DB.total_changes = DB.changes


function DB:urows(sql)
    local stmt = self:prepare(sql)
    assert(stmt, self:errmsg())
    return stmt:urows()
end

function DB:last_insert_rowid()
    local new_rowids = get_autoincrements(self._db_name)
    assert(#self._old_rowids == #new_rowids)
    local res
    for i, old_rowid in ipairs(self._old_rowids) do
        if new_rowids[i] ~= old_rowid then
            assert(res == nil)
            res = new_rowids[i]
        end
    end
    assert(res)
    return res
end

function Stmt:bind(n, value)
    assert(type(n) == "number" and n >= 1 and n % 1 == 0)
    self._bind_values[n] = value
    self._bind_count = math.max(self._bind_count, n)
    return sqlite3.OK
end

Stmt.bind_blob = Stmt.bind

function Stmt:bind_parameter_count()
    return self._bind_count
end

function Stmt:bind_values(...)
    self._bind_values = {...}
    self._bind_count = select("#", ...)
    return sqlite3.OK
end

function Stmt:columns()
    return #self._columns
end

function Stmt:finalize()
    return self._errcode or sqlite3.OK
end

function Stmt:get_name(n)
    return assert(self._columns[n + 1])
end

function Stmt:get_named_values()
    local row = assert(self._rows[self._idx])
    -- SQLite doesn't have a separate boolean type, so returned booleans get
    -- converted into integers.
    for k, v in pairs(row) do
        if v == true then
            row[k] = 1
        elseif v == false then
            row[k] = 0
        end
    end
    return row
end

function Stmt:get_names()
    return self._columns
end

function Stmt:get_unames()
    return unpack(self._columns)
end

function Stmt:get_uvalues()
    return unpack(self:get_values(), 1, #self._columns)
end

function Stmt:get_value(n)
    return self:get_named_values()[self:get_name(n)]
end

function Stmt:get_values()
    local values = {}
    local row = self:get_named_values()
    for i, col_name in ipairs(self._columns) do
        values[i] = row[col_name]
    end
    return values
end

function Stmt:isopen()
    return self._func ~= nil
end

local function make_iterator(self, get_values)
    return function()
        local status = self:step()
        if status == sqlite3.DONE then
            return
        end
        assert(status == sqlite3.ROW)
        return get_values(self)
    end
end

function Stmt:nrows()
    return make_iterator(self, self.get_named_values)
end

function Stmt:reset()
    self._bind_values = {}
    self._bind_count = 0
    self._rows = nil
    self._idx = nil
    return sqlite3.OK
end

function Stmt:rows()
    return make_iterator(self, self.get_values)
end

function Stmt:step()
    if not self._rows then
        if not self._func then
            return sqlite3.MISUSE
        end
        local old_rowids
        if self._columns.affected_rows then
            old_rowids = get_autoincrements(self._db._db_name)
        end
        self._db._old_rowids = nil
        local ok, rows = pcall(
            self._func,
            unpack(self._bind_values, 1, self._bind_count)
        )
        if not ok then
            self._errcode = sqlite3.ERROR
            self._db._errmsg = rows
            return sqlite3.ERROR
        end
        if rows.affected_rows == 1 then
            self._db._old_rowids = old_rowids
        end
        self._db._affected_rows = rows.affected_rows
        self._rows = rows
        self._idx = 0
    end

    self._idx = self._idx + 1
    return self._rows[self._idx] and sqlite3.ROW or sqlite3.DONE
end

function Stmt:urows()
    return make_iterator(self, self.get_uvalues)
end

return sqlite3
