Copyright(C) 2026. Huawei Technologies Co.,Ltd. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package workload
import (
"context"
"encoding/json"
"fmt"
"strconv"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"ascend-common/common-utils/hwlog"
"infer-operator/pkg/api/v1"
"infer-operator/pkg/common"
)
type StatefulSetWorkLoad struct {
*appsv1.StatefulSet
}
func (s *StatefulSetWorkLoad) SetWorkLoadObjMeta(objectMeta metav1.ObjectMeta) {
if s == nil {
return
}
s.ObjectMeta = objectMeta
}
func (s *StatefulSetWorkLoad) GetWorkLoadObjMeta() metav1.ObjectMeta {
if s == nil {
return metav1.ObjectMeta{}
}
return s.ObjectMeta
}
func (s *StatefulSetWorkLoad) IsWorkLoadReady() bool {
if s == nil {
return false
}
desiredReplicas := common.DefaultReplicas
if s.Spec.Replicas != nil {
desiredReplicas = *s.Spec.Replicas
}
if s.Generation > 0 && s.Status.ObservedGeneration < s.Generation {
return false
}
if s.Status.ReadyReplicas != desiredReplicas ||
s.Status.UpdatedReplicas != desiredReplicas {
return false
}
if s.Status.CurrentRevision != "" && s.Status.UpdateRevision != "" &&
s.Status.CurrentRevision != s.Status.UpdateRevision {
return false
}
return true
}
func (s *StatefulSetWorkLoad) GetWorkLoadReplicas() int32 {
if s == nil {
return common.DefaultReplicas
}
replicas := s.Spec.Replicas
if replicas == nil {
return common.DefaultReplicas
}
return *replicas
}
type StatefulSetHandler struct {
client client.Client
}
func NewStatefulSetHandler(client client.Client) *StatefulSetHandler {
return &StatefulSetHandler{
client: client,
}
}
func (s *StatefulSetHandler) CheckOrCreateWorkLoad(
ctx context.Context,
instanceSet *v1.InstanceSet,
indexer common.InstanceIndexer) error {
service := &corev1.Service{}
serviceNamespacedName := types.NamespacedName{
Name: common.GetServiceNameFromIndexer(indexer),
Namespace: instanceSet.Namespace,
}
err := s.client.Get(ctx, serviceNamespacedName, service)
if err != nil && !errors.IsNotFound(err) {
hwlog.RunLog.Errorf("Failed to get service %s/%s: %v",
instanceSet.Namespace, instanceSet.Name, err)
return common.NewRequeueError(err.Error())
}
if errors.IsNotFound(err) {
hwlog.RunLog.Infof("service of <%v> not exist, try to create", indexer)
if err := s.createService(ctx, instanceSet, indexer); err != nil {
return common.NewRequeueError(err.Error())
}
}
selectLabels := make(map[string]string)
selectLabels = common.AddLabelsFromIndexer(selectLabels, indexer)
statefulsetList, err := s.ListWorkLoads(ctx, selectLabels, indexer.Namespace)
if err != nil {
return err
}
if len(statefulsetList.Items) == 0 {
hwlog.RunLog.Infof("statefulset of <%v> not exist, try to create", indexer)
err := s.createStatefulSet(ctx, instanceSet, indexer)
if err != nil {
return err
}
}
if len(statefulsetList.Items) > 1 {
hwlog.RunLog.Warnf("More than one StatefulSet exists in InstanceSet<%s>", instanceSet.Name)
}
return nil
}
func (s *StatefulSetHandler) createStatefulSet(
ctx context.Context,
instanceSet *v1.InstanceSet,
indexer common.InstanceIndexer) error {
statefulsetSpec, err := s.parseStatefulSetWithScheme(instanceSet.Spec.InstanceSpec)
if err != nil {
return err
}
statefulsetLabels := common.DeepCopyLabelsMap(instanceSet.Labels)
for k, v := range instanceSet.Spec.WorkloadObjectMeta.Labels {
statefulsetLabels[k] = v
}
statefulsetLabels = common.AddLabelsFromIndexer(statefulsetLabels, indexer)
faultScheduling, ok := statefulsetSpec.Template.Labels[common.FaultSchedulingLabelKey]
if ok {
statefulsetLabels[common.FaultSchedulingLabelKey] = faultScheduling
}
statefulsetAnnotations := common.DeepCopyLabelsMap(instanceSet.Annotations)
for k, v := range instanceSet.Spec.WorkloadObjectMeta.Annotations {
statefulsetAnnotations[k] = v
}
statefulsetSpec.Template.Labels = common.AddLabelsFromIndexer(statefulsetSpec.Template.Labels, indexer)
if statefulsetSpec.Template.Annotations == nil {
statefulsetSpec.Template.Annotations = map[string]string{}
}
useGangScheduling := instanceSet.Labels[common.GangScheduleLabelKey] == common.TrueBool
if useGangScheduling {
statefulsetSpec.Template.Annotations[common.GroupNameAnnotationKey] = common.GetPGNameFromIndexer(indexer)
}
statefulsetSpec.ServiceName = common.GetServiceNameFromIndexer(indexer)
common.AddEnvToPodTemplate(&statefulsetSpec.Template, indexer)
err = s.createCMForSnapshot(ctx, instanceSet, common.GetWorkLoadNameFromIndexer(indexer))
if err != nil {
hwlog.RunLog.Errorf("createCMForSnapshot failed: %v", err)
}
common.AddSnapshotInfoToPodTemplate(&statefulsetSpec.Template, instanceSet,
common.SnapshotMetadataPrefix+common.GetWorkLoadNameFromIndexer(indexer))
common.AddMetadataVolume(&statefulsetSpec.Template,
common.SnapshotMetadataPrefix+common.GetWorkLoadNameFromIndexer(indexer), instanceSet)
newStatefulSet := &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Name: common.GetWorkLoadNameFromIndexer(indexer),
Namespace: instanceSet.Namespace,
Annotations: statefulsetAnnotations,
Labels: statefulsetLabels,
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(instanceSet, instanceSet.GroupVersionKind()),
},
},
Spec: *statefulsetSpec,
}
err = s.client.Create(ctx, newStatefulSet)
if err != nil {
hwlog.RunLog.Errorf("Failed to create StatefulSet<%s>: %v", newStatefulSet.Name, err)
return common.NewRequeueError(err.Error())
}
return nil
}
func (s *StatefulSetHandler) createService(
ctx context.Context,
instanceSet *v1.InstanceSet,
indexer common.InstanceIndexer) error {
labels := make(map[string]string)
labels = common.AddLabelsFromIndexer(labels, indexer)
selectLabels := common.DeepCopyLabelsMap(labels)
if common.IsContainerSnapshotOn(instanceSet) {
selectLabels[common.ActiveLabelKey] = common.TrueBool
}
newService := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: common.GetServiceNameFromIndexer(indexer),
Namespace: instanceSet.Namespace,
Annotations: instanceSet.Annotations,
Labels: labels,
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(instanceSet, instanceSet.GroupVersionKind()),
},
},
Spec: corev1.ServiceSpec{
ClusterIP: "None",
Selector: selectLabels,
Ports: []corev1.ServicePort{
{
Name: common.DefaultPortName,
Port: common.DefaultPort,
},
},
},
}
err := s.client.Create(ctx, newService)
if err != nil {
hwlog.RunLog.Errorf("Failed to create Service<%s>: %v", newService.Name, err)
return common.NewRequeueError(err.Error())
}
return nil
}
func (s *StatefulSetHandler) DeleteExtraWorkLoad(
ctx context.Context,
indexer common.InstanceIndexer, indexLimit int) error {
selectLabels := make(map[string]string)
selectLabels = common.AddLabelsFromIndexer(selectLabels, indexer)
delete(selectLabels, common.InstanceIndexLabelKey)
statefulsetList, err := s.ListWorkLoads(ctx, selectLabels, indexer.Namespace)
if err != nil {
return err
}
for _, statefulset := range statefulsetList.Items {
instanceIndexStr, ok := statefulset.Labels[common.InstanceIndexLabelKey]
if !ok {
continue
}
instanceIndex, err := strconv.Atoi(instanceIndexStr)
if err != nil {
hwlog.RunLog.Warnf("StatefulSet<%s> Failed to convert instance index to int: %v",
statefulset.Name, instanceIndexStr)
continue
}
if instanceIndex < indexLimit && instanceIndex >= 0 {
continue
}
if err = s.client.Delete(ctx, &statefulset); err != nil {
hwlog.RunLog.Errorf("Failed to delete StatefulSet<%s>: %v", statefulset.Name, err)
return err
}
hwlog.RunLog.Infof("Delete Extra StatefulSet<%s>", statefulset.Name)
}
return s.deleteExtraService(ctx, selectLabels, indexLimit)
}
func (s *StatefulSetHandler) GetWorkLoadReadyReplicas(
ctx context.Context,
indexer common.InstanceIndexer) (int, error) {
readyReplicas := 0
selectLabels := make(map[string]string)
selectLabels = common.AddLabelsFromIndexer(selectLabels, indexer)
delete(selectLabels, common.InstanceIndexLabelKey)
statefulsetList, err := s.ListWorkLoads(ctx, selectLabels, indexer.Namespace)
if err != nil {
return readyReplicas, err
}
for _, statefulset := range statefulsetList.Items {
if isStatefulsetReady(statefulset) {
readyReplicas++
}
}
return readyReplicas, nil
}
func (s *StatefulSetHandler) deleteExtraService(
ctx context.Context,
selectLabels map[string]string,
indexLimit int) error {
serviceList := &corev1.ServiceList{}
selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
MatchLabels: selectLabels,
})
if err != nil {
hwlog.RunLog.Errorf("Failed to convert label selector to selector: %v", err)
return common.NewRequeueError(err.Error())
}
if err = s.client.List(ctx, serviceList, client.MatchingLabelsSelector{Selector: selector}); err != nil {
hwlog.RunLog.Errorf("Failed to list extra services: %v", err)
return common.NewRequeueError(err.Error())
}
for _, service := range serviceList.Items {
instanceIndexStr, ok := service.Labels[common.InstanceIndexLabelKey]
if !ok {
continue
}
instanceIndex, err := strconv.Atoi(instanceIndexStr)
if err != nil {
hwlog.RunLog.Warnf("service<%s> Failed to convert instance index to int: %v",
service.Name, instanceIndexStr)
continue
}
if instanceIndex < indexLimit && instanceIndex >= 0 {
continue
}
err = s.client.Delete(ctx, &service)
if err != nil {
hwlog.RunLog.Errorf("Failed to delete Extra Service<%s>: %v", service.Name, err)
return common.NewRequeueError(err.Error())
}
}
return nil
}
func (s *StatefulSetHandler) ListWorkLoads(
ctx context.Context,
selectLabels map[string]string,
namespace string) (*appsv1.StatefulSetList, error) {
statefulsetList := &appsv1.StatefulSetList{}
selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
MatchLabels: selectLabels,
})
if err != nil {
hwlog.RunLog.Errorf("Failed to convert label selector to selector: %v", err)
return statefulsetList, common.NewRequeueError(err.Error())
}
if err = s.client.List(ctx, statefulsetList,
client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(namespace)); err != nil {
hwlog.RunLog.Errorf("Failed to list extra statefulsets: %v", err)
return nil, common.NewRequeueError(err.Error())
}
return statefulsetList, nil
}
func (s *StatefulSetHandler) Validate(spec runtime.RawExtension) error {
_, err := s.parseStatefulSetWithScheme(spec)
if err != nil {
return err
}
return nil
}
func (s *StatefulSetHandler) GetReplicas(spec runtime.RawExtension) (int32, error) {
statefulsetSpec, err := s.parseStatefulSetWithScheme(spec)
if err != nil {
return common.DefaultReplicas, err
}
replicas := statefulsetSpec.Replicas
if replicas == nil {
return common.DefaultReplicas, nil
}
return *replicas, nil
}
func isStatefulsetReady(sts appsv1.StatefulSet) bool {
desiredReplicas := int32(1)
if sts.Spec.Replicas != nil {
desiredReplicas = *sts.Spec.Replicas
}
if sts.Generation > 0 && sts.Status.ObservedGeneration < sts.Generation {
return false
}
if sts.Status.ReadyReplicas != desiredReplicas ||
sts.Status.UpdatedReplicas != desiredReplicas {
return false
}
if sts.Status.CurrentRevision != "" && sts.Status.UpdateRevision != "" &&
sts.Status.CurrentRevision != sts.Status.UpdateRevision {
return false
}
return true
}
func (s *StatefulSetHandler) parseStatefulSetWithScheme(raw runtime.RawExtension) (*appsv1.StatefulSetSpec, error) {
if len(raw.Raw) == 0 {
return nil, fmt.Errorf("raw extension is empty")
}
var spec appsv1.StatefulSetSpec
if err := json.Unmarshal(raw.Raw, &spec); err != nil {
return nil, fmt.Errorf("failed to unmarshal RawExtension to StatefulSetSpec: %w", err)
}
return &spec, nil
}
func (d *StatefulSetHandler) ListWorkLoad(
ctx context.Context,
selectLabels map[string]string,
namespace string,
filters ...WorkLoadFilter) ([]WorkLoadInterface, error) {
statefulsetList := &appsv1.StatefulSetList{}
selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
MatchLabels: selectLabels,
})
if err != nil {
hwlog.RunLog.Errorf("Failed to create selector: %v", err)
return nil, common.NewRequeueError(err.Error())
}
if err = d.client.List(ctx, statefulsetList,
client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(namespace)); err != nil {
hwlog.RunLog.Errorf("Failed to list StatefulSets: %v", err)
return nil, common.NewRequeueError(err.Error())
}
statefulsetWorkLoadList := make([]WorkLoadInterface, 0, len(statefulsetList.Items))
for _, statefulset := range statefulsetList.Items {
statefulsetCopy := statefulset
statefulSetWorkLoad := &StatefulSetWorkLoad{StatefulSet: &statefulsetCopy}
ok := true
for _, filter := range filters {
ok = ok && filter(statefulSetWorkLoad)
if !ok {
break
}
}
if ok {
statefulsetWorkLoadList = append(statefulsetWorkLoadList, statefulSetWorkLoad)
}
}
return statefulsetWorkLoadList, nil
}
func (s *StatefulSetHandler) DeleteWorkLoad(
ctx context.Context,
selectLabels map[string]string,
namespace string,
filters ...WorkLoadFilter) error {
statefulsetList, err := s.ListWorkLoads(ctx, selectLabels, namespace)
if err != nil {
return fmt.Errorf("failed to list statefulset work loads: %w", err)
}
var workloadList []*StatefulSetWorkLoad
for _, statefulset := range statefulsetList.Items {
ok := true
statefulsetCopy := statefulset
workload := &StatefulSetWorkLoad{StatefulSet: &statefulsetCopy}
for _, filter := range filters {
ok = ok && filter(workload)
if !ok {
break
}
}
if ok {
workloadList = append(workloadList, workload)
}
}
for _, workload := range workloadList {
if err := s.client.Delete(ctx, workload.StatefulSet); err != nil {
return fmt.Errorf("failed to delete statefulset work load %s/%s: %w", workload.Namespace, workload.Name, err)
}
}
return nil
}
func (s *StatefulSetHandler) UpdateWorkLoad(
ctx context.Context,
selectLabels map[string]string,
namespace string,
updater WorkloadUpdater,
filters ...WorkLoadFilter) error {
statefulsetList, err := s.ListWorkLoads(ctx, selectLabels, namespace)
if err != nil {
return fmt.Errorf("failed to list statefulset work loads: %w", err)
}
var workloadList []*StatefulSetWorkLoad
for _, statefulset := range statefulsetList.Items {
ok := true
statefulsetCopy := statefulset
workload := &StatefulSetWorkLoad{StatefulSet: &statefulsetCopy}
for _, filter := range filters {
ok = ok && filter(workload)
if !ok {
break
}
}
if ok {
workloadList = append(workloadList, workload)
}
}
for _, workload := range workloadList {
updater(workload)
if err := s.client.Update(ctx, workload.StatefulSet); err != nil {
return fmt.Errorf("failed to update statefulset work load %s/%s: %w", workload.Namespace, workload.Name, err)
}
}
return nil
}
func (s *StatefulSetHandler) createCMForSnapshot(ctx context.Context, instanceSet *v1.InstanceSet, instanceName string) error {
if !common.IsContainerSnapshotOn(instanceSet) {
return nil
}
data := common.SnapshotMetaData{
InstanceName: instanceName,
Namespace: instanceSet.Namespace,
}
dataBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal snapshot metadata: %v", err)
}
cm := &corev1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: common.SnapshotMetadataPrefix + instanceName,
Namespace: instanceSet.Namespace,
Labels: map[string]string{
common.OperatorNameKey: common.TrueBool,
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(instanceSet, instanceSet.GroupVersionKind()),
},
},
Data: map[string]string{
"snapshot_metadata.json": string(dataBytes),
common.GrusSnapshotRestoredFlagKey: "false",
},
}
existCM := &corev1.ConfigMap{}
err = s.client.Get(ctx, client.ObjectKeyFromObject(cm), existCM)
if err != nil {
if errors.IsNotFound(err) {
err = s.client.Create(ctx, cm)
if err != nil {
return fmt.Errorf("create configmap failed: %v", err)
}
return nil
}
return fmt.Errorf("get configmap failed: %v", err)
}
existCM.Data = cm.Data
if existCM.Labels == nil {
existCM.Labels = make(map[string]string)
}
existCM.Labels[common.OperatorNameKey] = common.TrueBool
if err = s.client.Update(ctx, existCM); err != nil {
return fmt.Errorf("update configmap failed: %v", err)
}
return nil
}