package main

import (
	"context"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"math/rand"
	"net/http"
	"sync"
	"time"

	"nhooyr.io/websocket"
)

type Session struct {
	ID         string
	WS         *websocket.Conn
	Outbox     [][]byte
	OutboxLock sync.Mutex
	NewMessage chan struct{}
	LastActive time.Time
	Target     string
	Closed     bool
}

var (
	sessions = struct {
		sync.RWMutex
		m map[string]*Session
	}{m: make(map[string]*Session)}
)

func main() {
	rand.Seed(time.Now().UnixNano())
	go cleanupSessions()

	http.HandleFunc("/connect", handleConnect)
	http.HandleFunc("/send", handleSend)
	http.HandleFunc("/poll", handlePoll)
	http.HandleFunc("/disconnect", handleDisconnect)

	log.Println("Proxy listening on :9839")
	log.Fatal(http.ListenAndServe(":9839", nil))
}

func handleConnect(w http.ResponseWriter, r *http.Request) {
	target := r.URL.Query().Get("target")
	if target == "" {
		log.Printf("[CONNECT] ERROR: missing target parameter from %s", r.RemoteAddr)
		http.Error(w, "missing target parameter", 400)
		return
	}

	log.Printf("[CONNECT] Attempting to connect to %s from %s", target, r.RemoteAddr)

	ctx := r.Context()
	ws, resp, err := websocket.Dial(ctx, target, &websocket.DialOptions{
		CompressionMode: websocket.CompressionContextTakeover,
	})
	if err != nil {
		log.Printf("[CONNECT] ERROR: failed to connect to %s: %v (resp: %+v)", target, err, resp)
		http.Error(w, "failed to connect to upstream WS", 500)
		return
	}

	// Set read limit to 16MB to handle larger messages
	ws.SetReadLimit(1 << 24)

	id := fmt.Sprintf("%08x", rand.Uint32())
	sess := &Session{
		ID:         id,
		WS:         ws,
		Outbox:     make([][]byte, 0),
		NewMessage: make(chan struct{}, 1),
		LastActive: time.Now(),
		Target:     target,
	}

	sessions.Lock()
	sessions.m[id] = sess
	sessions.Unlock()

	log.Printf("[CONNECT] SUCCESS: session=%s target=%s client=%s", id, target, r.RemoteAddr)

	// Start background reader
	go readFromWebSocket(sess)

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(map[string]string{"session": id})
}

func handleSend(w http.ResponseWriter, r *http.Request) {
	sid := r.URL.Query().Get("session")
	if sid == "" {
		log.Printf("[SEND] ERROR: missing session parameter from %s", r.RemoteAddr)
		http.Error(w, "missing session", 400)
		return
	}
	body, err := io.ReadAll(r.Body)
	if err != nil {
		log.Printf("[SEND] ERROR: session=%s failed to read body: %v", sid, err)
		http.Error(w, "failed to read body", 400)
		return
	}

	sess := getSession(sid)
	if sess == nil || sess.Closed {
		log.Printf("[SEND] ERROR: session=%s not found or closed", sid)
		http.Error(w, "invalid session", 404)
		return
	}
	sess.LastActive = time.Now()

	log.Printf("[SEND] session=%s target=%s bytes=%d", sid, sess.Target, len(body))

	// Forward to WebSocket
	err = sess.WS.Write(r.Context(), websocket.MessageText, body)
	if err != nil {
		log.Printf("[SEND] ERROR: session=%s write failed: %v", sid, err)
		http.Error(w, "send failed", 500)
		return
	}
	w.WriteHeader(204)
}

