/*
Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
*/

/*
Package controllers is using for reconcile AscendJob.
*/

package v1

import (
	"fmt"
	"strconv"
	"strings"

	commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
	corev1 "k8s.io/api/core/v1"
	"k8s.io/apimachinery/pkg/util/sets"

	"ascend-common/api"
	"ascend-common/common-utils/hwlog"
	mindxdlv1 "ascend-operator/pkg/api/v1"
	"ascend-operator/pkg/utils"
)

const (
	logEnvPattern = "set pod<%s> env: %v"

	taskIDEnvKey  = "MINDX_TASK_ID"
	appTypeEnvKey = "APP_TYPE"
)

func addEnvValue(pod *corev1.PodTemplateSpec, envKey, envValue string, index int) {
	pod.Spec.Containers[index].Env = append(pod.Spec.Containers[index].Env, corev1.EnvVar{
		Name:  envKey,
		Value: envValue,
	})
}

// addEnvValue adds or updates an environment variable with deduplication.
// For HighAvailableEnv and MsRecoverEnv, it merges values and deduplicates them.
func addEnvValueWithDedup(pod *corev1.PodTemplateSpec, envKey, envValue string, index int) {
	envs := &pod.Spec.Containers[index].Env

	// Check if the environment variable already exists
	for i := range *envs {
		if (*envs)[i].Name == envKey {
			// Handle special environment variables that need value merging
			if envKey == api.HighAvailableEnv || envKey == api.MsRecoverEnv {
				oldValue := (*envs)[i].Value
				mergedValue := mergeEnvValue(envKey, oldValue, envValue)
				(*envs)[i].Value = mergedValue
			} else {
				// For other environment variables, simply overwrite
				(*envs)[i].Value = envValue
			}
			return
		}
	}

	// If not found, append the new environment variable
	*envs = append(*envs, corev1.EnvVar{
		Name:  envKey,
		Value: envValue,
	})
}

// mergeEnvValue merges values for special environment variables and deduplicates them.
func mergeEnvValue(envKey, oldValue, newValue string) string {
	if envKey == api.HighAvailableEnv {
		// HighAvailableEnv: comma-separated values, merge and deduplicate
		return mergeCommaSeparatedValues(oldValue, newValue)
	}

	if envKey == api.MsRecoverEnv {
		// MsRecoverEnv: format is {...}, extract content, merge and deduplicate
		oldContent := extractMsRecoverContent(oldValue)
		newContent := extractMsRecoverContent(newValue)
		mergedContent := mergeCommaSeparatedValues(oldContent, newContent)
		return `{` + mergedContent + `}`
	}

	return newValue
}

// mergeCommaSeparatedValues merges two comma-separated value strings and deduplicates them.
func mergeCommaSeparatedValues(oldValue, newValue string) string {
	mergedSet := parseCommaSeparatedValues(oldValue)
	newSet := parseCommaSeparatedValues(newValue)
	mergedSet.Insert(newSet.List()...)
	return strings.Join(mergedSet.List(), ",")
}

// parseCommaSeparatedValues parses a comma-separated value string into a set of trimmed, non-empty values.
func parseCommaSeparatedValues(value string) sets.String {
	result := sets.NewString()
	if value == "" {
		return result
	}
	for _, v := range strings.Split(value, ",") {
		if trimmed := strings.TrimSpace(v); trimmed != "" {
			result.Insert(trimmed)
		}
	}
	return result
}

// extractMsRecoverContent extracts the content from MsRecoverEnv format '{...}'
func extractMsRecoverContent(value string) string {
	if value == "" {
		return ""
	}
	// Remove leading '{ and trailing }'
	value = strings.TrimSpace(value)
	const num1 = 1
	if strings.HasPrefix(value, `{`) {
		value = value[num1:]
	}
	if strings.HasSuffix(value, `}`) {
		value = value[:len(value)-num1]
	}
	return strings.TrimSpace(value)
}

func addEnvValueForSoftStrategy(pod *corev1.PodTemplateSpec, envKey string, index int) {
	pod.Spec.Containers[index].Env = append(pod.Spec.Containers[index].Env, corev1.EnvVar{
		Name:      envKey,
		ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: utils.SuperPodEnvPath}},
	})
}

