--- Restore tracking (chunk-level).
-- Tracks external mapblock modifications affecting a placed snapshot and enables
-- restoring only the changed 16x16x16 chunks.

local rt = map_octree.restore_tracking
local MAPBLOCK_SIZE = rt.MAPBLOCK_SIZE

local get_scan_budget_ms = rt.get_scan_budget_ms
local node_to_block_coord = rt.node_to_block_coord
local compute_block_bounds = rt.compute_block_bounds
local block_in_bounds = rt.block_in_bounds
local chunk_key = rt.chunk_key
local key_to_grid = rt.key_to_grid
local grid_in_range = rt.grid_in_range
local blockpos_to_grid_range = rt.blockpos_to_grid_range
local verify_chunk_against_snapshot = rt.verify_chunk_against_snapshot
local write_chunk_to_world = rt.write_chunk_to_world

local FifoSet = map_octree.FifoSet

--#region types
---@class RestoreChunkCoords
---@field gx integer
---@field gy integer
---@field gz integer

---@class RestoreChangeInfo
---@field t_us integer
---@field modified_block_count integer
---@field blocks_seen integer
---@field blocks_in_bounds integer
---@field blocks_out_of_bounds integer
---@field unique_chunks_enqueued integer

---@class RestoreTrackingStatus
---@field pending integer
---@field dirty integer
---@field base_corner vector
---@field max_corner vector
---@field min_blockpos vector
---@field max_blockpos vector
---@field chunk_size? vector
---@field last_change? RestoreChangeInfo

---@class RestoreFlushOpts
---@field pending? boolean
---@field dirty? boolean
---@field cancel_restore? boolean

---@class RestoreFlushResult
---@field pending_cleared integer
---@field dirty_cleared integer
---@field restore_cancelled boolean

---@class RestoreDisableOpts
---@field flush? boolean

---@class RestoreOptions

---@class RestoreStateInfo
---@field active boolean
---@field phase "idle"|"waiting"|"writing"
---@field token integer
---@field started_us integer
---@field err? string
---@field suppress boolean
---@field pending integer
---@field dirty integer

---@class RestoreProbeInfo
---@field live_node string
---@field expected_node string
---@field gx integer
---@field gy integer
---@field gz integer
---@field is_sparse boolean
---@field is_dirty boolean
---@field in_bounds boolean

---@class RestoreTracker
---@field id integer
---@field map OctMap
---@field base_corner vector
---@field max_corner vector
---@field req_min vector
---@field req_max vector
---@field min_blockpos vector
---@field max_blockpos vector
---@field suppress boolean
---@field restore_active boolean
---@field restore_phase "idle"|"waiting"|"writing"
---@field restore_token integer
---@field restore_started_us integer
---@field restore_err? string
---@field _restore_finish? fun(ok: boolean, err?: string)
---@field pending FifoSet
---@field dirty table<integer, boolean>
---@field dirty_coords table<integer, RestoreChunkCoords>
---@field last_change? RestoreChangeInfo
---@field _expected {data: integer[], param2: integer[], param1: integer[]}
---@field _write {data: integer[], param2: integer[], param1: integer[]}

---@class RestoreTrackingState
---@field next_id integer
---@field trackers table<integer, RestoreTracker>
---@field order integer[]
---@field rr integer
--#endregion

---@type RestoreTrackingState
local tracking = {
	next_id = 1,
	trackers = {},
	order = {},
	rr = 1,
}

local last_globalstep_log_us = 0
local MapMethods = octmap.MapMethods


---Enqueue a chunk key for a tracker.
---@param tracker RestoreTracker
---@param gx integer
---@param gy integer
---@param gz integer
---@return boolean
local function enqueue_chunk(tracker, gx, gy, gz)
	local key = chunk_key(tracker.map, gx, gy, gz)
	return tracker.pending:push(key)
end

---Pop the next pending chunk key.
---@param tracker RestoreTracker
---@return integer|nil
local function pop_next_chunk_key(tracker)
	return tracker.pending:pop()
end

---Count entries in a table.
---@param t table
---@return integer
local function count_table_keys(t)
	local n = 0
	for _ in pairs(t) do
		n = n + 1
	end
	return n
end


