-- Copyright (C) 2024 rstcxk
-- 
-- This program is free software: you can redistribute it and/or modify it under the terms of
-- the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
-- 
-- This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
-- without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
-- 
-- You should have received a copy of the GNU Affero General Public License along with this program. If not, see <https://www.gnu.org/licenses/>. 

-- @classmod LanguageConstruct_Expression
--	Describes a command invocation

local StringExpression
local NumberLiteral
local InlineLua
local VectorExpression
local RangeExpression
local SelectorExpression local TableExpression
local ProcessSubstitution
local ParameterExpansion
local ExpressionWrapper = require("language_constructs.expression_wrapper")
local Bool
local FunctionDefinition 


local operators = require("operators")
local helpers = require("helpers")

local operator_list = {}
local unary_operator_list = {}

for k, v in pairs(operators) do
	if v.type == "unary" then
		table.insert(unary_operator_list, k)
	else
		table.insert(operator_list, k)
	end
end

-- operator_list is useless, but i don't want to write the code to turn a set directly into a trie
local operator_trie = helpers.create_trie(operator_list)
local unary_operator_trie = helpers.create_trie(unary_operator_list)

local Expression =
{
	instance_of = "LanguageConstruct_Expression",
}

Expression.__index = Expression

-- Expressions are special because they dont allocate any memory at startup.
-- This is because they are used a lot when they dont even need it (because its a simple expression).
-- As such, only allocate if actually needed.
function Expression:new()
	return setmetatable({}, Expression)
end

function Expression:evaluate(ctx)
	-- elaborate way to recycle tables
	if not self.stack then
		self.stack = {}
	end
	self.stack[1] = nil

	local stack_head = 0
	for i = 1, #self.expression do
		if self.expression[i].instance_of then
			-- the element is a value
			stack_head = stack_head + 1
			self.stack[stack_head] = self.expression[i]:evaluate(ctx)
		else
			-- the element is an operator
			if self.expression[i].type == "unary" then
				self.stack[stack_head] = self.expression[i].action(self.stack[stack_head])
			else
				local result = self.expression[i].action(self.stack[stack_head - 1], self.stack[stack_head])
				self.stack[stack_head - 1] = result
				self.stack[stack_head] = nil
				stack_head = stack_head - 1
			end
		end
	end

	return self.stack[1]
end

local subexpression_terminators = helpers.create_trie({")", "\n", " ", "\t"})

local default_terminators = helpers.create_trie({"\n", " ", "\t"})

-- I KNOW
-- its so stupid
-- but its also the only way things like $tab.member.x are parsed correctly
-- but i also don't want to terminate on every operator. Take for example world
-- I dont think it should be parsed as w *or* ld
local unquoted_string_terminators = helpers.create_trie({".", "[", " "})

local actual_terminator_cache = {}

local terminator_cache = {}