// isVirtualResourceReq return true when pod request virtual resource, otherwise return false
func (r *ASJobReconciler) isVirtualResourceReq(requests *corev1.ResourceList) bool {
	if requests == nil {
		return false
	}
	nonVirtualResources := map[corev1.ResourceName]struct{}{
		api.HuaweiNPU:        {},
		api.HuaweiAscend310:  {},
		api.HuaweiAscend310P: {},
		api.HuaweiAscend910:  {},
	}
	for name := range *requests {
		if _, ok := nonVirtualResources[name]; !ok {
			hwlog.RunLog.Debugf("virtual resource name detected: %s", name)
			return true
		}
	}
	return false
}

func (r *ASJobReconciler) setInferEnv(pi *podInfo, podTemplate *corev1.PodTemplateSpec) {
	for i := range podTemplate.Spec.Containers {
		if podTemplate.Spec.Containers[i].Name != api.DefaultContainerName {
			continue
		}
		if len(podTemplate.Spec.Containers[i].Env) == 0 {
			podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
		}
		addEnvValue(podTemplate, taskIDEnvKey, pi.job.Labels[mindxdlv1.JobIdLabelKey], i)
		addEnvValue(podTemplate, appTypeEnvKey, pi.job.Labels[mindxdlv1.AppLabelKey], i)
		addEnvValue(podTemplate, mindxServerIPEnv, pi.clusterdSvcIp, i)
		addEnvValue(podTemplate, mindxServerDomainEnv, mindxDefaultServerDomain, i)
		hwlog.RunLog.Debugf(logEnvPattern, podTemplate.Name, podTemplate.Spec.Containers[i].Env)
	}
}

func (r *ASJobReconciler) setCommonEnv(pi *podInfo, podTemplate *corev1.PodTemplateSpec) {
	for i := range podTemplate.Spec.Containers {
		if podTemplate.Spec.Containers[i].Name == api.DefaultContainerName {
			if len(podTemplate.Spec.Containers[i].Env) == 0 {
				podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
			}
			if !r.isVirtualResourceReq(&podTemplate.Spec.Containers[i].Resources.Requests) && !isSoftShareDevJob(pi.job) {
				r.setAscendVisibleDevicesEnv(&podTemplate.Spec.Containers[i])
			}
			addEnvValue(podTemplate, taskIDEnvKey, string(pi.job.UID), i)
			addEnvValue(podTemplate, mindxServerIPEnv, pi.clusterdSvcIp, i)
			addEnvValue(podTemplate, mindxServerDomainEnv, mindxDefaultServerDomain, i)
			addEnvValue(podTemplate, hostNetwork, strconv.FormatBool(pi.spec.Template.Spec.HostNetwork), i)
			addHcclSuperPodIdEnv(pi, podTemplate, i)
			addEnvValue(podTemplate, replicaType, string(pi.rtype), i)
			hwlog.RunLog.Debugf(logEnvPattern, podTemplate.Name, podTemplate.Spec.Containers[i].Env)
		}
	}
}

func (r *ASJobReconciler) setAscendVisibleDevicesEnv(container *corev1.Container) {
	for resourceAnnoKey := range container.Resources.Requests {
		if strings.Contains(string(resourceAnnoKey), api.ResourceNamePrefix) {
			container.Env = append(container.Env, corev1.EnvVar{
				Name: api.AscendVisibleDevicesEnv,
				ValueFrom: &corev1.EnvVarSource{
					FieldRef: &corev1.ObjectFieldSelector{
						FieldPath: fmt.Sprintf("metadata.annotations['%s']", resourceAnnoKey),
					},
				},
			})
			return
		}
	}
}

