* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_shmem_wait_until.cpp
* \brief
*/
#include <gtest/gtest.h>
#include "machine/device/distributed/common.h"
#include "machine/device/distributed/shmem_wait_until.h"
#include "machine/utils/dynamic/dev_workspace.h"
#include "machine/runtime/distributed/hccl_context.h"
#include "machine/device/dynamic/aicore_manager.h"
namespace {
struct TileIndexInfo {
uint32_t tileShapeDim;
uint32_t startDim;
std::vector<uint32_t> viewshapes;
std::vector<uint32_t> dimTileNums;
std::vector<uint32_t> viewTileStrides;
std::vector<uint32_t> viewIndexStrides;
uint32_t viewTileNum;
uint32_t viewIndexNum;
uint32_t totalTileNum;
};
TileIndexInfo CalculateTileIndexInfo(
const std::vector<uint32_t>& shmemSignalRawShape, const std::vector<uint32_t>& shmemSignalShape,
const std::vector<uint32_t>& tileShape)
{
uint32_t tileShapeDim = tileShape.size();
uint32_t shmemSignalDim = shmemSignalRawShape.size();
uint32_t startDim = shmemSignalDim - tileShapeDim;
std::vector<uint32_t> viewshapes(tileShapeDim);
std::vector<uint32_t> dimTileNums(tileShapeDim);
std::vector<uint32_t> viewTileStrides(tileShapeDim);
std::vector<uint32_t> viewIndexStrides(tileShapeDim);
uint32_t viewTileNum = 1;
uint32_t viewIndexNum = 1;
viewTileStrides[0] = 1;
viewIndexStrides[0] = 1;
for (uint32_t i = 0; i < tileShapeDim; ++i) {
uint32_t curDim = startDim + i;
viewshapes[i] = shmemSignalShape[curDim];
uint32_t totalShape = shmemSignalShape[curDim];
uint32_t tileShapeVal = tileShape[i];
dimTileNums[i] = (totalShape + tileShapeVal - 1) / tileShapeVal;
viewTileNum *= dimTileNums[i];
viewIndexNum *= shmemSignalRawShape[curDim] / shmemSignalShape[curDim];
if (i > 0) {
viewTileStrides[i] = viewTileStrides[i - 1] * dimTileNums[i - 1];
viewIndexStrides[i] =
viewIndexStrides[i - 1] * (shmemSignalRawShape[curDim - 1] / shmemSignalShape[curDim - 1]);
}
}
uint32_t totalTileNum = viewTileNum * viewIndexNum;
return TileIndexInfo{tileShapeDim, startDim, viewshapes, dimTileNums, viewTileStrides,
viewIndexStrides, viewTileNum, viewIndexNum, totalTileNum};
}
uint32_t CalculateTileIndex(
const TileIndexInfo& tileInfo, const std::vector<uint32_t>& shmemSignalOffset,
const std::vector<uint32_t>& tileShape)
{
uint32_t tileIndex = 0;
uint32_t viewIndexAccum = 0;
for (uint32_t dimIdx = 0; dimIdx < tileInfo.tileShapeDim; ++dimIdx) {
uint32_t curDim = tileInfo.startDim + dimIdx;
uint32_t viewShape = tileInfo.viewshapes[dimIdx];
uint32_t offset = shmemSignalOffset[curDim];
uint32_t tileShapeVal = tileShape[dimIdx];
uint32_t viewIdx = offset / viewShape;
uint32_t viewOffset = offset % viewShape;
uint32_t viewTileIdx = viewOffset / tileShapeVal;
tileIndex += viewTileIdx * tileInfo.viewTileStrides[dimIdx];
viewIndexAccum += viewIdx * tileInfo.viewIndexStrides[dimIdx];
}
tileIndex += viewIndexAccum * tileInfo.viewTileNum;
return tileIndex;
}
std::vector<int32_t> InitializeShmemSignal(
std::vector<uint32_t> shmemSignalRawShape, std::vector<uint32_t> shmemSignalOffset, std::vector<uint32_t> tileShape,
uint32_t shmemSignalStride, int32_t expectedValue)
{
auto tileInfo = CalculateTileIndexInfo(shmemSignalRawShape, shmemSignalRawShape, tileShape);
uint32_t tileIndex = CalculateTileIndex(tileInfo, shmemSignalOffset, tileShape);
uint32_t size =
std::accumulate(shmemSignalRawShape.begin() + 1, shmemSignalRawShape.end(), 1, std::multiplies<uint32_t>());
std::vector<int32_t> shmemSignal(size, 0);
uint32_t index = (tileInfo.totalTileNum * shmemSignalOffset[0] + tileIndex) * shmemSignalStride;
shmemSignal[index] = expectedValue;
return shmemSignal;
}
constexpr size_t codeSize = 35;
struct AicpuCodeParams {
uint32_t opcode;
uint32_t oOperandTotalParamNum;
uint32_t outDim;
uint32_t outAttrOffset;
uint32_t iOperandTotalParamNum;
uint32_t predTokenDim;
uint32_t predTokenAttrOffset;
uint32_t shmemSignalDim;
uint32_t shmemSignalShapeNum;
uint32_t attrSize;
};
AicpuCodeParams PrepareAicpuCodeParams(const TileIndexInfo& tileInfo)
{
uint32_t paramSizePerOperand = 2;
uint32_t oOperandNum = 1;
uint32_t iOperandNum = 2;
uint32_t shmemSignalDim = 5;
return AicpuCodeParams{
.opcode = static_cast<uint32_t>(-1),
.oOperandTotalParamNum = paramSizePerOperand * oOperandNum,
.outDim = 2,
.outAttrOffset = static_cast<uint32_t>(-1),
.iOperandTotalParamNum = paramSizePerOperand * iOperandNum,
.predTokenDim = 2,
.predTokenAttrOffset = static_cast<uint32_t>(-1),
.shmemSignalDim = shmemSignalDim,
.shmemSignalShapeNum = shmemSignalDim * 2,
.attrSize = 3 + tileInfo.tileShapeDim + tileInfo.tileShapeDim * 3 + 2};
}
std::array<uint32_t, codeSize> BuildAicpuCodeData(
const AicpuCodeParams& params, const std::vector<uint32_t>& shmemSignalRawShape,
const std::vector<uint32_t>& shmemSignalShape, const TileIndexInfo& tileInfo, uint32_t shmemSignalAttrOffset,
uint32_t shmemSignalStride, int32_t expectedValue, const std::vector<uint32_t>& tileShape)
{
uint32_t resetSignal = 0;
return std::array<uint32_t, codeSize>{
params.opcode,
params.oOperandTotalParamNum,
params.outDim,
params.outAttrOffset,
params.iOperandTotalParamNum,
params.predTokenDim,
params.predTokenAttrOffset,
params.shmemSignalDim,
shmemSignalAttrOffset,
params.shmemSignalShapeNum,
shmemSignalRawShape[0],
shmemSignalRawShape[1],
shmemSignalRawShape[2],
shmemSignalRawShape[3],
shmemSignalRawShape[4],
shmemSignalShape[0],
shmemSignalShape[1],
shmemSignalShape[2],
shmemSignalShape[3],
shmemSignalShape[4],
params.attrSize,
static_cast<uint32_t>(expectedValue),
shmemSignalStride,
resetSignal,
tileInfo.tileShapeDim,
tileShape[0],
tileShape[1],
tileInfo.viewshapes[0],
tileInfo.viewshapes[1],
tileInfo.viewTileStrides[0],
tileInfo.viewTileStrides[1],
tileInfo.viewIndexStrides[0],
tileInfo.viewIndexStrides[1],
tileInfo.viewTileNum,
tileInfo.totalTileNum};
}
auto InitializeAicpuCode(
std::vector<uint32_t> shmemSignalRawShape, std::vector<uint32_t> tileShape, uint32_t shmemSignalStride,
int32_t expectedValue, uint32_t shmemSignalAttrOffset)
{
std::vector<uint32_t> shmemSignalShape = shmemSignalRawShape;
auto tileInfo = CalculateTileIndexInfo(shmemSignalRawShape, shmemSignalShape, tileShape);
auto params = PrepareAicpuCodeParams(tileInfo);
auto initData = BuildAicpuCodeData(
params, shmemSignalRawShape, shmemSignalShape, tileInfo, shmemSignalAttrOffset, shmemSignalStride,
expectedValue, tileShape);
auto data = std::make_unique<int32_t[]>(codeSize);
std::copy(initData.begin(), initData.end(), data.get());
npu::tile_fwk::dynamic::DevRelocVector<int32_t> aicpuCode(codeSize, data.get());
return std::make_tuple(std::move(data), std::move(aicpuCode));
}
auto InitializeTaskData(npu::tile_fwk::dynamic::DynDeviceTask* task)
{
size_t headerSize = sizeof(npu::tile_fwk::DynFuncHeader);
size_t dataSize = sizeof(npu::tile_fwk::DynFuncData);
std::unique_ptr<void, decltype(&free)> buffer(
malloc(headerSize + dataSize + sizeof(npu::tile_fwk::DevStartArgsBase) + sizeof(int64_t)), free);
auto* header = new (buffer.get()) npu::tile_fwk::DynFuncHeader();
auto* funcData = new (header + 1) npu::tile_fwk::DynFuncData();
auto* startArgs = new (funcData + 1) npu::tile_fwk::DevStartArgsBase();
auto* commContext = new (startArgs + 1) int64_t;
startArgs->commContexts = commContext;
funcData->startArgs = startArgs;
task->dynFuncDataList = header;
task->dynFuncDataList[0].seqNo = 1;
task->dynFuncDataList[0].funcNum = 1;
task->dynFuncDataList[0].funcSize = 1u;
task->dynFuncDataList[0].cceBinary = nullptr;
return std::make_tuple(std::move(buffer), funcData);
}
auto ConfigureFuncData(npu::tile_fwk::DynFuncData* funcData, uint64_t rawAddr)
{
constexpr size_t exprTblSize = 50;
auto exprTbl = std::make_unique<uint64_t[]>(exprTblSize);
funcData->exprTbl = exprTbl.get();
auto hcclParam = std::make_unique<npu::tile_fwk::HcclCombinOpParam>();
hcclParam->rankNum = 0;
hcclParam->windowsIn[0] = rawAddr;
auto rawTensorAddrHolder = std::make_unique<uint64_t[]>(1);
auto rawTensorDescHolder = std::make_unique<npu::tile_fwk::DevRawTensorDesc[]>(1);
rawTensorAddrHolder[0] = 0;
rawTensorDescHolder[0] = {0, 0};
funcData->rawTensorAddr = rawTensorAddrHolder.get();
funcData->rawTensorDesc = rawTensorDescHolder.get();
funcData->startArgs->commContexts[0] = reinterpret_cast<int64_t>(hcclParam.get());
funcData->startArgs->commGroupNum = 1;
constexpr size_t opAttrsLength = 17;
auto opAttrs = std::make_unique<uint64_t[]>(opAttrsLength);
std::fill_n(opAttrs.get(), opAttrsLength, 0);
funcData->opAttrs = opAttrs.get();
auto opAtrrOffsets = std::make_unique<int32_t[]>(1);
opAtrrOffsets[0] = 0;
funcData->opAtrrOffsets = opAtrrOffsets.get();
return std::make_tuple(
std::move(exprTbl), std::move(hcclParam), std::move(rawTensorAddrHolder), std::move(rawTensorDescHolder),
std::move(opAttrs), std::move(opAtrrOffsets));
}
auto InitializeTestEnvironment()
{
uint32_t worldSize = 4;
uint32_t shmemSignalRawShape2 = 1;
uint32_t shmemSignalRawShape3 = 64;
uint32_t shmemSignalRawShape4 = 5120;
std::vector<uint32_t> shmemSignalRawShape{
worldSize, worldSize, shmemSignalRawShape2, shmemSignalRawShape3, shmemSignalRawShape4};
std::vector<uint32_t> shmemSignalOffset(shmemSignalRawShape.size());
std::vector<uint32_t> tileShape{1, shmemSignalRawShape4};
uint32_t shmemSignalStride = 8;
int32_t expectedValue = 8;
std::vector<int32_t> rawAddr =
InitializeShmemSignal(shmemSignalRawShape, shmemSignalOffset, tileShape, shmemSignalStride, expectedValue);
uint32_t shmemSignalAttrOffset = 0;
auto [data, aicpuCode] =
InitializeAicpuCode(shmemSignalRawShape, tileShape, shmemSignalStride, expectedValue, shmemSignalAttrOffset);
auto allocator = std::make_unique<npu::tile_fwk::dynamic::DeviceWorkspaceAllocator>();
auto task = std::make_unique<npu::tile_fwk::dynamic::DynDeviceTask>(*allocator);
auto shmemWaitUntil = std::make_unique<npu::tile_fwk::Distributed::ShmemWaitUntilImpl>();
auto cache = std::make_unique<npu::tile_fwk::Distributed::ShmemWaitUntilCache>();
auto [buffer, funcData] = InitializeTaskData(task.get());
auto [exprTbl, hcclParam, rawTensorAddrHolder, rawTensorDescHolder, opAttrs, opAtrrOffsets] =
ConfigureFuncData(funcData, reinterpret_cast<uint64_t>(rawAddr.data()));
task->shmemWaitUntilCacheBackup = cache.get();
return std::make_tuple(
std::move(rawAddr), std::move(data), std::move(allocator), std::move(task), std::move(shmemWaitUntil),
std::move(cache), std::move(buffer), std::move(exprTbl), std::move(hcclParam), std::move(rawTensorAddrHolder),
std::move(rawTensorDescHolder), std::move(opAttrs), std::move(opAtrrOffsets), std::move(aicpuCode), funcData);
}
void PrepareTasks(
uint32_t tileOpCount, npu::tile_fwk::Distributed::ShmemWaitUntilCache* cache,
const npu::tile_fwk::dynamic::DevRelocVector<int32_t>& aicpuCode, npu::tile_fwk::DynFuncData* funcData,
int64_t* hcclContextAddr)
{
for (uint32_t taskId = 0; taskId < tileOpCount; ++taskId) {
npu::tile_fwk::Distributed::ShmemWaitUntilImpl::PrepareTask(
taskId, aicpuCode, cache->taskArray, taskId, funcData, hcclContextAddr);
}
cache->taskCount = tileOpCount;
npu::tile_fwk::Distributed::ShmemWaitUntilImpl::BuildHashTable(cache, tileOpCount);
}
void RunTests(
uint32_t tileOpCount, npu::tile_fwk::Distributed::ShmemWaitUntilImpl* shmemWaitUntil, uint32_t parallelIdx = 0)
{
TaskStat* taskStat{nullptr};
for (uint32_t taskId = 0; taskId < tileOpCount; ++taskId) {
shmemWaitUntil->EnqueueOp(taskId, parallelIdx, taskStat);
shmemWaitUntil->PollCompleted(nullptr, parallelIdx);
}
}
void TestShmemWaitUntil(const uint32_t tileOpCount)
{
auto
[rawAddr, data, allocator, task, shmemWaitUntil, cache, buffer, exprTbl, hcclParam, rawTensorAddrHolder,
rawTensorDescHolder, opAttrs, opAtrrOffsets, aicpuCode, funcData] = InitializeTestEnvironment();
PrepareTasks(tileOpCount, cache.get(), aicpuCode, funcData, funcData->startArgs->commContexts);
constexpr uint32_t parallelIdx = 0;
shmemWaitUntil->LoadCache(cache.get(), parallelIdx);
RunTests(tileOpCount, shmemWaitUntil.get(), parallelIdx);
}
TEST(ShmemWaitUntilTest, BasicFunctionality)
{
constexpr int32_t tileOpCount = 1;
TestShmemWaitUntil(tileOpCount);
}
}