--[[
	Lag Compensation System - Math Functions
	Interpolation algorithms and collision math.
]]

local function find_bracket(history, target_time) end

local get_us_time = core.get_us_time
local get_player_by_name = core.get_player_by_name
local table_insert = table.insert
local table_remove = table.remove
local vector_add = vector.add
local vector_subtract = vector.subtract
local math_abs = math.abs

local DEFAULT_EYE_HEIGHT = weapons_lib.lagcomp.DEFAULT_EYE_HEIGHT
local DEFAULT_COLLISIONBOX = weapons_lib.lagcomp.DEFAULT_COLLISIONBOX
local HISTORY_DURATION = weapons_lib.lagcomp.HISTORY_DURATION
local RECENT_INTERPOLATION_THRESHOLD = weapons_lib.lagcomp.RECENT_INTERPOLATION_THRESHOLD



function weapons_lib.lagcomp.add_snapshot(player)
  local p_name = player:get_player_name()
  local t = get_us_time() / 1000000
  local pos = player:get_pos()
  local eye_height = player:get_properties().eye_height or DEFAULT_EYE_HEIGHT
  pos.y = pos.y + eye_height

  local history = weapons_lib.lagcomp.player_history[p_name]
  if not history then
    history = {}
    weapons_lib.lagcomp.player_history[p_name] = history
  end

  local velocity = player:get_velocity()
  table_insert(history, {
    time = t,
    pos = pos,
    velocity = velocity
  })

  -- Clean up old history entries
  local cutoff_time = t - HISTORY_DURATION
  while history[1] and history[1].time < cutoff_time do
    table_remove(history, 1)
  end
end



