local ID = octchunk.ID
local CHILDREN = octchunk.CHILDREN
local track_tree_content_ids_impl

--#region types
---@class EmergeAreaInfo
---@field minp vector
---@field maxp vector
---@field padded_min vector
---@field padded_max vector
---@field had_error boolean
---@field param any
--#endregion



---Return positions sorted by axis.
---@param p1 vector
---@param p2 vector
---@return vector
---@return vector
local function sorted_pos(p1, p2)
	return vector.sort(p1, p2)
end


---Returns snapped grid bounds and their padded voxel bounds.
---The padded bounds cover the full voxel region needed to build trees.
---@param pos1 vector
---@param pos2 vector
---@return vector minp Snapped min
---@return vector maxp Snapped max
---@return vector padded_min Voxel min
---@return vector padded_max Voxel max
function octmap.get_padded_bounds(pos1, pos2)
	assert(pos1 and pos2, "pos1 and pos2 required")
	pos1, pos2 = sorted_pos(pos1, pos2)
	local minp = octchunk.snap_to_center(pos1)
	local maxp = octchunk.snap_to_center(pos2)
	local padded_min = vector.subtract(minp, octchunk.SIZE / 2)
	local padded_max = vector.add(maxp, octchunk.SIZE / 2)
	return minp, maxp, padded_min, padded_max
end



---Ensure a snapshot region exists by triggering map generation.
---Low-level: prefer octmap.new_async() unless you need only emerge info.
---@param pos1 vector
---@param pos2 vector
---@param callback fun(ok: boolean, info: EmergeAreaInfo)
---@param param? any User parameter passed to callback
function octmap.emerge_area_async(pos1, pos2, callback, param)
	if type(callback) ~= "function" and type(param) == "function" then
		callback, param = param, callback
	end
	assert(type(callback) == "function", "callback required")
	local minp, maxp, padded_min, padded_max = octmap.get_padded_bounds(pos1, pos2)

	local had_error = false
	core.emerge_area(padded_min, padded_max, function(blockpos, action, calls_remaining, cb_param)
		if action == core.EMERGE_CANCELLED or action == core.EMERGE_ERRORED then
			had_error = true
		end
		if calls_remaining == 0 then
			callback(not had_error, {
				minp = minp,
				maxp = maxp,
				padded_min = padded_min,
				padded_max = padded_max,
				had_error = had_error,
				param = cb_param,
			})
		end
	end, param)
end



---Async variant of octmap.new() that guarantees the area has been generated.
---@param pos1 vector
---@param pos2 vector
---@param opts? MapCreationOpts
---@param callback fun(map: OctMap|nil, err?: string)
function octmap.new_async(pos1, pos2, opts, callback)
	if type(opts) == "function" and callback == nil then
		callback, opts = opts, nil
	end
	assert(type(callback) == "function", "callback required")
	octmap.emerge_area_async(pos1, pos2, function(ok)
		if not ok then
			callback(nil, "emerge_area failed or was cancelled")
			return
		end
		local map = octmap.new(pos1, pos2, opts)
		callback(map)
	end)
end



---Check if a tree is uniform (single leaf, no children).
---@param tree OctNode
---@return boolean
local function is_uniform_tree(tree)
	return tree[CHILDREN] == nil
end

---Get a uniform tree's content_id (root ID).
---@param tree OctNode
---@return integer
local function get_uniform_content_id(tree)
	return tree[ID]
end

---Track content_ids in a tree into the global map.
---@param tree OctNode
---@param content_id_map table<integer, string>
local function track_tree_content_ids(tree, content_id_map)
	track_tree_content_ids_impl(tree, content_id_map)
end