---Enable restore tracking for a placed snapshot.
---@param self OctMap
function MapMethods:enable_tracking()
	if self._tracker_id then
		return
	end

	local minp = self.minp
	local maxp = self.maxp
	local req_min, req_max = map_octree.get_requested_bounds(self)

	local half = octchunk.SIZE / 2
	local min_corner = vector.subtract(minp, {x = half, y = half, z = half})
	local max_corner = vector.add(maxp, {x = half - 1, y = half - 1, z = half - 1})
	local minb, maxb = compute_block_bounds(req_min, req_max)

	local id = tracking.next_id
	tracking.next_id = tracking.next_id + 1

	-- self._tracker_id points to tracking.trackers[id]; tracker.id mirrors this value.
	tracking.trackers[id] = {
		id = id,
		map = self,
		base_corner = min_corner,
		max_corner = max_corner,
		req_min = req_min,
		req_max = req_max,
		min_blockpos = minb,
		max_blockpos = maxb,
		suppress = false,
		restore_active = false,
		restore_phase = "idle",
		restore_token = 0,
		restore_started_us = 0,
		restore_err = nil,
		_restore_finish = nil,
		pending = FifoSet.new(),
		dirty = {},
		dirty_coords = {},
		last_change = nil,
		_expected = {data = {}, param2 = {}, param1 = {}},
		_write = {data = {}, param2 = {}, param1 = {}},
	}

	tracking.order[#tracking.order + 1] = id
	self._tracker_id = id
end



---Return true if a restore operation is currently inflight for this map.
---@param self OctMap
---@return boolean
function MapMethods:is_restoring()
	local tracker_id = self._tracker_id
	if not tracker_id then
		return false
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return false
	end
	return tr.restore_active == true
end



---Get the current restore operation state for this map.
---This includes both the "waiting" phase (pending queue drain) and the "writing" phase.
---@param self OctMap
---@return RestoreStateInfo|nil
function MapMethods:get_restore_state()
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end

	return {
		active = tr.restore_active == true,
		phase = tr.restore_phase or "idle",
		token = tr.restore_token or 0,
		started_us = tr.restore_started_us or 0,
		err = tr.restore_err,
		suppress = tr.suppress == true,
		pending = tr.pending:len(),
		dirty = count_table_keys(tr.dirty),
	}
end



---Flush tracking state for this map.
---This clears pending verification queue and dirty chunk set.
---If a restore is inflight, it will be cancelled by default.
---@param self OctMap
---@param opts? RestoreFlushOpts
---@return RestoreFlushResult|nil
function MapMethods:flush_tracking(opts)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end

	opts = opts or {}
	local clear_pending = (opts.pending ~= false)
	local clear_dirty = (opts.dirty ~= false)
	local cancel_restore = (opts.cancel_restore ~= false)

	local pending_before = tr.pending:len()
	local dirty_before = count_table_keys(tr.dirty)
	local restore_cancelled = false

	if cancel_restore and tr.restore_active and tr._restore_finish then
		restore_cancelled = true
		tr.restore_err = "cancelled"
		-- Call callback immediately (tests rely on synchronous semantics here).
		tr._restore_finish(false, tr.restore_err)
	end

	if clear_pending and pending_before > 0 then
		tr.pending:clear()
	end
	if clear_dirty and dirty_before > 0 then
		tr.dirty = {}
		tr.dirty_coords = {}
	end

	return {
		pending_cleared = clear_pending and pending_before or 0,
		dirty_cleared = clear_dirty and dirty_before or 0,
		restore_cancelled = restore_cancelled,
	}
end



---Disable restore tracking.
---@param self OctMap
---@param opts? RestoreDisableOpts
function MapMethods:disable_tracking(opts)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end

	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end

	-- Always cancel inflight restore to avoid leaving callbacks and timers hanging.
	-- Optional flush can then clear pending/dirty state before disabling.
	if tr.restore_active or tr._restore_finish then
		self:flush_tracking({cancel_restore = true, pending = false, dirty = false})
	end
	if opts and opts.flush then
		self:flush_tracking({cancel_restore = false, pending = true, dirty = true})
	end

	tracking.trackers[tracker_id] = nil
	for i = 1, #tracking.order do
		if tracking.order[i] == tracker_id then
			table.remove(tracking.order, i)
			break
		end
	end
	if tracking.rr > #tracking.order then
		tracking.rr = 1
	end
	self._tracker_id = nil
end



---Get tracking status.
---@param self OctMap
---@return RestoreTrackingStatus|nil
function MapMethods:get_tracking_status()
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end
	return {
		pending = tr.pending:len(),
		dirty = count_table_keys(tr.dirty),
		base_corner = tr.base_corner,
		max_corner = tr.max_corner,
		min_blockpos = tr.min_blockpos,
		max_blockpos = tr.max_blockpos,
		chunk_size = tr.map and tr.map.trees and tr.map.trees.size or nil,
		last_change = tr.last_change,
	}
end



---Restore all currently dirty chunks synchronously (waits for pending queue to empty).
---@param self OctMap
---@param opts? RestoreOptions
---@return boolean ok
---@return string? err
function MapMethods:restore(opts)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return false, "tracking not enabled"
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return false, "tracker not found"
	end

	local pending = tr.pending:len()
	local dirty_count = count_table_keys(tr.dirty)
	core.log("action", string.format("[restore_tracking] restore: pending=%d, dirty=%d", pending, dirty_count))

	if pending > 0 then
		return false, "cannot restore while pending queue not empty (" .. pending .. " chunks)"
	end

	tr.restore_token = (tr.restore_token or 0) + 1
	tr.restore_started_us = core.get_us_time()
	tr.restore_active = true
	tr.restore_phase = "writing"
	tr.restore_err = nil

	tr.suppress = true
	local ok_all, err_all = xpcall(function()
		for key in pairs(tr.dirty) do
			local coords = tr.dirty_coords[key]
			if coords then
				write_chunk_to_world(tr, coords.gx, coords.gy, coords.gz)
			end
			tr.dirty[key] = nil
			tr.dirty_coords[key] = nil
		end
	end, debug.traceback)
	tr.suppress = false

	tr.restore_active = false
	tr.restore_phase = "idle"
	if not ok_all then
		tr.restore_err = tostring(err_all)
	end

	if not ok_all then
		return false, tostring(err_all)
	end
	return true
end



---Schedule restore of all dirty chunks asynchronously (waits for pending queue, then restores, then calls callback).
---@param self OctMap
---@param callback fun(ok: boolean, err?: string)
function MapMethods:schedule_restore(callback)
	assert(type(callback) == "function", "callback required")

	local tracker_id = self._tracker_id
	if not tracker_id then
		callback(false, "tracking not enabled")
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		callback(false, "tracker not found")
		return
	end

	---Safely invoke the restore callback.
	---@param ok boolean
	---@param err? string
	local function safe_callback(ok, err)
		local ok_cb, e = xpcall(function()
			callback(ok, err)
		end, debug.traceback)
		if not ok_cb then
			core.log("error", "[map_octree] schedule_restore callback error: " .. tostring(e))
		end
	end

	local token = (tr.restore_token or 0) + 1
	tr.restore_token = token
	tr.restore_started_us = core.get_us_time()
	tr.restore_active = true
	tr.restore_phase = "waiting"
	tr.restore_err = nil

	local finished = false
	---Finalize the restore operation.
	---@param ok boolean
	---@param err? string
	local function finish(ok, err)
		if finished then
			return
		end
		finished = true

		-- Clear current restore state (only if still current).
		if tr.restore_token == token then
			tr.restore_active = false
			tr.restore_phase = "idle"
			tr.restore_err = err
			tr._restore_finish = nil
		end
		safe_callback(ok, err)
	end

	tr._restore_finish = finish

	---Check whether the restore operation was cancelled or is stale.
	---@return boolean
	local function is_cancelled_or_stale()
		if tracking.trackers[tracker_id] ~= tr then
			finish(false, "tracking disabled")
			return true
		end
		if tr.restore_token ~= token then
			finish(false, "cancelled")
			return true
		end
		if not tr.restore_active then
			finish(false, tr.restore_err or "cancelled")
			return true
		end
		return false
	end

	---Wait for pending queue to drain before restoring.
	local function wait_for_pending()
		if is_cancelled_or_stale() then
			return
		end

		local pending = tr.pending:len()
		if pending > 0 then
			core.after(0, wait_for_pending)
			return
		end

		if is_cancelled_or_stale() then
			return
		end

		local keys = {}
		for key in pairs(tr.dirty) do
			keys[#keys + 1] = key
		end

		if #keys == 0 then
			core.after(0, function()
				finish(true)
			end)
			return
		end

		tr.restore_phase = "writing"
		tr.suppress = true
		local idx = 1
		local failed = false

		map_octree.run_budgeted({
			budget_ms = get_scan_budget_ms(),
			delay = 0,
			cancel_fn = function()
				if is_cancelled_or_stale() then
					tr.suppress = false
					return true
				end
				return false
			end,
			step_fn = function()
				if failed or idx > #keys then
					return true
				end

				local ok_step, err_step = xpcall(function()
					local key = keys[idx]
					idx = idx + 1
					local coords = tr.dirty_coords[key]
					if coords then
						write_chunk_to_world(tr, coords.gx, coords.gy, coords.gz)
					end
					tr.dirty[key] = nil
					tr.dirty_coords[key] = nil
				end, debug.traceback)

				if not ok_step then
					failed = true
					tr.suppress = false
					core.after(0, function()
						finish(false, tostring(err_step))
					end)
					return true
				end

				return idx > #keys
			end,
			done_fn = function()
				if failed then
					return
				end
				if is_cancelled_or_stale() then
					tr.suppress = false
					return
				end
				tr.suppress = false
				finish(true)
			end,
		})
	end

	wait_for_pending()
end



---Manually enqueue a chunk for verification (internal testing).
---@param self OctMap
---@param gx integer
---@param gy integer
---@param gz integer
function MapMethods:_enqueue_chunk_for_verification(gx, gy, gz)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end
	enqueue_chunk(tr, gx, gy, gz)
end



---Internal test helper: pop one pending key (if any).
---This exists to let tests validate queue semantics without depending on globalstep timing.
---@return integer|nil key
function MapMethods:_pop_pending_key_for_tests()
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end
	return pop_next_chunk_key(tr)
end



---Internal test helper: get the current pending queue length.
---@return integer
function MapMethods:_pending_count_for_tests()
	local tracker_id = self._tracker_id
	if not tracker_id then
		return 0
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return 0
	end
	return tr.pending:len()
end



---Manually verify a chunk immediately (internal testing).
---@param self OctMap
---@param gx integer
---@param gy integer
---@param gz integer
function MapMethods:_verify_chunk_now(gx, gy, gz)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return
	end
	verify_chunk_against_snapshot(tr, gx, gy, gz)
end



---Diagnostic probe: get info about a world position vs snapshot.
---@param self OctMap
---@param pos vector
---@return RestoreProbeInfo|nil info
function MapMethods:probe_position(pos)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return nil
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return nil
	end

	local half = octchunk.SIZE / 2
	local base = tr.base_corner
	local gx = math.floor((pos.x - base.x) / MAPBLOCK_SIZE) + 1
	local gy = math.floor((pos.y - base.y) / MAPBLOCK_SIZE) + 1
	local gz = math.floor((pos.z - base.z) / MAPBLOCK_SIZE) + 1

	local s = self.trees.size
	local in_bounds = gx >= 1 and gx <= s.x and gy >= 1 and gy <= s.y and gz >= 1 and gz <= s.z

	local live_node = core.get_node(pos).name
	local expected_node = octmap.get_node_name(self, pos) or "(out of snapshot)"

	local is_sparse = false
	local is_dirty = false
	if in_bounds then
		local cell = matrix3d.get(self.trees, gx, gy, gz)
		is_sparse = (cell == nil)
		local key = chunk_key(self, gx, gy, gz)
		is_dirty = (tr.dirty[key] == true)
	end

	return {
		live_node = live_node,
		expected_node = expected_node,
		gx = gx,
		gy = gy,
		gz = gz,
		is_sparse = is_sparse,
		is_dirty = is_dirty,
		in_bounds = in_bounds,
	}
end



---Force verify a chunk containing a world position and return updated dirty state.
---@param self OctMap
---@param pos vector
---@return boolean|nil is_dirty nil if out of bounds, true/false if verified
function MapMethods:verify_chunk_at_position(pos)
	local tracker_id = self._tracker_id
	if not tracker_id then
		return nil
	end
	local tr = tracking.trackers[tracker_id]
	if not tr then
		return nil
	end

	local base = tr.base_corner
	local gx = math.floor((pos.x - base.x) / MAPBLOCK_SIZE) + 1
	local gy = math.floor((pos.y - base.y) / MAPBLOCK_SIZE) + 1
	local gz = math.floor((pos.z - base.z) / MAPBLOCK_SIZE) + 1

	if not grid_in_range(self, gx, gy, gz) then
		return nil
	end

	verify_chunk_against_snapshot(tr, gx, gy, gz)
	local key = chunk_key(self, gx, gy, gz)
	return tr.dirty[key] == true
end



-- Rely exclusively on on_mapblocks_changed (covers dig/place and engine writes).
core.register_on_mapblocks_changed(function(modified_blocks, modified_block_count)
	if modified_block_count == 0 then
		return
	end
	if #tracking.order == 0 then
		return
	end

	local now_us = core.get_us_time()

	for hash in pairs(modified_blocks) do
		local bp = core.get_position_from_hash(hash)
		for i = 1, #tracking.order do
			local id = tracking.order[i]
			local tr = tracking.trackers[id]
			if tr and not tr.suppress then
				if tr.last_change == nil or tr.last_change.t_us ~= now_us then
					tr.last_change = {
						t_us = now_us,
						modified_block_count = modified_block_count,
						blocks_seen = 0,
						blocks_in_bounds = 0,
						blocks_out_of_bounds = 0,
						unique_chunks_enqueued = 0,
					}
				end
				tr.last_change.blocks_seen = tr.last_change.blocks_seen + 1

				if block_in_bounds(bp, tr.min_blockpos, tr.max_blockpos) then
					tr.last_change.blocks_in_bounds = tr.last_change.blocks_in_bounds + 1
					local gx1, gy1, gz1, gx2, gy2, gz2 = blockpos_to_grid_range(tr, bp)
					for gx = gx1, gx2 do
						for gy = gy1, gy2 do
							for gz = gz1, gz2 do
								if enqueue_chunk(tr, gx, gy, gz) then
									tr.last_change.unique_chunks_enqueued = tr.last_change.unique_chunks_enqueued + 1
								end
							end
						end
					end
				else
					tr.last_change.blocks_out_of_bounds = tr.last_change.blocks_out_of_bounds + 1
				end
			end
		end
	end
end)


core.register_globalstep(function()
	if #tracking.order == 0 then
		return
	end

	local budget_us = get_scan_budget_ms() * 1000
	local t0 = core.get_us_time()
	local tries_without_work = 0
	local verified_count = 0

	while (core.get_us_time() - t0) < budget_us do
		local id = tracking.order[tracking.rr]
		tracking.rr = tracking.rr + 1
		if tracking.rr > #tracking.order then
			tracking.rr = 1
		end

		local tr = tracking.trackers[id]
		local did_work = false
		if tr and not tr.suppress then
			local key = pop_next_chunk_key(tr)
			if key then
				local gx, gy, gz = key_to_grid(tr.map, key)
				verify_chunk_against_snapshot(tr, gx, gy, gz)
				did_work = true
				verified_count = verified_count + 1
			end
		end

		if did_work then
			tries_without_work = 0
		else
			tries_without_work = tries_without_work + 1
			if tries_without_work >= #tracking.order then
				-- No more pending work for any tracker in this tick.
				local now_us = core.get_us_time()
				if verified_count > 0 and (now_us - last_globalstep_log_us) >= 1000000 then
					last_globalstep_log_us = now_us
					core.log("action", string.format("[restore_tracking] globalstep verified %d chunks", verified_count))
				end
				return
			end
		end
	end

	local now_us = core.get_us_time()
	if verified_count > 0 and (now_us - last_globalstep_log_us) >= 1000000 then
		last_globalstep_log_us = now_us
		core.log("action", string.format("[restore_tracking] globalstep verified %d chunks (budget exhausted)", verified_count))
	end
end)

---Callback invoked by OctreeManip after async write completes.
---@param pos1 vector
---@param pos2 vector
function map_octree.on_async_write_complete(pos1, pos2)
	assert(pos1 and pos2, "pos1/pos2 required")
	pos1, pos2 = vector.sort(pos1, pos2)
	local minbp = {
		x = node_to_block_coord(pos1.x),
		y = node_to_block_coord(pos1.y),
		z = node_to_block_coord(pos1.z),
	}
	local maxbp = {
		x = node_to_block_coord(pos2.x),
		y = node_to_block_coord(pos2.y),
		z = node_to_block_coord(pos2.z),
	}

	for i = 1, #tracking.order do
		local id = tracking.order[i]
		local tr = tracking.trackers[id]
		if tr and not tr.suppress then
			local bx1 = math.max(minbp.x, tr.min_blockpos.x)
			local by1 = math.max(minbp.y, tr.min_blockpos.y)
			local bz1 = math.max(minbp.z, tr.min_blockpos.z)
			local bx2 = math.min(maxbp.x, tr.max_blockpos.x)
			local by2 = math.min(maxbp.y, tr.max_blockpos.y)
			local bz2 = math.min(maxbp.z, tr.max_blockpos.z)

			if bx1 <= bx2 and by1 <= by2 and bz1 <= bz2 then
				for z = bz1, bz2 do
					for y = by1, by2 do
						for x = bx1, bx2 do
							local bp = {x = x, y = y, z = z}
							local gx1, gy1, gz1, gx2, gy2, gz2 = blockpos_to_grid_range(tr, bp)
							for gx = gx1, gx2 do
								for gy = gy1, gy2 do
									for gz = gz1, gz2 do
										enqueue_chunk(tr, gx, gy, gz)
									end
								end
							end
						end
					end
				end
			end
		end
	end
end
