/*
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package warmupjob

import (
	"context"
	"fmt"
	"path/filepath"
	"testing"

	warmupv1alpha1 "github.com/openfuyao/weight-dispatcher/api/v1alpha1"
	"github.com/openfuyao/weight-dispatcher/pkg/node"
	"github.com/openfuyao/weight-dispatcher/pkg/planning/transferplanner"
	sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/runtime"
	"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

func TestDeriveCollectiveStagingPath(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name      string
		readyPath string
		want      string
	}{
		{
			name:      "artifact root path",
			readyPath: "/var/lib/weight-dispatcher/cache/ready/perf-qwen3-8b-fanout-m0-n0n2-v15a",
			want:      filepath.Join(string(filepath.Separator), "var", "lib", "weight-dispatcher", "cache", "ready", ".staging", "perf-qwen3-8b-fanout-m0-n0n2-v15a"),
		},
		{
			name:      "nested target path resolves to artifact staging root",
			readyPath: "/dev/shm/fanout-job/fanout-rdma-v14a-blob4g",
			want:      filepath.Join(string(filepath.Separator), "dev", "shm", "fanout-job", ".staging", "fanout-rdma-v14a-blob4g"),
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()
			got := deriveCollectiveStagingPath(tt.readyPath)
			if got != tt.want {
				t.Fatalf("expected staging path %q, got %q", tt.want, got)
			}
		})
	}
}

func TestAlignSourcesToManifestRewritesSingleFileRoots(t *testing.T) {
	t.Parallel()

	sources := []warmupv1alpha1.SourceSpec{
		{
			SourceType: "node",
			NodeName:   "m0",
			Path:       "/models/qwen/model-00001-of-00005.safetensors",
		},
		{
			SourceType: "node",
			NodeName:   "n2",
			Path:       "/models/qwen/model-00001-of-00005.safetensors",
		},
	}
	manifest := sharedtypes.LogicalManifest{
		RootPath: "/models/qwen",
		Files: []sharedtypes.ArtifactFile{{
			RelativePath: "model-00001-of-00005.safetensors",
		}},
	}

	aligned := alignSourcesToManifest(sources, manifest)
	for _, source := range aligned {
		if source.Path != "/models/qwen" {
			t.Fatalf("expected aligned source root %q, got %q", "/models/qwen", source.Path)
		}
	}
}

func TestAlignSourcesToManifestRewritesHFHubRootsToResolvedSnapshotRoot(t *testing.T) {
	t.Parallel()

	sources := []warmupv1alpha1.SourceSpec{
		{
			SourceType: "node",
			NodeName:   "m0",
			Path:       "/home/llm_cache/huggingface/hub/models--Qwen--Qwen3-32B",
		},
		{
			SourceType: "node",
			NodeName:   "n2",
			Path:       "/home/llm_cache/huggingface/hub/models--Qwen--Qwen3-32B",
		},
	}
	manifest := sharedtypes.LogicalManifest{
		RootPath: "/home/llm_cache/huggingface/hub/models--Qwen--Qwen3-32B/snapshots/rev-123",
		Files: []sharedtypes.ArtifactFile{
			{RelativePath: "config.json"},
			{RelativePath: "model-00001-of-00017.safetensors"},
		},
	}

	aligned := alignSourcesToManifest(sources, manifest)
	for _, source := range aligned {
		if source.Path != manifest.RootPath {
			t.Fatalf("expected aligned source root %q, got %q", manifest.RootPath, source.Path)
		}
	}
}

func TestResolveAndValidateSourceManifestsAlignsPerSourceRootsAndRejectsMismatches(t *testing.T) {
	t.Parallel()

	scheme := runtime.NewScheme()
	if err := corev1.AddToScheme(scheme); err != nil {
		t.Fatalf("add core scheme: %v", err)
	}

	nodes := []corev1.Node{
		{
			ObjectMeta: metav1.ObjectMeta{Name: "m0"},
			Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{
				Type:    corev1.NodeInternalIP,
				Address: "192.168.1.10",
			}}},
		},
		{
			ObjectMeta: metav1.ObjectMeta{Name: "n0"},
			Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{
				Type:    corev1.NodeInternalIP,
				Address: "192.168.1.20",
			}}},
		},
	}
	client := fake.NewClientBuilder().WithScheme(scheme).WithObjects(&nodes[0], &nodes[1]).Build()
	reconciler := &Reconciler{
		Resolver: node.NewResolver(client),
		Agent: fakeManifestAgent{
			responses: map[string]sharedtypes.BuildManifestResponse{
				"m0": {
					Manifest: sharedtypes.LogicalManifest{
						RootPath: "/home/llm_cache/huggingface/hub/models--Qwen--Qwen3-32B/snapshots/rev-a",
						Files: []sharedtypes.ArtifactFile{
							{RelativePath: "config.json", SizeBytes: 128, Kind: sharedtypes.ArtifactFileKindJSON, Required: true},
							{RelativePath: "model-00001-of-00017.safetensors", SizeBytes: 4096, Kind: sharedtypes.ArtifactFileKindSafeTensors, Chunkable: true, Required: true},
						},
					},
				},
				"n0": {
					Manifest: sharedtypes.LogicalManifest{
						RootPath: "/mnt/cache/qwen32b/snapshots/rev-a",
						Files: []sharedtypes.ArtifactFile{
							{RelativePath: "config.json", SizeBytes: 128, Kind: sharedtypes.ArtifactFileKindJSON, Required: true},
							{RelativePath: "model-00001-of-00017.safetensors", SizeBytes: 4096, Kind: sharedtypes.ArtifactFileKindSafeTensors, Chunkable: true, Required: true},
						},
					},
				},
			},
		},
	}

	spec := warmupv1alpha1.ModelWarmupJobSpec{
		Artifact: warmupv1alpha1.ArtifactRefSpec{Key: "qwen32b"},
		Policy:   warmupv1alpha1.PolicySpec{ChunkSizeMB: 64},
	}
	sources := []warmupv1alpha1.SourceSpec{
		{SourceType: "node", NodeName: "m0", Endpoint: "192.168.1.10", Path: "/home/llm_cache/Qwen32B"},
		{SourceType: "node", NodeName: "n0", Endpoint: "192.168.1.20", Path: "/home/llm_cache/Qwen32B"},
	}
	manifest := sharedtypes.LogicalManifest{
		RootPath: "/home/llm_cache/huggingface/hub/models--Qwen--Qwen3-32B/snapshots/rev-a",
		Files: []sharedtypes.ArtifactFile{
			{RelativePath: "config.json", SizeBytes: 128, Kind: sharedtypes.ArtifactFileKindJSON, Required: true},
			{RelativePath: "model-00001-of-00017.safetensors", SizeBytes: 4096, Kind: sharedtypes.ArtifactFileKindSafeTensors, Chunkable: true, Required: true},
		},
	}

	aligned, err := reconciler.resolveAndValidateSourceManifests(context.Background(), sources, spec, manifest)
	if err != nil {
		t.Fatalf("resolve and validate manifests: %v", err)
	}
	if aligned[0].Path != "/home/llm_cache/huggingface/hub/models--Qwen--Qwen3-32B/snapshots/rev-a" {
		t.Fatalf("expected m0 path to align to manifest root, got %q", aligned[0].Path)
	}
	if aligned[1].Path != "/mnt/cache/qwen32b/snapshots/rev-a" {
		t.Fatalf("expected n0 path to align to source-specific root, got %q", aligned[1].Path)
	}

	reconciler.Agent = fakeManifestAgent{
		responses: map[string]sharedtypes.BuildManifestResponse{
			"m0": {Manifest: manifest},
			"n0": {
				Manifest: sharedtypes.LogicalManifest{
					RootPath: "/mnt/cache/qwen32b/snapshots/rev-b",
					Files: []sharedtypes.ArtifactFile{
						{RelativePath: "config.json", SizeBytes: 128, Kind: sharedtypes.ArtifactFileKindJSON, Required: true},
					},
				},
			},
		},
	}
	if _, err := reconciler.resolveAndValidateSourceManifests(context.Background(), sources, spec, manifest); err == nil {
		t.Fatalf("expected manifest mismatch error")
	}
}

type fakeManifestAgent struct {
	responses map[string]sharedtypes.BuildManifestResponse
}

func (fakeManifestAgent) SubmitWarmup(context.Context, corev1.Node, sharedtypes.SubmitWarmupRequest) (sharedtypes.TaskHandle, error) {
	return sharedtypes.TaskHandle{}, fmt.Errorf("not implemented")
}

func (fakeManifestAgent) GetWarmupTaskStatus(context.Context, corev1.Node, sharedtypes.GetWarmupTaskStatusRequest) (sharedtypes.TaskStatus, error) {
	return sharedtypes.TaskStatus{}, fmt.Errorf("not implemented")
}

func (f fakeManifestAgent) BuildManifest(_ context.Context, node corev1.Node, _ sharedtypes.BuildManifestRequest) (sharedtypes.BuildManifestResponse, error) {
	resp, ok := f.responses[node.Name]
	if !ok {
		return sharedtypes.BuildManifestResponse{}, fmt.Errorf("no manifest for node %s", node.Name)
	}
	return resp, nil
}

func (fakeManifestAgent) OpenCollective(context.Context, corev1.Node, sharedtypes.OpenCollectiveRequest) (sharedtypes.OpenCollectiveResponse, error) {
	return sharedtypes.OpenCollectiveResponse{}, fmt.Errorf("not implemented")
}

func (fakeManifestAgent) StepCollective(context.Context, corev1.Node, sharedtypes.CollectiveStepRequest) (sharedtypes.CollectiveStepResponse, error) {
	return sharedtypes.CollectiveStepResponse{}, fmt.Errorf("not implemented")
}

func (fakeManifestAgent) CompleteCollective(context.Context, corev1.Node, sharedtypes.CompleteCollectiveRequest) error {
	return fmt.Errorf("not implemented")
}

func TestEnrichCollectivePlanBuildsSymmetricTwoTargetFanoutPeers(t *testing.T) {
	t.Parallel()

	reconciler := &Reconciler{}
	cachePlan := transferPlanFixture()
	nodes := []corev1.Node{
		{
			ObjectMeta: metav1.ObjectMeta{Name: "node-a"},
			Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{
				Type:    corev1.NodeInternalIP,
				Address: "10.0.0.1",
			}}},
		},
		{
			ObjectMeta: metav1.ObjectMeta{Name: "node-b"},
			Status: corev1.NodeStatus{Addresses: []corev1.NodeAddress{{
				Type:    corev1.NodeInternalIP,
				Address: "10.0.0.2",
			}}},
		},
	}

	if err := reconciler.enrichCollectivePlan(context.TODO(), nodes, &cachePlan); err != nil {
		t.Fatalf("enrichCollectivePlan returned error: %v", err)
	}
	if len(cachePlan.NodeIntents) != 2 {
		t.Fatalf("expected 2 intents, got %d", len(cachePlan.NodeIntents))
	}
	if got := len(cachePlan.NodeIntents[0].Collective.Peers); got != 2 {
		t.Fatalf("expected 2 collective peers on relay root, got %d", got)
	}
	if got := len(cachePlan.NodeIntents[1].Collective.Peers); got != 2 {
		t.Fatalf("expected 2 collective peers on passive peer, got %d", got)
	}
	ownedByNode := map[string]int{}
	for _, peer := range cachePlan.NodeIntents[1].Collective.Peers {
		ownedByNode[peer.NodeName] = len(peer.OwnedRanges)
		expectedStagingPath := filepath.Join(string(filepath.Separator), "var", "lib", "weight-dispatcher", "cache", "ready", ".staging", "artifact-a")
		if peer.NodeName == "node-a" && peer.StagingPath != expectedStagingPath {
			t.Fatalf("expected relay root staging path %q, got %q", expectedStagingPath, peer.StagingPath)
		}
	}
	if ownedByNode["node-a"] == 0 || ownedByNode["node-b"] == 0 {
		t.Fatalf("expected both peers to keep owned ranges, got %#v", ownedByNode)
	}
}

func TestCollectiveCompletionTargetsWaitForAllPeers(t *testing.T) {
	t.Parallel()

	cachePlan := transferPlanFixture()
	states := []warmupv1alpha1.WarmupNodeState{
		{
			NodeName: "node-a",
			TaskID:   "plan-node-a",
			Phase:    warmupv1alpha1.JobPhaseSucceeded,
			Message:  "root finished",
		},
		{
			NodeName: "node-b",
			TaskID:   "plan-node-b",
			Phase:    warmupv1alpha1.JobPhaseRunning,
			Message:  "peer still fetching",
		},
	}

	if targets := collectiveCompletionTargets(cachePlan, states); len(targets) != 0 {
		t.Fatalf("expected no collective cleanup before all peers finish, got %d targets", len(targets))
	}

	states[1].Phase = warmupv1alpha1.JobPhaseFailed
	states[1].Message = "peer failed"
	targets := collectiveCompletionTargets(cachePlan, states)
	if len(targets) != 2 {
		t.Fatalf("expected 2 collective cleanup targets after all peers finish, got %d", len(targets))
	}
	if targets[0].nodeName != "node-a" || !targets[0].success {
		t.Fatalf("expected root cleanup target to preserve success state, got %#v", targets[0])
	}
	if want := filepath.Join(string(filepath.Separator), "var", "lib", "weight-dispatcher", "cache", "ready", ".staging", "artifact-a"); targets[0].stagingPath != want {
		t.Fatalf("expected root cleanup target staging path %q, got %q", want, targets[0].stagingPath)
	}
	if targets[1].nodeName != "node-b" || targets[1].success {
		t.Fatalf("expected peer cleanup target to preserve failure state, got %#v", targets[1])
	}
}

func transferPlanFixture() transferplanner.CacheBuildPlan {
	return transferplanner.CacheBuildPlan{
		NodeIntents: []transferplanner.WarmupNodeIntent{
			{
				TargetNode: "node-a",
				TargetPath: "/var/lib/weight-dispatcher/cache/ready/artifact-a",
				TargetPlan: sharedtypes.TargetTransferPlan{
					TransferMode: sharedtypes.TransferModePartialPullAllGather,
					SourceSegments: []sharedtypes.SourceSegmentPlan{{
						SourceID: "source-a",
						ByteRanges: []sharedtypes.ByteRange{
							{RelativePath: "weights.bin", Start: 0, End: 13},
						},
					}},
				},
				Collective: sharedtypes.CollectiveSpec{
					Mode: sharedtypes.CollectiveModeRing,
					Ring: &sharedtypes.RingPeerPlan{SelfNode: "node-a", PrevNode: "node-b", NextNode: "node-b", Rank: 0, WorldSize: 2},
				},
			},
			{
				TargetNode: "node-b",
				TargetPath: "/var/lib/weight-dispatcher/cache/ready/artifact-b",
				TargetPlan: sharedtypes.TargetTransferPlan{
					TransferMode: sharedtypes.TransferModePartialPullAllGather,
					SourceSegments: []sharedtypes.SourceSegmentPlan{{
						SourceID: "source-a",
						ByteRanges: []sharedtypes.ByteRange{
							{RelativePath: "weights.bin", Start: 13, End: 26},
						},
					}},
				},
				Collective: sharedtypes.CollectiveSpec{
					Mode: sharedtypes.CollectiveModeRing,
					Ring: &sharedtypes.RingPeerPlan{SelfNode: "node-b", PrevNode: "node-a", NextNode: "node-a", Rank: 1, WorldSize: 2},
				},
			},
		},
	}
}