---Process a batch of chunks using a single VoxelManip read.
---@param map OctMap
---@param batch_minp vector
---@param batch_maxp vector
---@param x_idx_start integer
---@param y_idx_start integer
---@param z_idx_start integer
---@param trees_x integer
---@param trees_y integer
---@param trees_z integer
---@param opts? MapCreationOpts
---@param uniform_counts table<integer, integer>
local function process_batch(map, batch_minp, batch_maxp, x_idx_start, y_idx_start, z_idx_start, trees_x, trees_y, trees_z, opts, uniform_counts)
	local store_chunk_blobs = opts and opts.store_chunk_blobs
	local padded_min = vector.subtract(batch_minp, octchunk.SIZE / 2)
	local padded_max = vector.add(batch_maxp, octchunk.SIZE / 2)

	core.load_area(padded_min, padded_max)

	local manip = core.get_voxel_manip()
	local emerged_pos1, emerged_pos2 = manip:read_from_map(padded_min, padded_max)
	local area = VoxelArea(emerged_pos1, emerged_pos2)
	local data = {}
	local param2_data = {}
	local param1_data = {}
	manip:get_data(data)
	manip:get_param2_data(param2_data)
	param1_data = manip:get_light_data()
	---@cast param1_data integer[]

	for bx = 0, trees_x - 1 do
		for by = 0, trees_y - 1 do
			for bz = 0, trees_z - 1 do
				local tree_center = vector.new(
					batch_minp.x + bx * octchunk.SIZE,
					batch_minp.y + by * octchunk.SIZE,
					batch_minp.z + bz * octchunk.SIZE
				)
				tree_center = octchunk.snap_to_center(tree_center)

				local new_tree = {
					center = tree_center,
					size = octchunk.SIZE
				}
				octchunk.populate_tree_from_area(new_tree, area, data, param2_data, param1_data)

				-- Track content_ids globally
				track_tree_content_ids(new_tree, map.content_id_map)

				-- Track uniform trees for later sparsification (by content_id)
				if is_uniform_tree(new_tree) then
					local cid = get_uniform_content_id(new_tree)
					uniform_counts[cid] = (uniform_counts[cid] or 0) + 1
				end

				if store_chunk_blobs then
					matrix3d.set(map.trees, x_idx_start + bx, y_idx_start + by, z_idx_start + bz, octchunk.serialize(new_tree))
				else
					matrix3d.set(map.trees, x_idx_start + bx, y_idx_start + by, z_idx_start + bz, new_tree)
				end
				map_octree.add_wireframe(tree_center)
			end
		end
	end
end


