-- RNN API module
local RNN = {}

-- Utility functions (local to the module)
local function zeros(n)
    local t = {}
    for i = 1, n do
        t[i] = 0
    end
    return t
end

local function random_matrix(rows, cols)
    local m = {}
    for i = 1, rows do
        m[i] = {}
        for j = 1, cols do
            m[i][j] = (math.random() - 0.5) * 0.1
        end
    end
    return m
end

local function mat_vec_mul(mat, vec)
    local out = {}
    for i = 1, #mat do
        local sum = 0
        for j = 1, #vec do
            sum = sum + mat[i][j] * vec[j]
        end
        out[i] = sum
    end
    return out
end

local function vec_add(a, b)
    local out = {}
    for i = 1, #a do
        out[i] = a[i] + b[i]
    end
    return out
end

local function tanh(vec)
    local out = {}
    for i = 1, #vec do
        out[i] = math.tanh(vec[i])
    end
    return out
end

local function softmax(vec)
    local max_val = -math.huge
    for i = 1, #vec do
        if vec[i] > max_val then
            max_val = vec[i]
        end
    end
    local exps, sum = {}, 0
    for i = 1, #vec do
        exps[i] = math.exp(vec[i] - max_val)
        sum = sum + exps[i]
    end
    for i = 1, #exps do
        exps[i] = exps[i] / sum
    end
    return exps
end

local function one_hot(idx, size)
    local v = zeros(size)
    v[idx] = 1
    return v
end

