local sq = squill._internal
local format = string.format
local reserved_words = sq.reserved_words
local valid_identifier = sq.valid_identifier

return function(self)
    self:assert(not reserved_words[self:peek()], "UPDATE OR is not supported")

    local table_name = self:next()
    self:assert(valid_identifier(table_name), "%q is not a valid table name", table_name)
    self:expect("set")

    self.table_lookup = {[table_name] = table_name}
    local modified_columns = {}
    repeat
        local col_name = self:next()
        local ref, entire_column_ref = self:get_col_ref(table_name, col_name)
        self:assert(self:next() == "=")
        modified_columns[col_name] = {
            ref = ref,
            entire_column_ref = entire_column_ref,
            expr = self:expr_to_lua(self:parse_expr())
        }
    until not self:pop_if_equals(",")

    self:assert(self:peek(),
        "UPDATE statements without a WHERE condition are not allowed.")
    self:expect("where")

    local where_clause = self:parse_expr()

    local returned_columns = {affected_rows = true}
    local returning = false
    local return_exprs_lua = {}
    if self:pop_if_equals("returning") then
        repeat
            local name = sq.parse_select_result_column(self, return_exprs_lua,
                returned_columns)
            return_exprs_lua[name] = self:expr_to_lua(return_exprs_lua[name])
        until not self:pop_if_equals(",")
        returning = true
    end

    -- update_row() must be called before insert_var_refs so it can reference
    -- new columns
    local update_code, checks, checked_uniques, index_updates = {}, {}, {}, {}
    for col_name, col in pairs(modified_columns) do
        self:update_row(update_code, checks, table_name, col_name, col.expr,
            col.ref, col.entire_column_ref, true, checked_uniques, false, nil,
            index_updates, true)
    end

    -- Use the join API as it can do some optimisations
    local ordered_joins = self:get_optimal_join_order({
        {
            table = table_name,
            alias = table_name,
            cond = where_clause,
        }
    })

    local code = {}

    -- Generate argument list
    code[#code + 1] = self:create_function_def()
    self:insert_var_refs(code, false, modified_columns)

    if returning then
        code[#code + 1] = "local res = {}"
    end

    code[#code + 1] = "local rows_updated = 0"

    local ends = self:write_joins(code, ordered_joins)

    table.insert_all(code, update_code)
    table.insert_all(code, checks)
    code[#code + 1] = "rows_updated = rows_updated + 1"
    table.insert_all(code, index_updates)

    if returning then
        code[#code + 1] = "res[#res + 1] = {"
        for name, expr in pairs(return_exprs_lua) do
            code[#code + 1] = format("[%q] = %s,", name, expr)
        end
        code[#code + 1] = "}"
    end

    for _ = 1, ends do
        code[#code + 1] = "end"
    end

    code[#code + 1] = "if rows_updated > 0 then"
    self:add_set_columns(code, modified_columns, "true")
    code[#code + 1] = "end"

    if returning then
        code[#code + 1] = "res.affected_rows = rows_updated"
        code[#code + 1] = "return res"
    else
        code[#code + 1] = "return {affected_rows = rows_updated}"
    end

    code[#code + 1] = "end"

    return self:compile_lua(table.concat(code, "\n")), returned_columns
end