func (r *ASJobReconciler) setMindSporeEnv(pi *podInfo, podTemplate *corev1.PodTemplateSpec) {
	msRoleMap := map[commonv1.ReplicaType]string{
		mindxdlv1.MindSporeReplicaTypeScheduler: msSchedulerRole,
		mindxdlv1.ReplicaTypeWorker:             msWorkerRole,
	}
	for i := range podTemplate.Spec.Containers {
		if podTemplate.Spec.Containers[i].Name == api.DefaultContainerName {
			if len(podTemplate.Spec.Containers[i].Env) == 0 {
				podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
			}
			if pi.rtype == mindxdlv1.MindSporeReplicaTypeScheduler {
				podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
					Name: msSchedHost,
					ValueFrom: &corev1.EnvVarSource{
						FieldRef: &corev1.ObjectFieldSelector{
							FieldPath: statusPodIPDownwardAPI,
						},
					},
				})
			} else {
				addEnvValue(podTemplate, msSchedHost, pi.ip, i)
			}
			if pi.isSoftShareDevJob {
				addEnvValue(podTemplate, api.MsLocalWorkerEnv, strconv.Itoa(1), i)
				addEnvValue(podTemplate, api.MsWorkerNumEnv, strconv.Itoa(pi.npuReplicas), i)
			} else if !pi.isDynamicCutJob {
				addEnvValue(podTemplate, api.MsLocalWorkerEnv, strconv.Itoa(pi.ctReq), i)
				addEnvValue(podTemplate, api.MsWorkerNumEnv, strconv.Itoa(pi.ctReq*pi.npuReplicas), i)

			}
			addEnvValue(podTemplate, msNodeRank, strconv.Itoa(pi.rank), i)
			addEnvValue(podTemplate, msSchedPort, pi.port, i)
			addEnvValue(podTemplate, msServerNum, "0", i)
			addEnvValue(podTemplate, msRole, msRoleMap[pi.rtype], i)

			addEnvValue(podTemplate, npuPod, strconv.FormatBool(checkNpuPod(pi)), i)
			addProcessRecoverEnv(pi, podTemplate, i, api.MindSporeFramework)
			addMSPodScheduleEnv(pi, podTemplate, i)
			addSubHealthyEnv(pi, podTemplate, i, api.MindSporeFramework)
			hwlog.RunLog.Debugf(logEnvPattern, podTemplate.Name, podTemplate.Spec.Containers[i].Env)
		}
	}
}

func (r *ASJobReconciler) setPytorchEnv(pi *podInfo, podTemplate *corev1.PodTemplateSpec) {
	for i := range podTemplate.Spec.Containers {
		if podTemplate.Spec.Containers[i].Name == api.DefaultContainerName {

			if len(podTemplate.Spec.Containers[i].Env) == 0 {
				podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
			}
			if pi.isSoftShareDevJob {
				addEnvValue(podTemplate, api.PtLocalWorldSizeEnv, strconv.Itoa(1), i)
				addEnvValue(podTemplate, api.PtWorldSizeEnv, strconv.Itoa(pi.npuReplicas), i)
				addEnvValue(podTemplate, api.PtLocalRankEnv, localRankStr(1), i)
			} else if !pi.isDynamicCutJob {
				addEnvValue(podTemplate, api.PtLocalWorldSizeEnv, strconv.Itoa(pi.ctReq), i)
				addEnvValue(podTemplate, api.PtWorldSizeEnv, strconv.Itoa(pi.ctReq*pi.npuReplicas), i)
				addEnvValue(podTemplate, api.PtLocalRankEnv, localRankStr(pi.ctReq), i)
			}
			addEnvValue(podTemplate, ptMasterAddr, pi.ip, i)
			addEnvValue(podTemplate, ptMasterPort, pi.port, i)
			addEnvValue(podTemplate, ptRank, strconv.Itoa(pi.rank), i)
			addProcessRecoverEnv(pi, podTemplate, i, api.PytorchFramework)
			addSubHealthyEnv(pi, podTemplate, i, api.PytorchFramework)
			hwlog.RunLog.Debugf(logEnvPattern, podTemplate.Name, podTemplate.Spec.Containers[i].Env)
		}
	}
}