---Create a new OctMap snapshot from the world.
---@param minp vector
---@param maxp vector
---@param opts? MapCreationOpts
---@return OctMap
function octmap.new(minp, maxp, opts)
	opts = octmap.apply_server_limits(opts, 1)
	local max_voxelmanip_volume = opts.max_voxelmanip_volume
	local store_chunk_blobs = opts.store_chunk_blobs
	local t_start = core.get_us_time()
	local map = {}
	local req_min, req_max = vector.sort(minp, maxp)
	req_min = assert(vector.round(req_min))
	req_max = assert(vector.round(req_max))
	minp, maxp = req_min, req_max
	minp, maxp = octchunk.snap_to_center(minp), octchunk.snap_to_center(maxp)

	map.minp = minp
	map.maxp = maxp
	map.requested_pos1 = req_min
	map.requested_pos2 = req_max
	map.default_node = nil -- will be set after build if uniform trees exist
	map.content_id_map = {} -- global content_id->name map
	map.cache_mb = opts.cache_mb
	map.trees = matrix3d.new(
		math.floor((maxp.x - minp.x) / octchunk.SIZE) + 1,
		math.floor((maxp.y - minp.y) / octchunk.SIZE) + 1,
		math.floor((maxp.z - minp.z) / octchunk.SIZE) + 1,
		nil -- sparse: nil means "use default_node"
	)

	local chunk_count = map.trees.size.x * map.trees.size.y * map.trees.size.z
	map_octree.debug(string.format("Creating %d chunks: min=%s max=%s",
		chunk_count, core.pos_to_string(minp), core.pos_to_string(maxp)))

	local plan = octmap.plan_batches(minp, maxp, opts)
	local total_volume = 0
	if plan.batches[1] then
		-- for logging only
		local padded_min = vector.subtract(minp, octchunk.SIZE / 2)
		local padded_max = vector.add(maxp, octchunk.SIZE / 2)
		local size = vector.subtract(padded_max, padded_min)
		total_volume = (size.x + 1) * (size.y + 1) * (size.z + 1)
	end

	local t_voxelmanip = t_start
	local uniform_counts = {} -- collected during build for sparsification (by content_id)

	if #plan.batches == 1 and not opts.force_batches and plan.batches[1].volume <= max_voxelmanip_volume then
		-- Fast path: single VoxelManip read for entire area
		local b = plan.batches[1]
		core.load_area(b.padded_min, b.padded_max)

		local manip = core.get_voxel_manip()
		local emerged_pos1, emerged_pos2 = manip:read_from_map(b.padded_min, b.padded_max)
		local area = VoxelArea(emerged_pos1, emerged_pos2)
		local data = {}
		local param2_data = {}
		local param1_data = {}
		manip:get_data(data)
		manip:get_param2_data(param2_data)
		param1_data = manip:get_light_data()
		---@cast param1_data integer[]

		t_voxelmanip = core.get_us_time()

		matrix3d.iterate(map.trees, function(x, y, z)
			local tree_center = vector.new(
				minp.x + (x - 1) * octchunk.SIZE,
				minp.y + (y - 1) * octchunk.SIZE,
				minp.z + (z - 1) * octchunk.SIZE
			)
			tree_center = octchunk.snap_to_center(tree_center)

			local new_tree = {
				center = tree_center,
				size = octchunk.SIZE
			}
			octchunk.populate_tree_from_area(new_tree, area, data, param2_data, param1_data)

			-- Track content_ids globally
			track_tree_content_ids(new_tree, map.content_id_map)

			-- Track uniform trees (by content_id)
			if is_uniform_tree(new_tree) then
				local cid = get_uniform_content_id(new_tree)
				uniform_counts[cid] = (uniform_counts[cid] or 0) + 1
			end

			if store_chunk_blobs then
				matrix3d.set(map.trees, x, y, z, octchunk.serialize(new_tree))
			else
				matrix3d.set(map.trees, x, y, z, new_tree)
			end
			map_octree.add_wireframe(tree_center)
		end)
	else
		map_octree.debug(string.format("Large area (%.1fM voxels), processing in %d batches",
			total_volume / 1000000, #plan.batches))

		for i = 1, #plan.batches do
			local b = plan.batches[i]
			local trees_x = b.x_end - b.x_idx + 1
			local trees_y = b.y_end - b.y_idx + 1
			local trees_z = b.z_end - b.z_idx + 1
			process_batch(map, b.minp, b.maxp, b.x_idx, b.y_idx, b.z_idx, trees_x, trees_y, trees_z, opts, uniform_counts)
			collectgarbage("collect")
		end

		t_voxelmanip = core.get_us_time()
	end

	local t_chunks = core.get_us_time()

	if not store_chunk_blobs then
		-- Single octcache pass for ALL chunks (skip nil cells)
		matrix3d.iterate(map.trees, function(x, y, z)
			local tree = matrix3d.get(map.trees, x, y, z)
			if tree then octcache.create(tree) end
		end)
		matrix3d.iterate(map.trees, function(x, y, z)
			local tree = matrix3d.get(map.trees, x, y, z)
			if tree then octcache.use(tree) end
		end)
	end

	-- Sparsify grid using pre-collected uniform_counts (keyed by content_id)
	local uniform_total = 0
	for _, count in pairs(uniform_counts) do
		uniform_total = uniform_total + count
	end

	if uniform_total > 0 then
		-- pick most common uniform content_id as default
		local best_cid, best_count = nil, 0
		for cid, count in pairs(uniform_counts) do
			if count > best_count then
				best_cid, best_count = cid, count
			end
		end
		map.default_node = map.content_id_map[best_cid]
		map.default_content_id = best_cid

		-- replace uniform chunks matching default with nil (sparse)
		local sparse_count = 0
		matrix3d.iterate(map.trees, function(x, y, z)
			local tree = matrix3d.get(map.trees, x, y, z)
			if tree == nil then
				return -- already sparse
			end
			-- For blobs, deserialize temporarily to check
			local tree_table = type(tree) == "string" and octchunk.deserialize(tree, {use_cache = false}) or tree
			if is_uniform_tree(tree_table) then
				local cid = get_uniform_content_id(tree_table)
				if cid == best_cid then
					matrix3d.set(map.trees, x, y, z, nil)
					sparse_count = sparse_count + 1
				end
			end
		end)
		map_octree.debug(string.format("Sparse grid: %d/%d chunks omitted (default=%s)",
			sparse_count, chunk_count, map.default_node))
	end

	if map.default_content_id == nil then
		map.default_node = map.default_node or "air"
		map.default_content_id = core.get_content_id(map.default_node)
	end

	local t_end = core.get_us_time()
	map_octree.debug(string.format("Timing: VoxelManip=%.1fms, Chunks=%.1fms, Cache=%.1fms, Total=%.1fms",
		(t_voxelmanip - t_start) / 1000,
		(t_chunks - t_voxelmanip) / 1000,
		(t_end - t_chunks) / 1000,
		(t_end - t_start) / 1000))

	return octmap.attach_methods(map)
end



function track_tree_content_ids_impl(octnode, content_id_map)
	local cid = octnode[ID]
	if not content_id_map[cid] then
		content_id_map[cid] = core.get_name_from_content_id(cid)
	end
	local children = octnode[CHILDREN]
	if children then
		for _, child in pairs(children) do
			track_tree_content_ids_impl(child, content_id_map)
		end
	end
end