function Expression:parse(parser_ctx, terminators)
	-- {{{ type checking
	assert(parser_ctx.instance_of == "ParserContext", "parser_ctx should be an instance of ParserContext. instead of " .. parser_ctx.instance_of)
	-- }}}

	local expression_stack
	local operator_stack
	local char
	local initialized = false

	local function full_initialize(expression)
		initialized = true
		self.expression = self.expression or {}
		expression_stack = expression_stack or {expression}
		operator_stack = operator_stack or {}
	end

	local function sortRPN()
		-- sort the RPN so operator priority is taken into account
		local has_lower_priority
		local higher_priority_streak_start = 1
		local expression_stack_head = 2
		table.insert(self.expression, expression_stack[1])

		-- Combine the operator stack and the expression stack into one
		-- Lower priority means that the operation will be executed first, before any operations that precede it
		-- Here is an explanation, because even i found it difficult to read the code

		-- Imagine operations 2 + 3 + 3 * 4
		-- the stacks are as follows:
		-- expression stack: 2 3 3 4
		-- operator stack (paranthesis is the priority): +(2) +(2) *(1)

		-- First off, the output stack is initialized with the first expression

		-- Then it iterates over all the operators, during which it will always append the next expression, and check if the
		-- next operator has a lower or equal priority. If it does, it pushes the operator to the output, otherwise it will 
		-- keep iterating until it finds a lower or equal priority operator and flush all the operators it skipped over, but *in
		-- reverse order*

		-- So to finish the example, At the end of each iteration the state would be:
		-- iteration 0 (before the loop even begins): 2
		-- iteration 1: output stack = 2 3 +          (priority of first operator is equal to the second, therefore flush it)
		-- iteration 2: output stack = 2 3 + 3        (priority of the second operator is less than the third)
		-- iteration 3: output stack = 2 3 + 3 4 * +  (reached the end, flush all operators in reverse order)
		for i = 1, #operator_stack do
			-- i < #operator_stack handles the case of the last operator in the stack
			if i < #operator_stack and operator_stack[i].priority > operator_stack[i + 1].priority then
				has_lower_priority = false
			else
				has_lower_priority = true
			end

			if operator_stack[i].type ~= "unary" then
				table.insert(self.expression, expression_stack[expression_stack_head])
				expression_stack_head = expression_stack_head + 1
			end

			-- only push operators to the stack if the new operator has lower or equal priority
			if has_lower_priority then
				-- push all the operators to the stack in reverse order
				for j = i, higher_priority_streak_start, -1 do
					table.insert(self.expression, operator_stack[j])
				end
				higher_priority_streak_start = i + 1
			end
		end

		-- constant folding optimization
		local i = 1
		while i <= #self.expression do
			local v = self.expression[i]
			if not v.instance_of then
				local exp1 = self.expression[i - 2]
				local exp2 = self.expression[i - 1]
				if v.type ~= "unary" and exp1.instance_of and exp1.static
					and exp2.instance_of and exp2.static then
						local result = self.expression[i].action(exp1:static_eval(), exp2:static_eval())
						self.expression[i - 2] = ExpressionWrapper:new(result)
						-- shifting all the elements backwards. table.remove is inefficient here, since it would shift the
						-- elements twice
						for j = i - 1, #self.expression do
							self.expression[j] = self.expression[j + 2]
						end
						i = i - 3
				end
			end
			i = i + 1
		end

		if #self.expression == 1 then
			self.value = self.expression[1]:static_eval()
			self.expression = nil
			setmetatable(self, ExpressionWrapper)
		end
	end

	local function parse_expression(terminators)
		local actual_terminators
		if actual_terminator_cache[terminators] then
			actual_terminators = actual_terminator_cache[terminators]
		else
			local tmp = helpers.add_tries(terminators, operator_trie)
			actual_terminators = tmp
			actual_terminator_cache[terminators] = tmp
		end

		char = parser_ctx:peek()
		local expression

		if char == "`" then
			InlineLua = InlineLua or require("language_constructs.inline_lua")
			expression = InlineLua:new()
			expression:parse(parser_ctx)
		elseif char == "@" then
			SelectorExpression = SelectorExpression or require("language_constructs.selector_expression")
			expression = SelectorExpression:new()
			expression:parse(parser_ctx)
		elseif char == "{" then
			TableExpression = TableExpression or require("language_constructs.table_expression")
			expression = TableExpression:new()
			expression:parse(parser_ctx)
		else
			-- for more complex types where i cant simply check the first character,
			-- use pattern matching
			-- this is a temporary thing untill i add a match function to all language constructs to do a better job

			-- vector
			if string.find(parser_ctx.text, [[^^?%([^)]*,[^)]*,[^)]*%)]], parser_ctx.character_index) then
				VectorExpression = VectorExpression or require("language_constructs.vector_expression")
				expression = VectorExpression:new()
				expression:parse(parser_ctx)
			-- range
			elseif string.find(parser_ctx.text, "^[<(][^)>]*:[^)>]*[)>]", parser_ctx.character_index) then
				RangeExpression = RangeExpression or require("language_constructs.range_expression")
				expression = RangeExpression:new()
				expression:parse(parser_ctx)

			-- sub expression
			elseif char == "(" then
				parser_ctx:advance()
				expression = Expression:new()
				expression:parse(parser_ctx, subexpression_terminators)
				parser_ctx:advance()

			-- number literalvector
			elseif helpers.check_if_number(parser_ctx, actual_terminators) then
				NumberLiteral = NumberLiteral or require("language_constructs.number_literal")
				expression = NumberLiteral:new()
				expression:parse(parser_ctx, actual_terminators)

			-- process substitution
			elseif string.find(parser_ctx.text, "^%$%(", parser_ctx.character_index) then
				ProcessSubstitution = ProcessSubstitution or require("language_constructs.process_substitution")
				expression = ProcessSubstitution:new()
				expression:parse(parser_ctx)

			-- parameter expansion
			elseif string.find(parser_ctx.text, "^%$", parser_ctx.character_index) then
				ParameterExpansion = ParameterExpansion or require("language_constructs.parameter_expansion")
				expression = ParameterExpansion:new()
				expression:parse(parser_ctx)

			-- bools (highly inefficient...)
			elseif string.find(parser_ctx.text, "^true", parser_ctx.character_index)
				or string.find(parser_ctx.text, "^false", parser_ctx.character_index) then
				Bool = Bool or require("language_constructs.bool")
				expression = Bool:new()
				expression:parse(parser_ctx)
			-- functions
			elseif string.find(parser_ctx.text, "^function", parser_ctx.character_index) then
				FunctionDefinition = FunctionDefinition or require("language_constructs.function_definition")
				expression = FunctionDefinition:new()
				expression:parse(parser_ctx)

			-- strings
			else
				StringExpression = StringExpression or require("language_constructs.string_expression")
				expression = StringExpression:new()
				-- very special terminators
				if terminator_cache[terminators] then
					terminators = terminator_cache[terminators]
				else
					local tmp = helpers.add_tries(terminators, unquoted_string_terminators)
					terminators = tmp
					terminator_cache[terminators] = tmp
				end
				expression:parse(parser_ctx, terminators)
			end
		end

		return expression
	end

	local function recursive_expression_parse()
		local expression

		local unary_operator_matched, unary_operator = parser_ctx:match(unary_operator_trie, false)

		-- make sure its actually unary
		if unary_operator_matched then
			full_initialize()
			initialized = true
			table.insert(operator_stack, operators[unary_operator])
			parser_ctx:advance(#unary_operator)
			parser_ctx:consume_whitespaces()
		end

		if initialized and operator_stack[#operator_stack].type == "around" then
			-- We are using only the ending operator terminal here, and not the expression terminators, because it dosen't make
			-- sense to exit early from an expression like this (think about, what it means to parse: "<1:(2+3>", The "(" has to
			-- beterminated first)
			local expression_terminators
			local op = operator_stack[#operator_stack]
			if terminator_cache[op.ending_terminal] then
				expression_terminators = terminator_cache[op.ending_terminal]
			else
				local tmp = helpers.add_tries(default_terminators, helpers.create_trie({op.ending_terminal}))
				terminator_cache[op.ending_terminal] = tmp
				expression_terminators = tmp
			end

			expression = Expression:new()
			expression:parse(parser_ctx, expression_terminators)
			parser_ctx:consume_whitespaces()
			parser_ctx:advance(#operator_stack[#operator_stack].ending_terminal)
		else
			expression = parse_expression(terminators)
		end

		if initialized then
			table.insert(expression_stack, expression)
		end

		-- try parsing forward and check if there are operators. Otherwise return to the safe pos
		local last_safe_pos = parser_ctx.character_index

		parser_ctx:consume_whitespaces()

		local matched, operator = parser_ctx:match(operator_trie, false)
		local matched_terminator, _ = parser_ctx:match(terminators, false)

		-- the extra >>not matched_terminator<< check is here for edge cases like <1:2>, so that the second member would
		-- get parsed as a simple 2, instead of trying to parse 2 > (larger than) and look for another expression
		if matched and not matched_terminator then
			full_initialize(expression)
			table.insert(operator_stack, operators[operator])

			parser_ctx:advance(#operator)
			parser_ctx:consume_whitespaces()

			recursive_expression_parse()

			-- if operators[operator].type == "around" then
			-- 	parser_ctx:consume_whitespaces()
			-- 	parser_ctx:expect(operators[operator].ending_terminal)
			-- end
			-- this doesn't matter
			return nil, parser_ctx.character_index
		else
			return expression, last_safe_pos
		end
	end

	local expression, pos = recursive_expression_parse()

	parser_ctx.character_index = pos

	if initialized then
		sortRPN()
	else
		for key, value in pairs(expression) do
			self[key] = value
		end

		setmetatable(self, getmetatable(expression))
	end
end

function Expression:dump(dump_ctx)
	dump_ctx:write_text(dump_ctx:color("(Expression)", "ConstructSpecifier"))
	dump_ctx:new_line()
	dump_ctx:write_text("[")
	dump_ctx:indent(1)
	dump_ctx:new_line()

	for _, v in pairs(self.expression) do
		dump_ctx:new_line()
		if v.instance_of then
			v:dump(dump_ctx)
		else
			dump_ctx:write_text("(Operator) ")
			dump_ctx:write_text(v.name)
		end
	end

	dump_ctx:indent(-1)
	dump_ctx:new_line()
	dump_ctx:write_text("]")
end

return Expression