-- RNN instance constructor
function RNN.new(params)
    params = params or {}
    local instance = {}

    -- Configuration
    instance.charset = params.charset or "abcdefghijklmnopqrstuvwxyz .,!?'"
    instance.hidden_size = params.hidden_size or 16
    instance.learning_rate = params.learning_rate or 0.005

    -- Character to index mappings
    instance.char_to_idx = {}
    instance.idx_to_char = {}
    for i = 1, #instance.charset do
        local c = instance.charset:sub(i, i)
        instance.char_to_idx[c] = i
        instance.idx_to_char[i] = c
    end
    instance.vocab_size = #instance.charset

    -- Model weights
    instance.Wxh = random_matrix(instance.hidden_size, instance.vocab_size)
    instance.Whh = random_matrix(instance.hidden_size, instance.hidden_size)
    instance.Why = random_matrix(instance.vocab_size, instance.hidden_size)
    instance.bh = zeros(instance.hidden_size)
    instance.by = zeros(instance.vocab_size)

    -- Training data storage
    instance.training_phrases = {}

    -- Internal: Convert string to indices
    function instance:phrase_to_indices(phrase)
        local seq = {}
        phrase = phrase:lower()
        for c in phrase:gmatch(".") do
            local idx = self.char_to_idx[c]
            if idx then
                table.insert(seq, idx)
            end
        end
        return seq
    end

    -- Internal: Forward pass for sequence
    function instance:forward(inputs, hprev)
        local xs, hs, ys, ps = {}, {}, {}, {}
        hs[0] = hprev or zeros(self.hidden_size)
        local loss = 0

        for t = 1, #inputs do
            local x = one_hot(inputs[t], self.vocab_size)
            xs[t] = x
            local h_raw = vec_add(vec_add(mat_vec_mul(self.Wxh, x), mat_vec_mul(self.Whh, hs[t - 1])), self.bh)
            local h = tanh(h_raw)
            hs[t] = h
            local y = vec_add(mat_vec_mul(self.Why, h), self.by)
            ys[t] = y
            local p = softmax(y)
            ps[t] = p
            local target = inputs[t + 1] or self.char_to_idx[' ']
            loss = loss - math.log(p[target] or 1e-8)
        end
        return loss, xs, hs, ys, ps
    end

    -- Internal: Backward pass (BPTT)
    function instance:backward(xs, hs, ps, inputs)
        local dWxh_local = {}
        local dWhh_local = {}
        local dWhy_local = {}
        local dbh_local = zeros(self.hidden_size)
        local dby_local = zeros(self.vocab_size)

        for i = 1, self.hidden_size do
            dWxh_local[i] = zeros(self.vocab_size)
            dWhh_local[i] = zeros(self.hidden_size)
        end
        for i = 1, self.vocab_size do
            dWhy_local[i] = zeros(self.hidden_size)
        end

        local dhnext = zeros(self.hidden_size)

        for t = #xs, 1, -1 do
            local dy = {}
            for i = 1, self.vocab_size do
                dy[i] = ps[t][i]
            end
            local target = inputs[t + 1] or self.char_to_idx[' ']
            dy[target] = dy[target] - 1

            for i = 1, self.vocab_size do
                dby_local[i] = dby_local[i] + dy[i]
                for j = 1, self.hidden_size do
                    dWhy_local[i][j] = dWhy_local[i][j] + dy[i] * hs[t][j]
                end
            end

            local dh = {}
            for j = 1, self.hidden_size do
                local sum = 0
                for i = 1, self.vocab_size do
                    sum = sum + self.Why[i][j] * dy[i]
                end
                dh[j] = sum + dhnext[j]
            end

            local dt = {}
            for j = 1, self.hidden_size do
                dt[j] = (1 - hs[t][j] ^ 2) * dh[j]
            end

            for i = 1, self.hidden_size do
                for j = 1, self.vocab_size do
                    dWxh_local[i][j] = dWxh_local[i][j] + dt[i] * xs[t][j]
                end
            end
            for i = 1, self.hidden_size do
                for j = 1, self.hidden_size do
                    dWhh_local[i][j] = dWhh_local[i][j] + dt[i] * hs[t - 1][j]
                end
            end
            for i = 1, self.hidden_size do
                dbh_local[i] = dbh_local[i] + dt[i]
            end

            dhnext = dt
        end
        return dWxh_local, dWhh_local, dWhy_local, dbh_local, dby_local
    end

    -- Internal: Clip gradients
    function instance:clip_gradients(dWxh, dWhh, dWhy, dbh, dby, thresh)
        local function clip_matrix(m)
            for i = 1, #m do
                for j = 1, #m[i] do
                    if m[i][j] > thresh then
                        m[i][j] = thresh
                    elseif m[i][j] < -thresh then
                        m[i][j] = -thresh
                    end
                end
            end
        end
        local function clip_vector(v)
            for i = 1, #v do
                if v[i] > thresh then
                    v[i] = thresh
                elseif v[i] < -thresh then
                    v[i] = -thresh
                end
            end
        end
        clip_matrix(dWxh)
        clip_matrix(dWhh)
        clip_matrix(dWhy)
        clip_vector(dbh)
        clip_vector(dby)
    end

    -- Internal: Update weights
    function instance:update_weights(dWxh, dWhh, dWhy, dbh, dby)
        for i = 1, self.hidden_size do
            for j = 1, self.vocab_size do
                self.Wxh[i][j] = self.Wxh[i][j] - self.learning_rate * dWxh[i][j]
            end
        end
        for i = 1, self.hidden_size do
            for j = 1, self.hidden_size do
                self.Whh[i][j] = self.Whh[i][j] - self.learning_rate * dWhh[i][j]
            end
        end
        for i = 1, self.vocab_size do
            for j = 1, self.hidden_size do
                self.Why[i][j] = self.Why[i][j] - self.learning_rate * dWhy[i][j]
            end
        end
        for i = 1, self.hidden_size do
            self.bh[i] = self.bh[i] - self.learning_rate * dbh[i]
        end
        for i = 1, self.vocab_size do
            self.by[i] = self.by[i] - self.learning_rate * dby[i]
        end
    end

    -- Internal: Train one phrase
    function instance:train_phrase(phrase, hprev)
        local seq = self:phrase_to_indices(phrase)
        if #seq < 2 then
            return 0, hprev
        end
        local loss, xs, hs, ys, ps = self:forward(seq, hprev)
        local dWxh, dWhh, dWhy, dbh, dby = self:backward(xs, hs, ps, seq)
        self:clip_gradients(dWxh, dWhh, dWhy, dbh, dby, 5)
        self:update_weights(dWxh, dWhh, dWhy, dbh, dby)
        return loss, hs[#seq]
    end

    -- Public API: Add training data
    -- phrases: a table of strings
    function instance:add_training_data(phrases)
        for _, phrase in ipairs(phrases) do
            table.insert(self.training_phrases, phrase)
        end
    end

    -- Public API: Train the model
    -- epochs: number of training epochs
    -- on_epoch_callback: optional function called after each epoch (epoch_num, current_loss)
    function instance:train(epochs, on_epoch_callback)
        local hprev = zeros(self.hidden_size)
        local total_phrases = #self.training_phrases
        if total_phrases == 0 then
            return false, "No training data added. Use add_training_data first."
        end

        for epoch = 1, epochs do
            local current_loss = 0
            for _, phrase in ipairs(self.training_phrases) do
                local loss
                loss, hprev = self:train_phrase(phrase, hprev)
                current_loss = current_loss + loss
            end
            if on_epoch_callback then
                on_epoch_callback(epoch, current_loss / total_phrases)
            end
        end
        return true
    end

    -- Public API: Generate text from a seed
    -- seed: starting text string
    -- length: number of characters to generate
    function instance:generate(seed, length)
        local h = zeros(self.hidden_size)
        local ix = self.char_to_idx[seed:sub(1, 1)] or 1
        local output = seed

        for i = 1, length do
            local x = one_hot(ix, self.vocab_size)
            local h_raw = vec_add(vec_add(mat_vec_mul(self.Wxh, x), mat_vec_mul(self.Whh, h)), self.bh)
            h = tanh(h_raw)
            local y = vec_add(mat_vec_mul(self.Why, h), self.by)
            local p = softmax(y)
            
            local r = math.random()
            local cumulative = 0
            local next_ix = 1
            for j = 1, #p do
                cumulative = cumulative + p[j]
                if r < cumulative then
                    next_ix = j
                    break
                end
            end
            output = output .. self.idx_to_char[next_ix]
            ix = next_ix
        end
        return output
    end

    -- Public API: Save model weights
    -- Returns a table containing all weights and hyperparameters
    function instance:save_weights()
        return {
            charset = self.charset,
            hidden_size = self.hidden_size,
            learning_rate = self.learning_rate,
            Wxh = self.Wxh,
            Whh = self.Whh,
            Why = self.Why,
            bh = self.bh,
            by = self.by,
        }
    end

    -- Public API: Load model weights
    -- data: table returned by save_weights
    function instance:load_weights(data)
        if not data then return false, "No data provided to load." end
        
        self.charset = data.charset or self.charset
        self.hidden_size = data.hidden_size or self.hidden_size
        self.learning_rate = data.learning_rate or self.learning_rate

        -- Re-initialize char mappings based on loaded charset
        self.char_to_idx = {}
        self.idx_to_char = {}
        for i = 1, #self.charset do
            local c = self.charset:sub(i, i)
            self.char_to_idx[c] = i
            self.idx_to_char[i] = c
        end
        self.vocab_size = #self.charset

        self.Wxh = data.Wxh or self.Wxh
        self.Whh = data.Whh or self.Whh
        self.Why = data.Why or self.Why
        self.bh = data.bh or self.bh
        self.by = data.by or self.by
        return true
    end

    -- Initialize random seed on instance creation
    math.randomseed(os.time())

    return instance
end

return RNN

