/*
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package softmaxpotc

import (
	"context"
	"encoding/json"
	"math"
	"math/rand"
	"strings"
	"sync"
	"testing"

	metav1types "k8s.io/apimachinery/pkg/types"
	fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
	fwkscheduling "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"

	"hermes-router/pkg/epp/internal/pdgroup"
)

func TestPickReturnsSingleCandidateDirectly(t *testing.T) {
	picker := New("picker", Config{Temperature: 1.0})
	only := newScoredEndpoint("only", "10.0.80.2", 0.4)

	result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), []*fwkscheduling.ScoredEndpoint{only})

	assertPickedEndpoint(t, result, only.Endpoint)
}

func TestPickReturnsEmptyResultForEmptyInput(t *testing.T) {
	picker := New("picker", Config{Temperature: 1.0})

	result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), nil)

	if result == nil || len(result.TargetEndpoints) != 0 {
		t.Fatalf("expected empty result, got %+v", result)
	}
}

func TestPickWithTwoCandidatesReturnsHigherScore(t *testing.T) {
	picker := New("picker", Config{Temperature: 1.0})
	picker.randFloat = rand.New(rand.NewSource(1)).Float64
	higher := newScoredEndpoint("higher", "10.0.81.2", 0.9)
	lower := newScoredEndpoint("lower", "10.0.81.3", 0.1)

	result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), []*fwkscheduling.ScoredEndpoint{lower, higher})

	assertPickedEndpoint(t, result, higher.Endpoint)
}

func TestPickClampsNaNAndNegativeScores(t *testing.T) {
	tests := []struct {
		name     string
		badScore float64
	}{
		{name: "nan", badScore: math.NaN()},
		{name: "negative", badScore: -1},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			picker := New("picker", Config{Temperature: 1.0})
			picker.randFloat = rand.New(rand.NewSource(2)).Float64
			good := newScoredEndpoint("good", "10.0.82.2", 0.5)
			bad := newScoredEndpoint("bad", "10.0.82.3", tt.badScore)

			result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), []*fwkscheduling.ScoredEndpoint{good, bad})

			assertPickedEndpoint(t, result, good.Endpoint)
		})
	}
}

func TestFactoryValidatesTemperature(t *testing.T) {
	_, err := Factory("picker", json.RawMessage(`{"temperature":0}`), nil)
	if err == nil || !strings.Contains(err.Error(), "temperature must be greater than 0") {
		t.Fatalf("expected temperature validation error, got %v", err)
	}
}

func TestFactoryRejectsNullParameters(t *testing.T) {
	_, err := Factory("picker", json.RawMessage(`null`), nil)
	if err == nil || !strings.Contains(err.Error(), "explicit null") {
		t.Fatalf("expected explicit null parameters error, got %v", err)
	}
}

func TestLowTemperatureConcentratesOnBestEndpoint(t *testing.T) {
	picker := New("picker", Config{Temperature: 0.01})
	picker.randFloat = rand.New(rand.NewSource(7)).Float64
	best := newScoredEndpoint("best", "10.0.83.2", 0.9)
	middle := newScoredEndpoint("middle", "10.0.83.3", 0.5)
	worst := newScoredEndpoint("worst", "10.0.83.4", 0.1)

	bestCount := 0
	for i := 0; i < 2000; i++ {
		result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), []*fwkscheduling.ScoredEndpoint{best, middle, worst})
		if result != nil && len(result.TargetEndpoints) == 1 && fwkscheduling.EndpointComparer(result.TargetEndpoints[0], best.Endpoint) {
			bestCount++
		}
	}
	if bestCount < 1900 {
		t.Fatalf("expected best endpoint to dominate at low temperature, selected %d times", bestCount)
	}
}

func TestPickReturnsOneEndpointWhenScoresAreEqual(t *testing.T) {
	picker := New("picker", Config{Temperature: 1.0})
	picker.randFloat = rand.New(rand.NewSource(9)).Float64
	left := newScoredEndpoint("left", "10.0.84.2", 0.5)
	right := newScoredEndpoint("right", "10.0.84.3", 0.5)

	result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), []*fwkscheduling.ScoredEndpoint{left, right})

	if result == nil || len(result.TargetEndpoints) != 1 {
		t.Fatalf("expected exactly one endpoint, got %+v", result)
	}
}

func TestHighTemperatureApproachesUniformSoftmaxWeights(t *testing.T) {
	first := newScoredEndpoint("first", "10.0.85.2", 0.9)
	second := newScoredEndpoint("second", "10.0.85.3", 0.5)
	third := newScoredEndpoint("third", "10.0.85.4", 0.1)
	weights := softmaxWeights([]*fwkscheduling.ScoredEndpoint{first, second, third}, 100.0)
	for index, weight := range weights {
		if math.Abs(weight-(1.0/3.0)) > 0.01 {
			t.Fatalf("expected weight %d near 1/3, got %v", index, weight)
		}
	}
}

func TestAllZeroScoresProduceUniformSoftmaxWeights(t *testing.T) {
	first := newScoredEndpoint("first", "10.0.86.2", 0)
	second := newScoredEndpoint("second", "10.0.86.3", 0)
	third := newScoredEndpoint("third", "10.0.86.4", 0)
	weights := softmaxWeights([]*fwkscheduling.ScoredEndpoint{first, second, third}, 1.0)

	for index, weight := range weights {
		if math.Abs(weight-(1.0/3.0)) > 1e-9 {
			t.Fatalf("expected weight %d to equal 1/3, got %v", index, weight)
		}
	}
}

func TestPickDistributionMatchesSoftmaxPOTC(t *testing.T) {
	high := newScoredEndpoint("high", "10.0.90.2", 0.9)
	mid := newScoredEndpoint("mid", "10.0.90.3", 0.5)
	low := newScoredEndpoint("low", "10.0.90.4", 0.1)
	candidates := []*fwkscheduling.ScoredEndpoint{high, mid, low}

	picker := New("picker", Config{Temperature: 1.0})

	counts := map[string]int{"high": 0, "mid": 0, "low": 0}
	const trials = 10000

	for i := 0; i < trials; i++ {
		result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), candidates)
		name := result.TargetEndpoints[0].GetMetadata().PodName
		counts[name]++
	}

	t.Logf("distribution: high=%d mid=%d low=%d", counts["high"], counts["mid"], counts["low"])

	if counts["low"] != 0 {
		t.Fatalf("expected low to never win in POTC comparison, got %d selections", counts["low"])
	}
	if counts["mid"] == 0 {
		t.Fatal("expected mid-score endpoint to be picked at least once (beats low)")
	}
	if counts["high"] < counts["mid"] {
		t.Fatal("expected high > mid in selection count")
	}
}

func TestPickSupportsConcurrentUse(t *testing.T) {
	picker := New("picker", Config{Temperature: 1.0})
	first := newScoredEndpoint("first", "10.0.86.5", 0.9)
	second := newScoredEndpoint("second", "10.0.86.6", 0.5)
	third := newScoredEndpoint("third", "10.0.86.7", 0.1)
	inputs := []*fwkscheduling.ScoredEndpoint{first, second, third}

	var wg sync.WaitGroup
	for i := 0; i < 8; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for j := 0; j < 1000; j++ {
				result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), inputs)
				if result == nil || len(result.TargetEndpoints) != 1 {
					t.Fatalf("expected one target endpoint, got %+v", result)
				}
			}
		}()
	}
	wg.Wait()
}

func TestPickUsesReadyPDLeadersOnly(t *testing.T) {
	picker := New("picker", Config{Temperature: 1.0})
	picker.randFloat = rand.New(rand.NewSource(1)).Float64

	pendingLeader := newScoredEndpoint("pending", "10.0.87.2", 0.99)
	pdgroup.Set(pendingLeader.Endpoint, &pdgroup.Info{
		GroupID:     "group-pending",
		PrefillPods: []fwkscheduling.ScoredEndpoint{{Endpoint: newEndpoint("prefill-pending", "10.0.87.3"), Score: 0.1}},
		DecodePods:  []fwkscheduling.ScoredEndpoint{{Endpoint: newEndpoint("decode-pending", "10.0.87.4"), Score: 0.2}},
	})

	readyLeader := newScoredEndpoint("ready", "10.0.87.5", 0.5)
	pdgroup.Set(readyLeader.Endpoint, &pdgroup.Info{
		GroupID:     "group-ready",
		PrefillPods: []fwkscheduling.ScoredEndpoint{{Endpoint: newEndpoint("prefill-ready", "10.0.87.6"), Score: 0.1}},
		DecodePods:  []fwkscheduling.ScoredEndpoint{{Endpoint: newEndpoint("decode-ready", "10.0.87.7"), Score: 0.2}},
		SelectedPrefillPod: &fwkscheduling.ScoredEndpoint{
			Endpoint: newEndpoint("prefill-selected", "10.0.87.8"),
			Score:    0.1,
		},
		SelectedDecodePod: &fwkscheduling.ScoredEndpoint{
			Endpoint: newEndpoint("decode-selected", "10.0.87.9"),
			Score:    0.2,
		},
	})

	result := picker.Pick(context.Background(), fwkscheduling.NewCycleState(), []*fwkscheduling.ScoredEndpoint{pendingLeader, readyLeader})
	assertPickedEndpoint(t, result, readyLeader.Endpoint)
}

func newScoredEndpoint(name, address string, score float64) *fwkscheduling.ScoredEndpoint {
	return &fwkscheduling.ScoredEndpoint{
		Endpoint: newEndpoint(name, address),
		Score:    score,
	}
}

func newEndpoint(name, address string) fwkscheduling.Endpoint {
	return fwkscheduling.NewEndpoint(
		&fwkdl.EndpointMetadata{
			NamespacedName: metav1types.NamespacedName{Namespace: "default", Name: name},
			PodName:        name,
			Address:        address,
			Port:           "8000",
		},
		fwkdl.NewMetrics(),
		fwkdl.NewAttributes(),
	)
}

func assertPickedEndpoint(t *testing.T, result *fwkscheduling.ProfileRunResult, want fwkscheduling.Endpoint) {
	t.Helper()
	if result == nil || len(result.TargetEndpoints) != 1 {
		t.Fatalf("expected exactly one picked endpoint, got %+v", result)
	}
	if !fwkscheduling.EndpointComparer(result.TargetEndpoints[0], want) {
		t.Fatalf("unexpected picked endpoint: got %v want %v", result.TargetEndpoints[0], want)
	}
}