func (r *ASJobReconciler) setTensorflowEnv(pi *podInfo, podTemplate *corev1.PodTemplateSpec) {
	for i := range podTemplate.Spec.Containers {
		if podTemplate.Spec.Containers[i].Name == api.DefaultContainerName {
			if len(podTemplate.Spec.Containers[i].Env) == 0 {
				podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
			}
			if pi.rtype == mindxdlv1.TensorflowReplicaTypeChief {
				podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
					Name: tfChiefIP,
					ValueFrom: &corev1.EnvVarSource{
						FieldRef: &corev1.ObjectFieldSelector{
							FieldPath: statusPodIPDownwardAPI,
						},
					},
				})
			} else {
				addEnvValue(podTemplate, tfChiefIP, pi.ip, i)
			}
			if pi.isSoftShareDevJob {
				addEnvValue(podTemplate, api.TfLocalWorkerEnv, strconv.Itoa(1), i)
				addEnvValue(podTemplate, api.TfWorkerSizeEnv, strconv.Itoa(pi.npuReplicas), i)
			} else if !pi.isDynamicCutJob {
				addEnvValue(podTemplate, api.TfLocalWorkerEnv, strconv.Itoa(pi.ctReq), i)
				addEnvValue(podTemplate, api.TfWorkerSizeEnv, strconv.Itoa(pi.ctReq*pi.npuReplicas), i)
			}
			addEnvValue(podTemplate, tfChiefPort, pi.port, i)
			addEnvValue(podTemplate, tfRank, strconv.Itoa(pi.rank), i)
			addEnvValue(podTemplate, tfChiefDevice, "0", i)
			podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
				Name: tfWorkerIP,
				ValueFrom: &corev1.EnvVarSource{
					FieldRef: &corev1.ObjectFieldSelector{
						FieldPath: statusPodIPDownwardAPI,
					},
				},
			})
			hwlog.RunLog.Debugf(logEnvPattern, podTemplate.Name, podTemplate.Spec.Containers[i].Env)
		}
	}
}

// addHcclSuperPodIdEnv add HCCL_LOGIC_SUPERPOD_ID env to build hccs network
func addHcclSuperPodIdEnv(pi *podInfo, pod *corev1.PodTemplateSpec, index int) {
	for name, res := range pod.Spec.Containers[index].Resources.Requests {
		if strings.Contains(string(name), api.ResourceNamePrefix) {
			chipsPerNode := int(res.Value())
			superPodId := strconv.Itoa(utils.GetLogicSuperPodId(pi.rank, utils.GetSpBlock(pi.job), chipsPerNode))
			hwlog.RunLog.Debugf("pod<%s> resource<%v=%v> pod-rank=%v sp-block=%v set %s=%v",
				pod.Name, name, chipsPerNode, pi.rank, utils.GetSpBlock(pi.job), hcclSuperPodLogicId, superPodId)
			if pi.job.Labels[utils.SuperPodAffinity] != utils.SoftStrategy {
				addEnvValue(pod, hcclSuperPodLogicId, superPodId, index)
				break
			}
			if string(name) == api.HuaweiNPU {
				addEnvValue(pod, hcclSuperPodLogicId, superPodId, index)
			} else {
				addEnvValueForSoftStrategy(pod, hcclSuperPodLogicId, index)
			}
			break
		}
	}
}

func addMSPodScheduleEnv(pi *podInfo, pod *corev1.PodTemplateSpec, containerIndex int) {
	if !isPodScheduleStrategy(pi.job) {
		return
	}
	for _, env := range pod.Spec.Containers[containerIndex].Env {
		if env.Name == api.MsRecoverEnv {
			return
		}
	}
	addEnvValue(pod, api.MsRecoverEnv, `{`+api.MsRscStrategy+`}`, containerIndex)
}

func addSubHealthyEnv(pi *podInfo, pod *corev1.PodTemplateSpec, containerIndex int, framework string) {
	strategy := getSubHealthyStrategy(pi.job)
	if strategy != api.SubHealthyHotSwitch {
		return
	}
	if framework == api.PytorchFramework {
		addEnvValueWithDedup(pod, api.HighAvailableEnv, api.RecoverStrategy, containerIndex)
	} else if framework == api.MindSporeFramework {
		addEnvValueWithDedup(pod, api.MsRecoverEnv, `{`+api.MsArfStrategy+`}`, containerIndex)
		addEnvValueWithDedup(pod, api.EnableMS, api.EnableFlag, containerIndex)
	} else {
		hwlog.RunLog.Warnf("subhealth hotswitch only support pytorch and mindspore framework,current: %v", framework)
		return
	}
	addEnvValueWithDedup(pod, api.ProcessRecoverEnv, api.EnableFunc, containerIndex)
	addEnvValueWithDedup(pod, api.ElasticRecoverEnv, api.EnableFlag, containerIndex)
}

func addProcessRecoverEnv(pi *podInfo, pod *corev1.PodTemplateSpec, containerIndex int, framework string) {
	strategies := getJobRecoverStrategy(pi.job)
	if strategies == "" {
		return
	}
	doAddProcessRecoverEnv(pi, pod, containerIndex, framework, strategies)
}

