* Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
* Description: 内存冲突校验/语义校验辅助函数实现
* Author: yinding
* Create: 2024-02-04
*/
#include "check_utils.h"
#include <vector>
#include "task_stub.h"
#include "mem_layout.h"
using namespace Hccl;
namespace checker {
const std::string FOUR_INDENT_SPACE = " ";
TaskTypeStub GetNodeType(const TaskNode *node)
{
return node->task->GetType();
}
bool IsAllToAllSeries(CheckerOpType opType)
{
return (opType == CheckerOpType::ALLTOALL || opType == CheckerOpType::ALLTOALLV ||
opType == CheckerOpType::ALLTOALLVC);
}
bool IsSendRecvType(CheckerOpType opType)
{
return opType == CheckerOpType::SEND || opType == CheckerOpType::RECEIVE;
}
void CalcInputOutputSize(const CheckerOpParam &opParam, u32 ranksize, u64 &inputSize, u64 &outputSize, RankId myRank)
{
u32 unitSize = 0;
if (!IsAllToAllSeries(opParam.opType) && opParam.opType != CheckerOpType::BATCH_SEND_RECV &&
opParam.opType != CheckerOpType::REDUCE_SCATTER_V && opParam.opType != CheckerOpType::ALLGATHER_V) {
unitSize = CHECK_SIZE_TABLE[opParam.DataDes.dataType];
}
u64 count = opParam.DataDes.count;
if (opParam.opType == CheckerOpType::ALLREDUCE) {
inputSize = count * unitSize;
outputSize = count * unitSize;
} else if (opParam.opType == CheckerOpType::BROADCAST) {
inputSize = count * unitSize;
outputSize = count * unitSize;
} else if (IsSendRecvType(opParam.opType) && myRank == opParam.srcRank) {
inputSize = count * unitSize;
outputSize = 0;
} else if (IsSendRecvType(opParam.opType) && myRank == opParam.dstRank) {
inputSize = 0;
outputSize = count * unitSize;
} else if (opParam.opType == CheckerOpType::REDUCE) {
if (myRank == opParam.root) {
outputSize = count * unitSize;
} else {
outputSize = count * unitSize;
}
inputSize = count * unitSize;
} else if (opParam.opType == CheckerOpType::ALLGATHER) {
inputSize = count * unitSize;
outputSize = count * unitSize * ranksize;
} else if (opParam.opType == CheckerOpType::REDUCE_SCATTER) {
inputSize = count * unitSize * ranksize;
outputSize = count * unitSize;
} else if (opParam.opType == CheckerOpType::ALLTOALL || opParam.opType == CheckerOpType::ALLTOALLVC) {
u64 curSendOffset = 0;
u64 curRecvOffset = 0;
void *sendCountMatrix = static_cast<void *>(const_cast<u64*>(opParam.All2AllDataDes.sendCountMatrix.data()));
RankId curRank = 0;
for (u32 j = 0; j < ranksize; j++) {
u64 curSendCounts = *(static_cast<const u64 *>(sendCountMatrix) + curRank * ranksize + j);
u64 curSendLength = curSendCounts * CHECK_SIZE_TABLE[opParam.All2AllDataDes.sendType];
curSendOffset += curSendLength;
u64 curRecvCounts = *(static_cast<const u64 *>(sendCountMatrix) + curRank + ranksize * j);
u64 curRecvLength = curRecvCounts * CHECK_SIZE_TABLE[opParam.All2AllDataDes.recvType];
curRecvOffset += curRecvLength;
}
inputSize = curSendOffset;
outputSize = curRecvOffset;
} else if (opParam.opType == CheckerOpType::ALLTOALLV) {
void* sendCounts = static_cast<void *>(const_cast<u64*>(opParam.All2AllDataDes.sendCounts.data()));
void* recvCounts = static_cast<void *>(const_cast<u64*>(opParam.All2AllDataDes.recvCounts.data()));
u64 curSendOffset = 0;
u64 curRecvOffset = 0;
for (u32 i = 0; i < ranksize; i++) {
u64 curSendCounts = *(static_cast<const u64 *>(sendCounts) + i);
u64 curSendLength = curSendCounts * CHECK_SIZE_TABLE[opParam.All2AllDataDes.sendType];
curSendOffset += curSendLength;
u64 curRecvCounts = *(static_cast<const u64 *>(recvCounts) + i);
u64 curRecvLength = curRecvCounts * CHECK_SIZE_TABLE[opParam.All2AllDataDes.recvType];
curRecvOffset += curRecvLength;
}
inputSize = curSendOffset;
outputSize = curRecvOffset;
} else if (opParam.opType == CheckerOpType::SCATTER) {
inputSize = count * unitSize * ranksize;
outputSize = count * unitSize;
} else if (opParam.opType == CheckerOpType::BATCH_SEND_RECV) {
if (opParam.allRanksSendRecvInfoVec.size() == 0 || opParam.allRanksSendRecvInfoVec[0].size() == 0) {
HCCL_ERROR("BatchSendRecv allRanksSendRecvInfoVec is empty.");
return;
}
u32 unitSizePerTask = CHECK_SIZE_TABLE[opParam.allRanksSendRecvInfoVec[0][0].dataType];
u64 countPerTask = opParam.allRanksSendRecvInfoVec[0][0].count;
inputSize = ranksize * countPerTask * unitSizePerTask;
outputSize = ranksize * countPerTask * unitSizePerTask;
} else if (opParam.opType == CheckerOpType::REDUCE_SCATTER_V) {
void* counts = static_cast<void *>(const_cast<u64*>(opParam.VDataDes.counts.data()));
inputSize = 0;
for (u32 i = 0; i < ranksize; i++) {
u64 curCounts = *(static_cast<const u64 *>(counts) + i);
u64 curLength = curCounts * CHECK_SIZE_TABLE[opParam.VDataDes.dataType];
inputSize += curLength;
}
outputSize = static_cast<const u64 *>(counts)[myRank] * CHECK_SIZE_TABLE[opParam.VDataDes.dataType];
} else if (opParam.opType == CheckerOpType::ALLGATHER_V) {
void* counts = static_cast<void *>(const_cast<u64*>(opParam.VDataDes.counts.data()));
outputSize = 0;
for (u32 i = 0; i < ranksize; i++) {
u64 curCounts = *(static_cast<const u64 *>(counts) + i);
u64 curLength = curCounts * CHECK_SIZE_TABLE[opParam.VDataDes.dataType];
outputSize += curLength;
}
inputSize = static_cast<const u64 *>(counts)[myRank] * CHECK_SIZE_TABLE[opParam.VDataDes.dataType];
}
return;
}
void CalcDataSize(const CheckerOpParam &opParam, u64 &dataSize)
{
if (opParam.opType == CheckerOpType::BATCH_SEND_RECV) {
u32 unitSize = CHECK_SIZE_TABLE[opParam.allRanksSendRecvInfoVec[0][0].dataType];
u64 count = opParam.allRanksSendRecvInfoVec[0][0].count;
dataSize = count * unitSize;
return;
}
if (!IsAllToAllSeries(opParam.opType) && opParam.opType != CheckerOpType::REDUCE_SCATTER_V &&
opParam.opType != CheckerOpType::ALLGATHER_V) {
u32 unitSize = CHECK_SIZE_TABLE[opParam.DataDes.dataType];
u64 count = opParam.DataDes.count;
dataSize = count * unitSize;
}
return;
}
std::vector<std::string> SplitString(const std::string &str, const char c)
{
std::string::size_type startPos = 0;
std::string::size_type foundPos = str.find(c);
std::vector<std::string> strVector;
while (foundPos != std::string::npos) {
strVector.push_back(str.substr(startPos, foundPos - startPos));
startPos = foundPos + 1;
foundPos = str.find(c, startPos);
}
if (startPos != str.length()) {
strVector.push_back(str.substr(startPos));
}
return strVector;
}
bool DataSliceSizeIsEqual(std::unique_ptr<DataSlice> &a, std::unique_ptr<DataSlice> &b)
{
return a->GetSize() == b->GetSize();
}
bool DataSliceSizeIsEqual(std::unique_ptr<DataSlice> &a, std::unique_ptr<DataSlice> &b, std::unique_ptr<DataSlice> &c)
{
return (a->GetSize() == b->GetSize()) && (b->GetSize() == c->GetSize());
}
}