* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_all_gather.cpp
* \brief
*/
#include "distributed_op_test_common.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/configs/config_manager.h"
#include "tilefwk/data_type.h"
#include "test_dev_func_runner.h"
namespace npu::tile_fwk {
namespace Distributed {
template <typename T>
void TestAllGather(OpTestParam& testParam, std::string& goldenDir)
{
constexpr size_t paramsSize = 7;
auto [row, col, validRow, validCol, typeNum, tileRow, tileCol] = GetParams<paramsSize>(goldenDir + "/params.bin");
DataType dType = GetDataTypeNum(typeNum);
Shape shape{row, col};
Shape outShape{testParam.rankSize * row, col};
Tensor in(dType, shape, "in");
Tensor out(dType, outShape, "out");
std::vector<T> inPtr =
ReadToVector<T>(goldenDir + "/input_rank_" + std::to_string(testParam.rankId) + ".bin", shape);
Shape shmemDataShape{testParam.rankSize * row, col};
FUNCTION("ALLGATHER", {in}, {out})
{
in.GetStorage()->UpdateDynValidShape(std::vector<SymbolicScalar>{validRow, validCol});
TileShape::Current().SetVecTile({tileRow, tileCol});
ShmemTensor shmemTensor;
LOOP("CreateShmemTensor", FunctionType::DYNAMIC_LOOP, index, LoopRange(1))
{
(void)index;
CreateShmemTensor(testParam.group, testParam.rankSize, dType, shmemDataShape, shmemTensor);
}
AllGather(in, in, shmemTensor, out);
}
ProgramData::GetInstance().AppendInputs({RawTensorData::CreateTensor<T>(in, inPtr)});
ProgramData::GetInstance().AppendOutputs({RawTensorData::CreateTensorZero(out)});
RunTest();
auto outPtr = ProgramData::GetInstance().GetOutputData(0)->GetDevPtr();
int32_t outSize = row * col * testParam.rankSize;
EXPECT_TRUE(CompareWithGolden<uint8_t*>(dType, goldenDir + "/allgather_out_rank_", outSize, outPtr, testParam));
}
template void TestAllGather<int32_t>(OpTestParam& testParam, std::string& goldenDir);
template void TestAllGather<float>(OpTestParam& testParam, std::string& goldenDir);
template void TestAllGather<float16>(OpTestParam& testParam, std::string& goldenDir);
template void TestAllGather<bfloat16>(OpTestParam& testParam, std::string& goldenDir);
}
}