* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "parallel_task_loader.h"
#include "profiling_manager_pub.h"
namespace hccl {
ParallelTaskLoader::ParallelTaskLoader(const s32 deviceLogicId, const HcclDispatcher dispatcher)
: deviceLogicId_(deviceLogicId), dispatcher_(dispatcher), taskLoaderNum_(0)
{}
ParallelTaskLoader::~ParallelTaskLoader()
{}
HcclResult ParallelTaskLoader::Prepare(std::vector<Stream *> streamsPtr, SubCommInfo level0CommInfo)
{
streamsPtr_.resize(streamsPtr.size());
for (u32 streamIndex = 0; streamIndex < streamsPtr.size(); streamIndex++) {
streamsPtr_[streamIndex] = streamsPtr[streamIndex];
}
commInfo_ = level0CommInfo;
HCCL_INFO("[ParallelTaskLoader]Prepare streams size[%u], taskLoaderNum_[%u]", streamsPtr_.size(), taskLoaderNum_);
if (taskLoaderNum_ >= streamsPtr_.size()) {
HCCL_INFO("[ParallelTaskLoader] taskloaderNum [%u]", taskLoaderNum_);
return HCCL_SUCCESS;
}
streamTaskLoader_.resize(streamsPtr_.size());
for (u32 streamIndex = taskLoaderNum_; streamIndex < streamsPtr_.size(); streamIndex++) {
streamTaskLoader_[streamIndex].reset(new (std::nothrow) TaskLoader(deviceLogicId_, dispatcher_));
CHK_SMART_PTR_NULL(streamTaskLoader_[streamIndex]);
HcclResult ret = streamTaskLoader_[streamIndex]->Init();
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[ParallelTaskLoader][Init]streamIndex[%u] TaskLoader failed, return[%d]", streamIndex, ret),
ret);
}
taskLoaderNum_ = streamsPtr_.size();
HCCL_INFO("[ParallelTaskLoader] Prepare success taskLoaderNum[%u]", taskLoaderNum_);
return HCCL_SUCCESS;
}
HcclResult ParallelTaskLoader::StartTaskLoad()
{
tidInfo_.resize(streamsPtr_.size());
for (u32 streamIndex = 0; streamIndex < streamsPtr_.size(); streamIndex++) {
streamTaskLoader_[streamIndex]->Prepare(streamsPtr_[streamIndex], commInfo_);
tidInfo_[streamIndex] = streamTaskLoader_[streamIndex]->GetTid();
}
#ifndef OPEN_HCCL_TEST
CHK_RET(hccl::ProfilingManagerPub::CallMsprofReportMultiThreadInfo(tidInfo_));
#endif
for (u32 streamIndex = 0; streamIndex < streamsPtr_.size(); streamIndex++) {
streamTaskLoader_[streamIndex]->NotifyStart();
}
return HCCL_SUCCESS;
}
HcclResult ParallelTaskLoader::WaitTaskLoadFinish()
{
for (u32 streamIndex = 0; streamIndex < streamsPtr_.size(); streamIndex++) {
streamTaskLoader_[streamIndex]->WaitDone();
CHK_RET(streamTaskLoader_[streamIndex]->GetExecuteResult());
}
return HCCL_SUCCESS;
}
HcclResult ParallelTaskLoader::ClearTagCommInfo()
{
commInfo_ = SubCommInfo{};
for (u32 streamIndex = 0; streamIndex < streamsPtr_.size(); streamIndex++) {
CHK_RET(streamTaskLoader_[streamIndex]->ClearTagCommInfo());
}
return HCCL_SUCCESS;
}
}