func doAddProcessRecoverEnv(pi *podInfo, pod *corev1.PodTemplateSpec, containerIndex int, framework string,
	strategies string) {
	env := make(map[string]string)
	trainEnv := make(sets.String)
	for _, strategy := range strings.Split(strategies, ",") {
		addEnvByStrategy(env, trainEnv, strategy, framework)
	}
	if framework == api.PytorchFramework {
		env[api.PtCloseWatchDogKey] = api.PtCloseWatchDogValue
		env[api.HighAvailableEnv] = strings.Join(trainEnv.List(), ",")
	}
	if framework == api.MindSporeFramework {
		if isPodScheduleStrategy(pi.job) {
			trainEnv.Insert(api.MsRscStrategy)
		}
		env[api.MsRecoverEnv] = `{` + strings.Join(trainEnv.List(), ",") + `}`
		env[api.EnableMS] = api.EnableFlag
		env[api.MsCloseWatchDogKey] = api.MsCloseWatchDogValue
	}
	for k, v := range env {
		addEnvValue(pod, k, v, containerIndex)
	}
	hwlog.RunLog.Infof("set process reschedule pod<%s> env: %v", pod.Name, env)
}

func addEnvByStrategy(env map[string]string, trainEnv sets.String, strategy string, framework string) {
	switch strategy {
	case api.RecoverStrategy:
		addRecoverEnv(env, trainEnv, framework)
	case api.RetryStrategy:
		addRetryEnv(env, trainEnv, framework)
	case api.DumpStrategy:
		addDumpEnv(env, trainEnv, framework)
	case api.InPlaceStrategy:
		addRecoverInPlaceEnv(env, trainEnv, framework)
	case api.ExitStrategy:
		addRecoverEnv(env, trainEnv, framework)
	case api.ElasticTraining:
		addElasticTrainingEnv(env, trainEnv, framework)
	default:
		return
	}
}

func addRecoverEnv(env map[string]string, trainEnv sets.String, framework string) {
	if env == nil || trainEnv == nil {
		return
	}
	if framework == api.PytorchFramework {
		trainEnv.Insert(api.RecoverStrategy)
	}
	if framework == api.MindSporeFramework {
		trainEnv.Insert(api.MsArfStrategy)
	}
	env[api.ProcessRecoverEnv] = api.EnableFunc
	env[api.ElasticRecoverEnv] = api.EnableFlag
}

func addRetryEnv(env map[string]string, trainEnv sets.String, framework string) {
	if env == nil || trainEnv == nil {
		return
	}
	if framework == api.PytorchFramework {
		trainEnv.Insert(api.RetryStrategy)
	}
	if framework == api.MindSporeFramework {
		trainEnv.Insert(api.MsUceStrategy)
		trainEnv.Insert(api.MsHcceStrategy)
	}
	env[api.ProcessRecoverEnv] = api.EnableFunc
	env[api.ElasticRecoverEnv] = api.EnableFlag
}

func addRecoverInPlaceEnv(env map[string]string, trainEnv sets.String, framework string) {
	if env == nil || trainEnv == nil {
		return
	}
	if framework == api.PytorchFramework {
		trainEnv.Insert(api.RecoverStrategy)
	}
	if framework == api.MindSporeFramework {
		trainEnv.Insert(api.MsArfStrategy)
	}
	env[api.ProcessRecoverEnv] = api.EnableFunc
	env[api.EnableRestartEnv] = api.EnableFunc
	env[api.ElasticRecoverEnv] = api.EnableFlag
}

func addDumpEnv(env map[string]string, trainEnv sets.String, framework string) {
	if env == nil || trainEnv == nil {
		return
	}
	if framework == api.PytorchFramework {
		trainEnv.Insert(api.DumpStrategy)
	}
	if framework == api.MindSporeFramework {
		trainEnv.Insert(api.MsDumpStrategy)
	}
	env[api.ElasticRecoverEnv] = api.EnableFlag
	env[api.ProcessRecoverEnv] = api.EnableFunc
}

func addElasticTrainingEnv(env map[string]string, trainEnv sets.String, framework string) {
	if env == nil || trainEnv == nil {
		hwlog.RunLog.Warnf("env or trainEnv is nil, env: %v, trainEnv: %v", env, trainEnv)
		return
	}
	if framework != api.PytorchFramework {
		hwlog.RunLog.Warn("elastic-training strategy only support pytorch framework")
		return
	}
	env[api.ProcessRecoverEnv] = api.EnableFunc
	trainEnv.Insert(api.ElasticTraining)
}