local sq = squill._internal
local reserved_words = sq.reserved_words
local valid_identifier = sq.valid_identifier

return function(parser)
    parser:assert(not reserved_words[parser:peek()], "UPDATE OR is not supported")

    local table_name = parser:next()
    parser:assert(valid_identifier(table_name), "%q is not a valid table name", table_name)
    parser:expect("set")

    local env = sq.Env.new(parser)
    env.table_lookup[table_name] = table_name
    local modified_columns = {}
    repeat
        local col_name = parser:next()
        local ref, entire_column_ref = env:get_col_ref(table_name, col_name)
        parser:assert(parser:next() == "=")
        modified_columns[col_name] = {
            ref = ref,
            entire_column_ref = entire_column_ref,
            expr = env:expr_to_lua(parser:parse_expr())
        }
    until not parser:pop_if_equals(",")

    parser:assert(parser:peek(),
        "UPDATE statements without a WHERE condition are not allowed.")
    parser:expect("where")

    local where_clause = parser:parse_expr()

    local returned_columns = {affected_rows = true}
    local returning = false
    local return_exprs_lua = {}
    if parser:pop_if_equals("returning") then
        repeat
            local name = sq.parse_select_result_column(parser, return_exprs_lua,
                returned_columns)
            return_exprs_lua[name] = env:expr_to_lua(return_exprs_lua[name])
        until not parser:pop_if_equals(",")
        returning = true
    end

    -- update_row() must be called before insert_var_refs so it can reference
    -- new columns
    local update_code = sq.CodeBuf.new()
    local checks = sq.CodeBuf.new()
    local checked_uniques = sq.CodeBuf.new()
    local index_updates = sq.CodeBuf.new()
    for col_name, col in pairs(modified_columns) do
        env:update_row(update_code, checks, table_name, col_name, col.expr,
            col.ref, col.entire_column_ref, true, checked_uniques, false, nil,
            index_updates, true)
    end

    -- Use the join API as it can do some optimisations
    local ordered_joins = env:get_optimal_join_order({
        {
            table = table_name,
            alias = table_name,
            cond = where_clause,
        }
    })

    -- Generate argument list
    local code = sq.CodeBuf.new()
    env:insert_var_refs(code, false, modified_columns)

    if returning then
        code:put("local res = {}")
    end

    code:put("local rows_updated = 0")

    local ends = env:write_joins(code, ordered_joins)

    code:insert_all(update_code)
    code:insert_all(checks)
    code:put("rows_updated = rows_updated + 1")
    code:insert_all(index_updates)

    if returning then
        code:put("res[#res + 1] = {")
        for name, expr in pairs(return_exprs_lua) do
            code:putf("[%q] = %s,", name, expr)
        end
        code:put("}")
    end

    for _ = 1, ends do
        code:put("end")
    end

    code:put("if rows_updated > 0 then")
    env:add_set_columns(code, modified_columns, "true")
    code:put("end")

    if returning then
        code:put("res.affected_rows = rows_updated")
        code:put("return res")
    else
        code:put("return {affected_rows = rows_updated}")
    end

    return code:compile(parser), returned_columns
end