-- Main interpolation function (with optional debug info)
function weapons_lib.lagcomp.get_interpolated_position(history, target_time, return_debug_info)
  if not history or #history < 1 then
    return
  end

  if #history == 1 then
    local debug_info = return_debug_info and {
      method = "DIRECT_POSITION",
      description = "single snapshot"
    } or nil
    return history[1].pos, debug_info
  end

  -- Use binary search for bracketing snapshots
  local snapshot_before, snapshot_after = find_bracket(history, target_time)

  -- Handle edge cases with velocity-based prediction
  if not snapshot_before then
    local earliest = history[1]
    local time_diff = target_time - earliest.time
    local result = {
      x = earliest.pos.x + earliest.velocity.x * time_diff,
      y = earliest.pos.y + earliest.velocity.y * time_diff,
      z = earliest.pos.z + earliest.velocity.z * time_diff
    }
    local debug_info = return_debug_info and {
      method = "BACKWARD_EXTRAPOLATION",
      description = string.format("%.3fs before earliest snapshot", -time_diff),
      time_diff = time_diff,
      used_snapshot = earliest
    } or nil
    return result, debug_info
  end

  if not snapshot_after then
    local latest = history[#history]
    local time_diff = target_time - latest.time
    local result = {
      x = latest.pos.x + latest.velocity.x * time_diff,
      y = latest.pos.y + latest.velocity.y * time_diff,
      z = latest.pos.z + latest.velocity.z * time_diff
    }
    local debug_info = return_debug_info and {
      method = "FORWARD_EXTRAPOLATION",
      description = string.format("%.3fs after latest snapshot", time_diff),
      time_diff = time_diff,
      used_snapshot = latest
    } or nil
    return result, debug_info
  end

  local total_time_diff = snapshot_after.time - snapshot_before.time
  if total_time_diff <= 0 then
    local debug_info = return_debug_info and {
      method = "DIRECT_POSITION",
      description = "identical timestamps"
    } or nil
    return snapshot_before.pos, debug_info
  end

  local latest_snapshot = history[#history]
  local time_from_latest = math_abs(target_time - latest_snapshot.time)
  local target_time_diff = target_time - snapshot_before.time
  local alpha = target_time_diff / total_time_diff
  local debug_info = return_debug_info and {
    snapshot_before = snapshot_before,
    snapshot_after = snapshot_after,
    target_time = target_time,
    total_time_diff = total_time_diff,
    target_time_diff = target_time_diff,
    alpha = alpha,
    time_from_latest = time_from_latest
  } or nil

  if time_from_latest <= RECENT_INTERPOLATION_THRESHOLD then
    local p0, p1 = snapshot_before.pos, snapshot_after.pos
    local result = {
      x = p0.x + alpha * (p1.x - p0.x),
      y = p0.y + alpha * (p1.y - p0.y),
      z = p0.z + alpha * (p1.z - p0.z)
    }
    if debug_info then
      debug_info.method = "LINEAR_INTERPOLATION"
      debug_info.description = "recent data"
    end
    return result, debug_info
  end

  -- Hermite spline interpolation for older data
  local alpha2 = alpha * alpha
  local alpha3 = alpha2 * alpha
  local h1 = 2 * alpha3 - 3 * alpha2 + 1
  local h2 = -2 * alpha3 + 3 * alpha2
  local h3 = alpha3 - 2 * alpha2 + alpha
  local h4 = alpha3 - alpha2
  local p0, p1 = snapshot_before.pos, snapshot_after.pos
  local m0 = {
    x = snapshot_before.velocity.x * total_time_diff,
    y = snapshot_before.velocity.y * total_time_diff,
    z = snapshot_before.velocity.z * total_time_diff
  }
  local m1 = {
    x = snapshot_after.velocity.x * total_time_diff,
    y = snapshot_after.velocity.y * total_time_diff,
    z = snapshot_after.velocity.z * total_time_diff
  }
  local result = {
    x = p0.x * h1 + p1.x * h2 + m0.x * h3 + m1.x * h4,
    y = p0.y * h1 + p1.y * h2 + m0.y * h3 + m1.y * h4,
    z = p0.z * h1 + p1.z * h2 + m0.z * h3 + m1.z * h4
  }

  if debug_info then
    debug_info.method = "HERMITE_SPLINE_INTERPOLATION"
    debug_info.description = "older data"
  end

  return result, debug_info
end



-- Get interpolated position with debug output
function weapons_lib.lagcomp.get_interpolated_position_and_print_debug(p_name, target_time, shooter_name)
  local history = weapons_lib.lagcomp.player_history[p_name]

  -- Get result and debug info in one call
  local result, debug_info = weapons_lib.lagcomp.get_interpolated_position(history, target_time, true)

  if not result then
    return
  end

  if not debug_info then
    return result
  end

  -- Output debug information based on method
  if debug_info.method == "DIRECT_POSITION" then
    weapons_lib.print_debug(true, shooter_name, string.format("Method: DIRECT POSITION (%s)", debug_info.description))
  elseif debug_info.method == "BACKWARD_EXTRAPOLATION" then
    weapons_lib.print_debug(true, shooter_name, string.format("Method: BACKWARD EXTRAPOLATION (%s)", debug_info.description))

    local target_player = get_player_by_name(p_name)
    if target_player then
      local current_pos = target_player:get_pos()
      local eye_height = target_player:get_properties().eye_height or DEFAULT_EYE_HEIGHT
      current_pos.y = current_pos.y + eye_height
      local distance_diff = vector.distance(result, current_pos)
      weapons_lib.print_debug(true, shooter_name, string.format("Position difference: %.3f nodes", distance_diff))
    end
  elseif debug_info.method == "FORWARD_EXTRAPOLATION" then
    weapons_lib.print_debug(true, shooter_name, string.format("Method: FORWARD EXTRAPOLATION (%s)", debug_info.description))

    local target_player = get_player_by_name(p_name)
    if target_player then
      local current_pos = target_player:get_pos()
      local eye_height = target_player:get_properties().eye_height or DEFAULT_EYE_HEIGHT
      current_pos.y = current_pos.y + eye_height
      local distance_diff = vector.distance(result, current_pos)
      weapons_lib.print_debug(true, shooter_name, string.format("Position difference: %.3f nodes", distance_diff))
    end
  else
    -- Linear or Hermite interpolation
    weapons_lib.print_debug(true, shooter_name,
      string.format("Using snapshots: %.6fs to %.6fs (rewind: %.6fs)",
        debug_info.snapshot_before.time, debug_info.snapshot_after.time, debug_info.target_time))
    weapons_lib.print_debug(true, shooter_name,
      string.format("Alpha calc: (%.6f - %.6f) / %.6f = %.6f",
        debug_info.target_time, debug_info.snapshot_before.time, debug_info.total_time_diff, debug_info.alpha))

    weapons_lib.print_debug(true, shooter_name, string.format("Method: %s (%s)",
      debug_info.method:gsub("_", " "), debug_info.description))

    local target_player = get_player_by_name(p_name)
    if target_player then
      local current_pos = target_player:get_pos()
      local eye_height = target_player:get_properties().eye_height or DEFAULT_EYE_HEIGHT
      current_pos.y = current_pos.y + eye_height
      local distance_diff = vector.distance(result, current_pos)
      weapons_lib.print_debug(true, shooter_name, string.format("Result position: (%.3f, %.3f, %.3f)", result.x, result.y, result.z))
      weapons_lib.print_debug(true, shooter_name, string.format("Current position: (%.3f, %.3f, %.3f)", current_pos.x, current_pos.y, current_pos.z))
      weapons_lib.print_debug(true, shooter_name, string.format("Position difference: %.3f nodes", distance_diff))
    end
  end

  return result
end



----------------------------------------------
---------------COLLISION MATH----------------
----------------------------------------------

function weapons_lib.lagcomp.ray_aabb(o, d, len, bmin, bmax)
  local dx_inv = 1 / d.x
  local dy_inv = 1 / d.y
  local dz_inv = 1 / d.z

  -- X
  local t1 = (bmin.x - o.x) * dx_inv
  local t2 = (bmax.x - o.x) * dx_inv
  if t1 > t2 then t1, t2 = t2, t1 end
  local tmin = t1
  local tmax = t2
  if tmin > len or tmax < 0 then return nil end

  -- Y
  t1 = (bmin.y - o.y) * dy_inv
  t2 = (bmax.y - o.y) * dy_inv
  if t1 > t2 then t1, t2 = t2, t1 end
  if t1 > tmin then tmin = t1 end
  if t2 < tmax then tmax = t2 end
  if tmin > tmax then return nil end
  if tmin > len or tmax < 0 then return nil end

  -- Z
  t1 = (bmin.z - o.z) * dz_inv
  t2 = (bmax.z - o.z) * dz_inv
  if t1 > t2 then t1, t2 = t2, t1 end
  if t1 > tmin then tmin = t1 end
  if t2 < tmax then tmax = t2 end
  if tmin > tmax or tmin > len or tmax < 0 then return nil end

  -- Return the intersection point at tmin
  return tmin, {x = o.x + d.x * tmin, y = o.y + d.y * tmin, z = o.z + d.z * tmin}
end



-- Test ray collision against a specific player position
function weapons_lib.lagcomp.test_ray_collision(ray_origin, ray_dir, ray_length, player_obj, test_position)
  local props = player_obj:get_properties()
  local cbox = props.collisionbox or DEFAULT_COLLISIONBOX
  local eye_height = props.eye_height or DEFAULT_EYE_HEIGHT
  local feet_pos = vector_subtract(test_position, {x = 0, y = eye_height, z = 0})
  local box_min = vector_add(feet_pos, {x = cbox[1], y = cbox[2], z = cbox[3]})
  local box_max = vector_add(feet_pos, {x = cbox[4], y = cbox[5], z = cbox[6]})

  local distance, intersection_point = weapons_lib.lagcomp.ray_aabb(ray_origin, ray_dir, ray_length, box_min, box_max)
  if distance then
    return {
      object = player_obj,
      distance = distance,
      intersection_point = intersection_point
    }
  end
end



----------------------------------------------
----------LOCAL FUNCTION IMPLEMENTATIONS-----
----------------------------------------------

function find_bracket(history, target_time)
  local low, high = 1, #history
  local before, after = nil, nil
  while low <= high do
    local mid = math.floor((low + high) / 2)
    local t = history[mid].time
    if t == target_time then
      return history[mid], history[mid]
    elseif t < target_time then
      before = history[mid]
      low = mid + 1
    else
      after = history[mid]
      high = mid - 1
    end
  end
  return before, after
end
