* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "coll_operator.h"
#include <string>
#include <unordered_map>
#include <algorithm>
#include <functional>
#include "op_type.h"
#include "string_util.h"
#include "not_support_exception.h"
#include "binary_stream.h"
namespace Hccl {
constexpr u32 MAX_OP_TAG_LEN = 191;
constexpr u32 MAX_HANDSHAKEMSGPACK_LEN = 1024;
std::string MemBufferDesc(const BaseCollOperator &collOp)
{
std::string memDesc = "";
memDesc += "inputMem=" + (collOp.inputMem ? collOp.inputMem->Describe() : "nullptr") + ", ";
memDesc += "outputMem=" + (collOp.outputMem ? collOp.outputMem->Describe() : "nullptr") + ", ";
memDesc += "scratchMem=" + (collOp.scratchMem ? collOp.scratchMem->Describe() : "nullptr");
return memDesc;
}
std::string OpDesc(const BaseCollOperator &collOp)
{
std::string opDesc = "";
opDesc += "opType=" + collOp.opType.Describe() + ", ";
opDesc += "opMode=" + collOp.opMode.Describe() + ", ";
opDesc += "dataType=" + collOp.dataType.Describe() + ", ";
opDesc += "sendRecvRemoteRank=" + std::to_string(collOp.sendRecvRemoteRank) + ", ";
opDesc += "Buffers=[" + MemBufferDesc(collOp) + "]";
return opDesc;
}
std::string DescReduceScatter(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, reduceOp=%s, dataCount=%llu]",
OpDesc(collOp).c_str(), collOp.reduceOp.Describe().c_str(), collOp.dataCount
);
}
std::string DescAllreduce(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, reduceOp=%s, dataCount=%llu]",
OpDesc(collOp).c_str(), collOp.reduceOp.Describe().c_str(), collOp.dataCount
);
}
std::string DescAllgather(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, dataCount=%llu]",
OpDesc(collOp).c_str(), collOp.dataCount
);
}
std::string DescScatter(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, dataCount=%llu, root=%u]",
OpDesc(collOp).c_str(), collOp.dataCount, collOp.root
);
}
std::string DescAlltoall(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[opType=%s, opMode=%s, sendCount=%llu, recvCount=%llu, sendType=%s, recvType=%s, "
"sendRecvRemoteRank=%u, Buffers=[%s]]",
collOp.opType.Describe().c_str(), collOp.opMode.Describe().c_str(),
collOp.all2AllDataDes.sendCount, collOp.all2AllDataDes.recvCount,
collOp.all2AllDataDes.sendType.Describe().c_str(),
collOp.all2AllDataDes.recvType.Describe().c_str(), collOp.sendRecvRemoteRank,
MemBufferDesc(collOp).c_str()
);
}
std::string DescAlltoallV(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[opType=%s, opMode=%s, sendType=%s, recvType=%s, sendRecvRemoteRank=%u, Buffers=[%s]]",
collOp.opType.Describe().c_str(), collOp.opMode.Describe().c_str(),
collOp.all2AllVDataDes.sendType.Describe().c_str(),
collOp.all2AllVDataDes.recvType.Describe().c_str(), collOp.sendRecvRemoteRank,
MemBufferDesc(collOp).c_str()
);
}
std::string DescAlltoallVC(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[opType=%s, opMode=%s, sendType=%s, recvType=%s, sendRecvRemoteRank=%u, Buffers=[%s]]",
collOp.opType.Describe().c_str(), collOp.opMode.Describe().c_str(),
collOp.all2AllVCDataDes.sendType.Describe().c_str(),
collOp.all2AllVCDataDes.recvType.Describe().c_str(), collOp.sendRecvRemoteRank,
MemBufferDesc(collOp).c_str()
);
}
std::string DescSend(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s]", OpDesc(collOp).c_str()
);
}
std::string DescRecv(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s]", OpDesc(collOp).c_str()
);
}
std::string DescReduce(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, reduceOp=%s, dataCount=%llu, root=%u]",
OpDesc(collOp).c_str(), collOp.reduceOp.Describe().c_str(), collOp.dataCount, collOp.root
);
}
std::string DescBroadcast(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, dataCount=%llu, root=%u]",
OpDesc(collOp).c_str(), collOp.dataCount, collOp.root
);
}
std::string DescBatchSendRecv(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s, dataCount=%llu, root=%u]",
OpDesc(collOp).c_str(), collOp.dataCount, collOp.root
);
}
std::string DescHalfAlltoAllV(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s]", OpDesc(collOp).c_str()
);
}
std::string DescReduceScatterV(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s]", OpDesc(collOp).c_str()
);
}
std::string DescAllGatherV(const BaseCollOperator &collOp)
{
return StringFormat(
"BaseCollOperator[%s]", OpDesc(collOp).c_str()
);
}
std::unordered_map<OpType, std::function<std::string(const BaseCollOperator &)>, std::EnumClassHash> descOpMap{
{OpType::REDUCESCATTER, std::bind(&DescReduceScatter, std::placeholders::_1)},
{OpType::ALLREDUCE, std::bind(&DescAllreduce, std::placeholders::_1)},
{OpType::ALLGATHER, std::bind(&DescAllgather, std::placeholders::_1)},
{OpType::SCATTER, std::bind(&DescScatter, std::placeholders::_1)},
{OpType::ALLTOALL, std::bind(&DescAlltoall, std::placeholders::_1)},
{OpType::ALLTOALLV, std::bind(&DescAlltoallV, std::placeholders::_1)},
{OpType::ALLTOALLVC, std::bind(&DescAlltoallVC, std::placeholders::_1)},
{OpType::SEND, std::bind(&DescSend, std::placeholders::_1)},
{OpType::RECV, std::bind(&DescRecv, std::placeholders::_1)},
{OpType::REDUCE, std::bind(&DescReduce, std::placeholders::_1)},
{OpType::BROADCAST, std::bind(&DescBroadcast, std::placeholders::_1)},
{OpType::BATCHSENDRECV, std::bind(&DescBatchSendRecv, std::placeholders::_1)},
{OpType::HALFALLTOALLV, std::bind(&DescHalfAlltoAllV, std::placeholders::_1)},
{OpType::REDUCESCATTERV, std::bind(&DescReduceScatterV, std::placeholders::_1)},
{OpType::ALLGATHERV, std::bind(&DescAllGatherV, std::placeholders::_1)},
};
std::string CollOpToString(const BaseCollOperator &collOp)
{
auto it = descOpMap.find(collOp.opType);
if (it != descOpMap.end()) {
return it->second.operator()(collOp);
} else {
return "unknown";
}
}
inline std::vector<char> DumpByteVector(BinaryStream &binaryStream)
{
std::vector<char> byteVector;
binaryStream.Dump(byteVector);
auto remainLen = MAX_HANDSHAKEMSGPACK_LEN - byteVector.size();
byteVector.insert(byteVector.end(), remainLen, '\0');
return byteVector;
}
std::vector<char> opTagToVector(const std::string &opTag)
{
std::vector<char> result(MAX_OP_TAG_LEN, '\0');
auto copyLen = opTag.size() < MAX_OP_TAG_LEN ? opTag.size() :MAX_OP_TAG_LEN;
std::copy_n(opTag.begin(), copyLen, result.begin());
return result;
}
std::string vectorToOpTag(const std::vector<char> &opTagvector)
{
auto validSize = opTagvector.size() < MAX_OP_TAG_LEN ? opTagvector.size() : MAX_OP_TAG_LEN;
auto firstNul = std::find(opTagvector.begin(), opTagvector.begin() + validSize, '\0');
return std::string(opTagvector.begin(), firstNul);
}
std::vector<char> CollOperator::GetUniqueId() const
{
HCCL_INFO("[CollOperator::%s] opMode[%s], opType[%s], reduceOp[%s], dataType[%s], dataCount[%llu], root[%u], "
"myRank[%d], sendRecvRemoteRank[%d], opTag[%s], staticAddr[%d], staticShape[%d], outputDataType[%s], ",
__func__, opMode.Describe().c_str(), opType.Describe().c_str(), reduceOp.Describe().c_str(), dataType.Describe().c_str(),
dataCount, root, myRank, sendRecvRemoteRank, opTag.c_str(), staticAddr, staticShape, outputDataType.Describe().c_str());
BinaryStream binaryStream;
binaryStream << opMode;
binaryStream << opType;
binaryStream << reduceOp;
binaryStream << dataType;
binaryStream << dataCount;
binaryStream << root;
binaryStream << myRank;
binaryStream << sendRecvRemoteRank;
binaryStream << opTagToVector(opTag);
binaryStream << staticAddr;
binaryStream << staticShape;
binaryStream << outputDataType;
if (opType == OpType::BATCHSENDRECV) {
return DumpByteVector(binaryStream);
}
if (opType == OpType::ALLTOALL) {
binaryStream << all2AllDataDes.sendType;
binaryStream << all2AllDataDes.recvType;
binaryStream << all2AllDataDes.sendCount;
binaryStream << all2AllDataDes.recvCount;
return DumpByteVector(binaryStream);
}
if (opType == OpType::ALLTOALLV) {
binaryStream << all2AllVDataDes.sendType;
binaryStream << all2AllVDataDes.recvType;
return DumpByteVector(binaryStream);
}
if (opType == OpType::ALLTOALLVC) {
binaryStream << all2AllVCDataDes.sendType;
binaryStream << all2AllVCDataDes.recvType;
return DumpByteVector(binaryStream);
}
if (opType == OpType::ALLGATHERV || opType == OpType::REDUCESCATTERV) {
binaryStream << vDataDes.dataType;
return DumpByteVector(binaryStream);
}
binaryStream << dataDes.dataCount;
binaryStream << dataDes.dataType;
binaryStream << dataDes.strideCount;
return DumpByteVector(binaryStream);
}
CollOperatorDef CollOperator::GetPackedData(std::vector<char> &byteVector)
{
CollOperator op;
BinaryStream binaryStream(byteVector);
std::vector<char> vectorOpTag;
binaryStream >> op.opMode;
binaryStream >> op.opType;
binaryStream >> op.reduceOp;
binaryStream >> op.dataType;
binaryStream >> op.dataCount;
binaryStream >> op.root;
binaryStream >> op.myRank;
binaryStream >> op.sendRecvRemoteRank;
binaryStream >> vectorOpTag;
binaryStream >> op.staticAddr;
binaryStream >> op.staticShape;
binaryStream >> op.outputDataType;
op.opTag = vectorToOpTag(vectorOpTag);
if (op.opType == OpType::BATCHSENDRECV) {
return op;
}
if (op.opType == OpType::ALLTOALL) {
binaryStream >> op.all2AllDataDes.sendType;
binaryStream >> op.all2AllDataDes.recvType;
binaryStream >> op.all2AllDataDes.sendCount;
binaryStream >> op.all2AllDataDes.recvCount;
return op;
}
if (op.opType == OpType::ALLTOALLV) {
binaryStream >> op.all2AllVDataDes.sendType;
binaryStream >> op.all2AllVDataDes.recvType;
return op;
}
if (op.opType == OpType::ALLTOALLVC) {
binaryStream >> op.all2AllVCDataDes.sendType;
binaryStream >> op.all2AllVCDataDes.recvType;
return op;
}
if (op.opType == OpType::ALLGATHERV || op.opType == OpType::REDUCESCATTERV) {
binaryStream >> op.vDataDes.dataType;
return op;
}
binaryStream >> op.dataDes.dataCount;
binaryStream >> op.dataDes.dataType;
binaryStream >> op.dataDes.strideCount;
return op;
}
}