--[[
Copyright 2023 ekl

This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at https://mozilla.org/MPL/2.0/.
]]
--This file sort of works, but also doesn't.
--Anything related to time signatures is probably way off.

---@class MidiPlayer
---@field track string
---@field length integer? Length of the track in bytes
local MidiPlayer = {
	track = "MTrk\x00\x00\x00\x00",
	head = 1,
	wait = 0,
	tempo = 100000000 / 480,
	handles = {}, ---@type {[integer]: any} Maps midi tones to active handles
	instrument = { name = "cid_piano" },
}
MidiPlayer.__index = MidiPlayer

---Decode a variable length encoded integer
---@param str string
---@param start integer
---@return integer Decoded number
---@return integer Index Index past the end of the VLE integer
local function decodeVle(str, start)
	local n = 0
	for i = start, #str do
		local b = str:byte(i)
		if b < 128 then
			return n + b, i + 1
		end
		n = (n + b - 128) * 128
	end
	return n, #str
end

function MidiPlayer:destroy()
	for i, handle in pairs(self.handles) do
		minetest.sound_stop(handle)
		self.handles[i] = nil
	end
end

function MidiPlayer:stopNote(pitch, channel)
	local slot = pitch * 16 + channel
	local handle = self.handles[slot]
	if handle then
		minetest.sound_fade(handle, 10, 0)
		self.handles[slot] = nil
	end
end

function MidiPlayer:playNote(pitch, velocity, channel, objRef)
	local slot = pitch * 16 + channel
	self:stopNote(pitch, channel)
	self.handles[slot] = minetest.sound_play({
		name = self.instrument.name,
		pitch = 2 ^ ((pitch - self.instrument.pitch) / 12),
		gain = self.instrument.gain * velocity / 64 / (2 ^ ((pitch - self.instrument.pitch) / 12))
	}, { object = objRef })
end

function MidiPlayer:step(dtime, objRef)
	local wait = self.wait - dtime * 100000 / self.tempo

	local track = self.track
	local i = self.head
	--minetest.chat_send_all(("%d %d"):format(i, wait))
	-- Read the length
	if not self.length then
		local a, b, c, d = track:byte(i + 4, i + 7)
		self.length = ((a * 256 + b) * 256 + c) * 256 + d
		i = i + 8 -- Skip past header
		wait, i = decodeVle(track, i) -- Read the first wait
		--minetest.debug(wait, i)
	end

	while wait <= 0 and i < #track do
		local firstByte = track:byte(i)
		--minetest.debug(("%d %x"):format(i, firstByte))
		if firstByte >= 0x90 and firstByte <= 0xa0 then
			local channel = firstByte - 0x90
			local pitch, velocity = track:byte(i + 1, i + 2)
			if velocity == 0 then
				self:stopNote(pitch, channel)
			else
				self:playNote(pitch, velocity, channel, objRef)
			end
			i = i + 3
		elseif firstByte >= 0x80 and firstByte < 0x90 then
			local channel = firstByte - 0x80
			self:stopNote(track:byte(i + 1), channel)
			i = i + 3
		elseif firstByte == 0xFF then
			local secondByte = track:byte(i + 1)
			--minetest.debug(secondByte)
			if secondByte == 0x2F then
				--Stop at Track End
				self.head = i
				return true
			elseif secondByte == 0x51 then
				--Set tempo
				--minetest.debug("Setting tempo")
				local a, b, c = track:byte(i + 2, i + 4)
				self.tempo = (a * 256 + b) * 256 + c
			elseif secondByte == 0x58 then
				--minetest.debug("Time signature not implemented")
			end
			local length
			length, i = decodeVle(track, i + 2)
			i = i + length
		elseif (firstByte >= 0xC0 and firstByte < 0xE0) or firstByte == 0xF1 or firstByte == 0xF3 or firstByte == 0xF0 then
			i = i + 2
		elseif firstByte >= 0xF4 and firstByte <= 0xF7 then
			i = i + 1
		else
			--Ignore anything we don't understand
			i = i + 3
		end
		local delta
		delta, i = decodeVle(track, i)
		wait = wait + delta
	end
	self.head = i
	self.wait = wait
end

---Handles playing all tracks of a MIDI file
---By full, I mean all tracks, this supports relatively few MIDI features
---@class MidiPlayerFull
---@field tracks MidiPlayer[]
local MidiPlayerFull = {
	tps = 120,
}
MidiPlayerFull.__index = MidiPlayerFull

function MidiPlayerFull.new(midiData, instrument)
	local self = setmetatable({ tracks = {} }, MidiPlayerFull)
	local tracks = self.tracks
	local i = 1
	while i < #midiData do
		local chunkType = midiData:sub(i, i + 3)
		if chunkType == "MTrk" then
			--minetest.debug("Add track", i)
			tracks[#tracks + 1] = setmetatable({ track = midiData, head = i, instrument = instrument }, MidiPlayer)
		elseif chunkType == "MThd" then
			--Note: Format is just ignored since I'm not going to implement sequential track support

			--Get division
			local a, b = midiData:byte(i + 10, i + 11)
			if a >= 128 then
				self.tps = ((a - 128) * 256 + b) * 4
			else
				self.tps = (256 - a) * b
			end
		end
		-- Go to the next section
		local a, b, c, d = midiData:byte(i + 4, i + 7)
		i = i + 8 + ((a * 256 + b) * 256 + c) * 256 + d
	end
	return self
end

function MidiPlayerFull:destroy()
	for i, track in pairs(self.tracks) do
		track:destroy()
		self.tracks[i] = nil
	end
end

function MidiPlayerFull:step(dtime)
	dtime = dtime * self.tps
	local allFinished = true
	local tracks = self.tracks
	for i = 1, #tracks do
		allFinished = tracks[i]:step(dtime) and allFinished
	end
	return allFinished
end

return {
	play = function(data, instrument)
		return MidiPlayerFull.new(data._midi, instrument)
	end,
	_player = MidiPlayer,
	_playerFull = MidiPlayerFull,
}