func handlePoll(w http.ResponseWriter, r *http.Request) {
	sid := r.URL.Query().Get("session")
	sess := getSession(sid)
	if sess == nil {
		log.Printf("[POLL] ERROR: session=%s not found", sid)
		http.Error(w, "invalid session", 404)
		return
	}

	log.Printf("[POLL] session=%s waiting for messages (closed=%v)", sid, sess.Closed)

	// Wait for messages or timeout
	select {
	case <-sess.NewMessage:
		// Messages available, fall through to collect them
	case <-time.After(20 * time.Second):
		log.Printf("[POLL] session=%s timeout (no messages)", sid)
		// Check if connection closed during wait
		sess.OutboxLock.Lock()
		closed := sess.Closed
		hasMessages := len(sess.Outbox) > 0
		sess.OutboxLock.Unlock()

		if closed && !hasMessages {
			log.Printf("[POLL] session=%s was closed during timeout, removing", sid)
			removeSession(sess.ID)
		}
		w.WriteHeader(204) // no content
		return
	}

	// Collect all pending messages
	sess.OutboxLock.Lock()
	if len(sess.Outbox) == 0 {
		// False alarm (race condition - shouldn't happen often, but handle gracefully)
		closed := sess.Closed
		sess.OutboxLock.Unlock()

		log.Printf("[POLL] session=%s signaled but no messages", sid)
		if closed {
			log.Printf("[POLL] session=%s was closed, removing", sid)
			removeSession(sess.ID)
		}
		w.WriteHeader(204)
		return
	}

	messages := sess.Outbox
	sess.Outbox = make([][]byte, 0)
	closed := sess.Closed
	sess.OutboxLock.Unlock()

	log.Printf("[POLL] session=%s returning %d messages", sid, len(messages))

	// Convert to string array for JSON response
	messageStrings := make([]string, len(messages))
	for i, msg := range messages {
		messageStrings[i] = string(msg)
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(map[string]interface{}{
		"messages": messageStrings,
	})
	sess.LastActive = time.Now()

	// If connection is closed and outbox is now empty, clean up
	if closed {
		log.Printf("[POLL] session=%s drained, removing", sid)
		removeSession(sess.ID)
	}
}

func handleDisconnect(w http.ResponseWriter, r *http.Request) {
	sid := r.URL.Query().Get("session")
	sess := popSession(sid)
	if sess == nil {
		log.Printf("[DISCONNECT] session=%s not found (already closed?)", sid)
		w.WriteHeader(204)
		return
	}
	log.Printf("[DISCONNECT] session=%s target=%s", sid, sess.Target)
	sess.WS.Close(websocket.StatusNormalClosure, "client disconnected")
	w.WriteHeader(204)
}

func readFromWebSocket(sess *Session) {
	log.Printf("[READER] started for session=%s target=%s", sess.ID, sess.Target)
	ctx := context.Background()
	for {
		_, data, err := sess.WS.Read(ctx)
		if err != nil {
			log.Printf("[READER] session=%s read error: %v", sess.ID, err)
			sess.OutboxLock.Lock()
			sess.Closed = true
			messageCount := len(sess.Outbox)
			sess.OutboxLock.Unlock()

			log.Printf("[READER] session=%s marked closed, outbox has %d messages", sess.ID, messageCount)

			// Signal if there are messages to drain, or if we need to notify of closure
			if messageCount > 0 {
				select {
				case sess.NewMessage <- struct{}{}:
				default:
				}
			}
			return
		}
		log.Printf("[READER] session=%s received bytes=%d", sess.ID, len(data))

		sess.OutboxLock.Lock()
		sess.Outbox = append(sess.Outbox, data)
		sess.OutboxLock.Unlock()

		// Signal that a message is available (non-blocking)
		select {
		case sess.NewMessage <- struct{}{}:
		default:
			// Already signaled, no need to signal again
		}
	}
}

func cleanupSessions() {
	ticker := time.NewTicker(60 * time.Second)
	for range ticker.C {
		sessions.Lock()
		var cleaned []string
		for id, s := range sessions.m {
			if time.Since(s.LastActive) > 90*time.Second {
				s.WS.Close(websocket.StatusNormalClosure, "timeout")
				delete(sessions.m, id)
				cleaned = append(cleaned, id)
			}
		}
		sessions.Unlock()
		if len(cleaned) > 0 {
			log.Printf("[CLEANUP] closed %d inactive sessions: %v", len(cleaned), cleaned)
		}
	}
}

func getSession(id string) *Session {
	sessions.RLock()
	defer sessions.RUnlock()
	return sessions.m[id]
}

func popSession(id string) *Session {
	sessions.Lock()
	defer sessions.Unlock()
	sess := sessions.m[id]
	delete(sessions.m, id)
	return sess
}

func removeSession(id string) {
	sessions.Lock()
	delete(sessions.m, id)
	sessions.Unlock()
}
