* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "datasystem/common/device/comm_wrapper.h"
#include "datasystem/common/inject/inject_point.h"
#include "datasystem/common/util/status_helper.h"
namespace datasystem {
CommWrapper::CommWrapper(const std::string &commId, int localDeviceId, int remoteDeviceId,
std::shared_ptr<HcclCommMagr> &threadControl, DeviceResourceManager *resourceMgr)
: CommWrapperBase(commId, localDeviceId, remoteDeviceId, threadControl, resourceMgr)
{
}
void CommWrapper::ShutDown()
{
if ((commState_ != CommState::DESTROY)) {
commState_ = CommState::DESTROY;
std::lock_guard<std::mutex> lock(mutex_);
if (hasShutDown_) {
return;
}
if (pool_) {
auto traceId = Trace::Instance().GetTraceID();
pool_->Submit([this, resource = resource_, traceId]() {
TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId);
LOG_IF_ERROR(
deviceImpl_->SynchronizeStreamWithTimeout(resource->PrimaryStream(), SYNC_STREAM_WAIT_TIMEOUT_MS),
"Timed out waiting for all tasks in Stream to complete, check that Recv is not called");
resource->Release();
deviceImpl_->CommDestroy(GetRef());
LOG(INFO) << "Destroy Comm ok, commId: " << commId_;
});
}
(void)commThreadControl_->RemoveThreadPoolCommRecord(bindThreadId_, commId_);
hasShutDown_ = true;
pool_.reset();
commThreadControl_.reset();
}
}
CommWrapper::~CommWrapper()
{
ShutDown();
}
Status CommWrapper::InitComm(int numRanks, CommRootInfo &rootInfo, int rank)
{
LOG(INFO) << "InitComm";
commConnectTimestamp_ = std::chrono::steady_clock::now();
commState_ = CommState::CREATING;
auto rc = deviceImpl_->CommInitRootInfo(numRanks, &rootInfo,
rank, reinterpret_cast<void**>(&GetRef()));
LOG_IF_ERROR(rc, "CommInitRootInfo failed.");
SetStatus(rc);
return rc;
}
Status CommWrapper::P2PSend(const std::vector<Blob> &blobs,
const std::shared_ptr<DeviceRtEventWrapper> &event,
aclrtStream stream)
{
LOG(INFO) << "comm start to send " << (!blobs.empty() ? std::to_string(blobs[0].size) : "")
<< ", info num: " << blobs.size();
(void)event;
auto &comm = GetRef();
RETURN_IF_NOT_OK(CheckCommPtr(comm));
for (size_t i = 0; i < blobs.size(); i++) {
RETURN_IF_NOT_OK(deviceImpl_->CommSend(blobs[i].pointer, blobs[i].size, CommDataType::INT8,
P2P_RECV_RANK, comm, stream));
}
VLOG(1) << "Send comm ok";
return Status::OK();
}
Status CommWrapper::GetCommAsyncError()
{
if (commState_ == CommState::CREATING || commState_ == CommState::UNCREATE) {
return Status::OK();
}
auto &comm = GetRef();
return deviceImpl_->CommGetAsyncError(comm);
}
Status CommWrapper::P2PRecv(const std::vector<Blob> &blobs,
const std::shared_ptr<DeviceRtEventWrapper> &event,
aclrtStream stream)
{
LOG(INFO) << "comm receiving " << (!blobs.empty() ? std::to_string(blobs[0].size) : "")
<< ", info num: " << blobs.size();
(void)event;
auto &comm = GetRef();
RETURN_IF_NOT_OK(CheckCommPtr(comm));
for (size_t i = 0; i < blobs.size(); i++) {
RETURN_IF_NOT_OK(deviceImpl_->CommRecv(blobs[i].pointer, blobs[i].size, CommDataType::INT8,
P2P_SEND_RANK, comm, stream));
}
VLOG(1) << "Recv comm ok";
return Status::OK();
}
Status CommWrapper::InitCommunicator(CommRootInfo &rootInfo, const CommDirection direction, bool isSameNode)
{
(void)isSameNode;
InitPipeline(direction);
if (direction == CommDirection::SEND) {
return InitComm(P2P_RANK_NUM, rootInfo, P2P_SEND_RANK);
}
return InitComm(P2P_RANK_NUM, rootInfo, P2P_RECV_RANK);
}
Status CommWrapper::WarmUpComm(CommDirection eventType)
{
void *devPtr = nullptr;
RETURN_IF_NOT_OK(deviceImpl_->MallocDeviceMemory(sizeof(char), devPtr));
Raii raii([this, &devPtr]() { deviceImpl_->FreeDeviceMemory(devPtr); });
std::shared_ptr<DeviceRtEventWrapper> event;
if (eventType == CommDirection::SEND) {
RETURN_IF_NOT_OK(
deviceImpl_->CommSend(devPtr, WARM_UP_DATA_COUNT, CommDataType::INT8, P2P_RECV_RANK, Get(), GetStream()));
} else if (eventType == CommDirection::RECV) {
RETURN_IF_NOT_OK(
deviceImpl_->CommRecv(devPtr, WARM_UP_DATA_COUNT, CommDataType::INT8, P2P_SEND_RANK, Get(), GetStream()));
}
RETURN_IF_NOT_OK(deviceImpl_->SynchronizeStream(GetStream()));
LOG(INFO) << "communicator warmup ok";
return Status::OK();
}
Status CommWrapper::CreateRootInfo(CommRootInfo &rootInfo)
{
RETURN_IF_NOT_OK_PRINT_ERROR_MSG(deviceImpl_->CommGetRootInfo(&rootInfo), "GetRootInfo failed.");
return Status::OK();
}
Status CommWrapper::CheckCommPtr(const void *ptr)
{
if (ptr == nullptr) {
auto errorStatus = GetDetailStatus();
return {K_RUNTIME_ERROR,
FormatString("Comm is nullptr, create communication domain failed. Detail:%s", errorStatus)};
}
return Status::OK();
}
}