package websocket

import (
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"

	"github.com/goodrain/rainbond/api/controller"
)

func TestChunkUploadPreflightAllowsConsoleHeaders(t *testing.T) {
	chunkController := &controller.ChunkUploadController{}
	req := httptest.NewRequest(http.MethodOptions, "/component/events/event-id/upload/init", nil)
	req.Header.Set("Origin", "http://console.example")
	req.Header.Set("Access-Control-Request-Method", http.MethodPost)
	req.Header.Set("Access-Control-Request-Headers", "content-type,authorization,x-team-name,x-region-name,x-requested-with")
	rec := httptest.NewRecorder()

	chunkController.HandleOptions(rec, req)

	if rec.Code != http.StatusOK {
		t.Fatalf("expected status 200, got %d", rec.Code)
	}
	if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://console.example" {
		t.Fatalf("expected allow origin to echo request origin, got %q", got)
	}

	allowedHeaders := rec.Header().Get("Access-Control-Allow-Headers")
	for _, header := range []string{"content-type", "authorization", "x-team-name", "x-region-name", "x-requested-with"} {
		if !containsHeaderToken(allowedHeaders, header) {
			t.Fatalf("expected Access-Control-Allow-Headers %q to contain %q", allowedHeaders, header)
		}
	}
}

func TestPackageBuildCORSMiddlewareHandlesUnknownRoutePreflight(t *testing.T) {
	handler := packageBuildCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		http.NotFound(w, r)
	}))
	req := httptest.NewRequest(http.MethodOptions, "/component/events/event-id/package_build/component/events/event-id/upload/init", nil)
	req.Header.Set("Origin", "http://console.example")
	req.Header.Set("Access-Control-Request-Method", http.MethodPost)
	rec := httptest.NewRecorder()

	handler.ServeHTTP(rec, req)

	if rec.Code != http.StatusOK {
		t.Fatalf("expected unknown-route preflight status 200, got %d", rec.Code)
	}
	if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://console.example" {
		t.Fatalf("expected CORS allow origin on unknown-route preflight, got %q", got)
	}
}

func containsHeaderToken(headerValue, target string) bool {
	for _, token := range strings.Split(headerValue, ",") {
		if strings.EqualFold(strings.TrimSpace(token), target) {
			return true
		}
	}
	return false
}