package recover
import (
"context"
"fmt"
"slices"
"ascend-common/common-utils/hwlog"
"clusterd/pkg/application/faultmanager"
"clusterd/pkg/common/constant"
"clusterd/pkg/domain/common"
"clusterd/pkg/domain/job"
"clusterd/pkg/interface/grpc/recover"
)
var (
stressOps = []int64{0, 1}
)
func (s *FaultRecoverService) SwitchNicTrack(ctx context.Context, nics *pb.SwitchNics) (*pb.Status, error) {
if ok, msg := s.checkNicsParam(nics); !ok {
hwlog.RunLog.Errorf("check param failed: %s", msg)
return &pb.Status{
Code: int32(common.OMParamInvalid),
Info: msg,
}, nil
}
ctl, _ := s.getController(nics.JobID)
if !ctl.canDoSwitchingNic() {
hwlog.RunLog.Errorf("jobId=%s nic is swtiching, or job recovering", nics.JobID)
return &pb.Status{
Code: int32(common.OMIsRunning),
Info: fmt.Sprintf("jobId=%s nic is swtiching, or job recovering", nics.JobID),
}, nil
}
globalSwitchRankIDs, globalOps := s.getGlobalRankIDAndOp(nics)
ctl.setSwitchNicParam(globalSwitchRankIDs, globalOps)
ctl.addEvent(common.StartSwitchNic)
hwlog.RunLog.Infof("jobId=%s nic swtich: %v, %v", nics.JobID, globalSwitchRankIDs, globalOps)
return &pb.Status{Code: int32(common.OK), Info: "switching operation was successfully distributed"}, nil
}
func (s *FaultRecoverService) SubscribeSwitchNicSignal(req *pb.SwitchNicRequest,
stream pb.Recover_SubscribeSwitchNicSignalServer) error {
if req == nil {
return fmt.Errorf("request is nil")
}
hwlog.RunLog.Infof("receive Subscribe signal request, jobID: %s", req.JobID)
controller, exist := s.getController(req.JobID)
if !exist {
return fmt.Errorf("jobId=%s not registed", req.JobID)
}
if !controller.isSwitchingNic() {
return fmt.Errorf("jobId=%s is not swtiching nic", req.JobID)
}
controller.listenSwitchNicChannel(stream)
return nil
}
func (s *FaultRecoverService) SubscribeNotifySwitch(req *pb.ClientInfo, stream pb.Recover_SubscribeNotifySwitchServer) error {
if req == nil {
return fmt.Errorf("request is nil")
}
controller, exist := s.getController(req.JobId)
if !exist {
hwlog.RunLog.Debugf("jobId=%s not registed, wait job running", req.JobId)
return fmt.Errorf("jobId=%s not registed, please wait agent register", req.JobId)
}
hwlog.RunLog.Infof("receive Subscribe notify switch signal request, jobID: %s", req.JobId)
controller.listenSwitchNicNotifyChannel(stream)
return nil
}
func (s *FaultRecoverService) ReplySwitchNicResult(ctx context.Context, res *pb.SwitchResult) (*pb.Status, error) {
if res == nil {
return &pb.Status{Code: int32(common.OMParamInvalid), Info: "request is nil"}, nil
}
controller, exist := s.getController(res.JobId)
if !exist {
hwlog.RunLog.Errorf("jobId=%s not registed", res.JobId)
return &pb.Status{
Code: int32(common.UnRegistry),
Info: fmt.Sprintf("jobId=%s not registed", res.JobId)}, nil
}
controller.setSwitchNicResult(res)
hwlog.RunLog.Infof("jobId=%s nic swtich result: %v", res.JobId, res.Result)
return &pb.Status{Code: int32(common.OK), Info: "reply success"}, nil
}
func (s *FaultRecoverService) checkNicsParam(nics *pb.SwitchNics) (bool, string) {
if nics == nil {
return false, "nics is nil"
}
jobInfo, ok := job.GetJobCache(nics.JobID)
if !ok {
return false, fmt.Sprintf("job:%s not exist", nics.JobID)
}
if !s.registered(nics.JobID) {
return false, fmt.Sprintf("job:%s is not registered", nics.JobID)
}
if jobInfo.Status != job.StatusJobRunning {
return false, fmt.Sprintf("job:%s is not running", nics.JobID)
}
jobServerMap := s.getNodeDeviceMap(jobInfo.JobRankTable.ServerList)
deviceInfos := faultmanager.QueryDeviceInfoToReport()
for node, devs := range nics.NicOps {
if _, ok := jobServerMap[node]; !ok {
return false, fmt.Sprintf("node:%s not exist in job:%s", node, nics.JobID)
}
if _, ok := deviceInfos[node]; !ok {
return false, fmt.Sprintf("node:%s not exist in job:%s", node, nics.JobID)
}
if msg, ok := s.checkDevsValid(devs, jobServerMap[node], node); !ok {
return false, msg
}
if deviceInfos[node].SuperPodID < 0 {
return false, fmt.Sprintf("node:%s should operate in superPodID", node)
}
}
return true, ""
}
func (s *FaultRecoverService) checkDevsValid(switchDev *pb.DeviceList, devs []string, node string) (string, bool) {
if len(switchDev.Dev) != len(switchDev.Op) {
return "dev and op length is not equal", false
}
if len(switchDev.Dev) == 0 || len(switchDev.Op) == 0 {
return "dev or op is empty", false
}
for _, dev := range switchDev.Dev {
if !slices.Contains(devs, dev) {
return fmt.Sprintf("device:%s not exist in node:%s", dev, node), false
}
}
return "", true
}
func (s *FaultRecoverService) getNodeDeviceMap(serverList []constant.ServerHccl) map[string][]string {
serverMap := make(map[string][]string)
for _, server := range serverList {
devs := make([]string, 0)
for _, dev := range server.DeviceList {
devs = append(devs, dev.DeviceID)
}
serverMap[server.ServerName] = devs
}
return serverMap
}
func (s *FaultRecoverService) getGlobalRankIDAndOp(nics *pb.SwitchNics) ([]string, []bool) {
globalRankIDs := make([]string, 0)
globalOps := make([]bool, 0)
jobInfo, ok := job.GetJobCache(nics.JobID)
if !ok {
hwlog.RunLog.Errorf("get job cache failed, jobId=%s", nics.JobID)
return globalRankIDs, globalOps
}
serverMap := make(map[string]map[string]string)
for _, server := range jobInfo.JobRankTable.ServerList {
devs := make(map[string]string)
for _, dev := range server.DeviceList {
devs[dev.DeviceID] = dev.RankID
}
serverMap[server.ServerName] = devs
}
for node, devs := range nics.NicOps {
for _, dev := range devs.Dev {
globalRankIDs = append(globalRankIDs, serverMap[node][dev])
}
for _, op := range devs.Op {
globalOps = append(globalOps, op)
}
}
return globalRankIDs, globalOps
}
func (s *FaultRecoverService) StressTest(ctx context.Context, params *pb.StressTestParam) (*pb.Status, error) {
if ok, msg := s.checkStressTestParam(params); !ok {
hwlog.RunLog.Errorf("check param failed: %s", msg)
return &pb.Status{
Code: int32(common.OMParamInvalid),
Info: msg,
}, nil
}
ctl, _ := s.getController(params.JobID)
if !ctl.canDoStressTest() {
hwlog.RunLog.Errorf("jobId=%s om is running, or job recovering", params.JobID)
return &pb.Status{
Code: int32(common.OMIsRunning),
Info: fmt.Sprintf("jobId=%s om is running, or job recovering", params.JobID),
}, nil
}
globalRankIDs := s.getNodeRankOpsMap(params)
ctl.setStressTestParam(globalRankIDs)
ctl.addEvent(common.StartStressTest)
hwlog.RunLog.Infof("jobId=%s stress test param: %v, global ranks: %v", params.JobID, params, globalRankIDs)
return &pb.Status{Code: int32(common.OK), Info: "stress test operation was successfully distributed"}, nil
}
func (s *FaultRecoverService) SubscribeNotifyExecStressTest(req *pb.ClientInfo, stream pb.Recover_SubscribeNotifyExecStressTestServer) error {
if req == nil {
return fmt.Errorf("request is nil")
}
controller, exist := s.getController(req.JobId)
if !exist {
hwlog.RunLog.Debugf("jobId=%s not registed, wait job running", req.JobId)
return fmt.Errorf("jobId=%s not registed, please wait agent register", req.JobId)
}
hwlog.RunLog.Infof("receive Subscribe notify stress test signal request, jobID: %s", req.JobId)
controller.listenStressTestNotifyChannel(stream)
return nil
}
func (s *FaultRecoverService) ReplyStressTestResult(ctx context.Context, res *pb.StressTestResult) (*pb.Status, error) {
if res == nil {
return &pb.Status{Code: int32(common.OMParamInvalid), Info: "request is nil"}, nil
}
controller, exist := s.getController(res.JobId)
if !exist {
hwlog.RunLog.Errorf("jobId=%s not registed", res.JobId)
return &pb.Status{
Code: int32(common.UnRegistry),
Info: fmt.Sprintf("jobId=%s not registed", res.JobId)}, nil
}
controller.setStressTestResult(res)
hwlog.RunLog.Infof("jobId=%s stress test result: %v", res.JobId, res.StressResult)
return &pb.Status{Code: int32(common.OK), Info: "reply success"}, nil
}
func (s *FaultRecoverService) SubscribeStressTestResponse(req *pb.StressTestRequest,
stream pb.Recover_SubscribeStressTestResponseServer) error {
if req == nil {
return fmt.Errorf("request is nil")
}
hwlog.RunLog.Infof("receive Subscribe signal request, jobID: %s", req.JobID)
controller, exist := s.getController(req.JobID)
if !exist {
return fmt.Errorf("jobId=%s not registed", req.JobID)
}
if !controller.isStressTest() {
return fmt.Errorf("jobId=%s is not stress test", req.JobID)
}
controller.listenStressTestChannel(stream)
return nil
}
func (s *FaultRecoverService) getNodeRankOpsMap(param *pb.StressTestParam) common.StressTestParam {
nodeRankOpsMap := make(common.StressTestParam)
jobInfo, ok := job.GetJobCache(param.JobID)
if !ok {
hwlog.RunLog.Errorf("get job cache failed, jobId=%s", param.JobID)
return nil
}
nodeRankMap := s.getNodeRankMap(jobInfo.JobRankTable.ServerList)
if len(param.AllNodesOps) != 0 {
for node, ranks := range nodeRankMap {
nodeRankOpsMap[node] = make(map[string][]int64, len(ranks))
for _, rank := range ranks {
nodeRankOpsMap[node][rank] = param.AllNodesOps
}
}
return nodeRankOpsMap
}
for nodeName, _ := range param.StressParam {
nodeRankOpsMap[nodeName] = map[string][]int64{}
for _, rank := range nodeRankMap[nodeName] {
nodeRankOpsMap[nodeName][rank] = param.StressParam[nodeName].Ops
}
}
return nodeRankOpsMap
}
func (s *FaultRecoverService) checkStressTestParam(params *pb.StressTestParam) (bool, string) {
if params == nil {
return false, "param is nil"
}
jobInfo, ok := job.GetJobCache(params.JobID)
if !ok {
return false, fmt.Sprintf("job:%s not exist", params.JobID)
}
if !s.registered(params.JobID) {
return false, fmt.Sprintf("job:%s is not registered", params.JobID)
}
if len(params.AllNodesOps) != 0 {
if ok, msg := s.validateStressTestOps(params.AllNodesOps, "AllNodes"); !ok {
return false, msg
}
return true, ""
}
jobServerMap := s.getNodeRankMap(jobInfo.JobRankTable.ServerList)
if len(params.StressParam) == 0 {
return false, "stress test node is nil"
}
for node, ops := range params.StressParam {
if ops == nil {
return false, fmt.Sprintf("node:%s stress test ops is nil", node)
}
if len(ops.Ops) == 0 {
return false, fmt.Sprintf("node:%s stress test ops is 0", node)
}
if _, ok := jobServerMap[node]; !ok {
return false, fmt.Sprintf("node:%s not exist in job:%s", node, params.JobID)
}
if ok, msg := s.validateStressTestOps(ops.Ops, node); !ok {
return false, msg
}
}
return true, ""
}
func (s *FaultRecoverService) validateStressTestOps(ops []int64, node string) (bool, string) {
opMap := make(map[int64]struct{})
for _, op := range ops {
opMap[op] = struct{}{}
if !slices.Contains(stressOps, op) {
return false, fmt.Sprintf("op:%v not exist in support operation:%v", op, stressOps)
}
}
if len(opMap) != len(ops) {
return false, fmt.Sprintf("node:%s stress test ops should not repeat", node)
}
return true, ""
}
func (ctl *EventController) canDoStressTest() bool {
return ctl.state.GetState() == common.InitState
}
func (s *FaultRecoverService) getNodeRankMap(serverList []constant.ServerHccl) map[string][]string {
serverMap := make(map[string][]string)
for _, server := range serverList {
ranks := make([]string, 0)
for _, dev := range server.DeviceList {
ranks = append(ranks, dev.RankID)
}
serverMap[server.ServerName] = ranks
}
return serverMap
}