* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_add_operation.cpp
* \brief
*/
#include "test_operation.h"
using namespace tile_fwk::test_operation;
namespace {
struct ConcatOpFuncArgs : public OpFuncArgs {
ConcatOpFuncArgs(int axis, const std::vector<int64_t>& viewShape, const std::vector<int64_t> tileShape)
: axis_(axis), viewShape_(viewShape), tileShape_(tileShape)
{}
int axis_;
std::vector<int64_t> viewShape_;
std::vector<int64_t> tileShape_;
};
struct ConcatOpMetaData {
explicit ConcatOpMetaData(const OpFunc& opFunc, const nlohmann::json& test_data)
: opFunc_(opFunc), test_data_(test_data)
{}
OpFunc opFunc_;
nlohmann::json test_data_;
};
static std::vector<int64_t> GetConcatViewShape(Tensor tensor, const std::vector<int64_t>& viewShape, int axis)
{
std::vector<int64_t> resultViewShape = {};
for (size_t i = 0; i < viewShape.size(); i++) {
if (static_cast<int>(i) != axis) {
resultViewShape.push_back(std::min(viewShape[i], tensor.GetShape()[i]));
} else {
resultViewShape.push_back(tensor.GetShape()[i]);
}
}
return resultViewShape;
}
static std::vector<SymbolicScalar> GetNoAxisDims(const std::vector<Tensor>& inputs, int axis)
{
Tensor input0 = inputs[0];
std::vector<SymbolicScalar> dimensions = {};
for (size_t i = 0; i < input0.GetShape().size(); i++) {
if (static_cast<int>(i) == axis) {
continue;
}
dimensions.push_back(input0.GetShape()[i]);
}
return dimensions;
}
static std::vector<int64_t> GetNoAxisViewShapes(const std::vector<int64_t>& viewShapes, int axis)
{
std::vector<int64_t> shapes = {};
for (size_t i = 0; i < viewShapes.size(); i++) {
if (static_cast<int>(i) == axis) {
continue;
}
shapes.push_back(viewShapes[i]);
}
return shapes;
}
static std::vector<std::reference_wrapper<const Tensor>> AsRef(const std::vector<Tensor>& tensors)
{
std::vector<std::reference_wrapper<const Tensor>> results;
results.reserve(tensors.size());
for (const Tensor& tensor : tensors) {
results.emplace_back(std::cref(tensor));
}
return results;
}
static void ConcatOperationExeFuncDoubleCut(
const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
auto inputRefs = AsRef(inputs);
auto args = static_cast<const ConcatOpFuncArgs*>(opArgs);
int axis = 0;
if (args->axis_ < 0) {
axis = args->axis_ + inputs[0].GetShape().size();
} else {
axis = args->axis_;
}
FUNCTION("main", inputRefs, {outputs[0]})
{
std::vector<SymbolicScalar> noConcatAxisDimensions = GetNoAxisDims(inputs, axis);
std::vector<int64_t> noConcatAxisViewShapes = GetNoAxisViewShapes(args->viewShape_, axis);
const int bloop = CeilDiv(noConcatAxisDimensions[0], noConcatAxisViewShapes[0]);
LOOP("LOOP_L1_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bloop, 1))
{
std::vector<SymbolicScalar> indices = {bIdx};
indices.insert(indices.begin() + axis, 0);
std::vector<Tensor> concatTensors = {};
for (size_t tensorId = 0; tensorId < inputs.size(); tensorId++) {
const std::vector<int64_t> tensorViewShape =
GetConcatViewShape(inputs[tensorId], args->viewShape_, axis);
SymbolicScalar validShape0 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[0]) - indices[0] * tensorViewShape[0],
tensorViewShape[0]);
SymbolicScalar validShape1 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[1]) - indices[1] * tensorViewShape[1],
tensorViewShape[1]);
auto tileTensor = View(
inputs[tensorId], tensorViewShape, {validShape0, validShape1},
{indices[0] * tensorViewShape[0], indices[1] * tensorViewShape[1]});
concatTensors.push_back(tileTensor);
}
TileShape::Current().SetVecTile(args->tileShape_);
auto res = Cat(concatTensors, axis);
Assemble(res, {indices[0] * args->viewShape_[0], indices[1] * args->viewShape_[1]}, outputs[0]);
}
}
}
static void ConcatOperationExeFuncTripleCut(
const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
auto inputRefs = AsRef(inputs);
auto args = static_cast<const ConcatOpFuncArgs*>(opArgs);
int axis = 0;
if (args->axis_ < 0) {
axis = args->axis_ + inputs[0].GetShape().size();
} else {
axis = args->axis_;
}
FUNCTION("main", inputRefs, {outputs[0]})
{
std::vector<SymbolicScalar> noConcatAxisDimensions = GetNoAxisDims(inputs, axis);
std::vector<int64_t> noConcatAxisViewShapes = GetNoAxisViewShapes(args->viewShape_, axis);
const int bloop = CeilDiv(noConcatAxisDimensions[0], noConcatAxisViewShapes[0]);
const int sloop = CeilDiv(noConcatAxisDimensions[1], noConcatAxisViewShapes[1]);
LOOP("LOOP_L1_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bloop, 1))
{
LOOP("LOOP_L1_sIdx", FunctionType::DYNAMIC_LOOP, sIdx, LoopRange(0, sloop, 1))
{
std::vector<SymbolicScalar> indices = {bIdx, sIdx};
indices.insert(indices.begin() + axis, 0);
std::vector<Tensor> concatTensors = {};
for (size_t tensorId = 0; tensorId < inputs.size(); tensorId++) {
const std::vector<int64_t> tensorViewShape =
GetConcatViewShape(inputs[tensorId], args->viewShape_, axis);
SymbolicScalar validShape0 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[0]) - indices[0] * tensorViewShape[0],
tensorViewShape[0]);
SymbolicScalar validShape1 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[1]) - indices[1] * tensorViewShape[1],
tensorViewShape[1]);
SymbolicScalar validShape2 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[2]) - indices[2] * tensorViewShape[2],
tensorViewShape[2]);
auto tileTensor = View(
inputs[tensorId], tensorViewShape, {validShape0, validShape1, validShape2},
{
indices[0] * tensorViewShape[0],
indices[1] * tensorViewShape[1],
indices[2] * tensorViewShape[2],
});
concatTensors.push_back(tileTensor);
}
TileShape::Current().SetVecTile(args->tileShape_);
auto res = Cat(concatTensors, axis);
Assemble(
res,
{
indices[0] * args->viewShape_[0],
indices[1] * args->viewShape_[1],
indices[2] * args->viewShape_[2],
},
outputs[0]);
}
}
}
}
static void ConcatOperationExeFuncQuadraticCut(
const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
auto inputRefs = AsRef(inputs);
int axis = 0;
auto args = static_cast<const ConcatOpFuncArgs*>(opArgs);
if (args->axis_ < 0) {
axis = args->axis_ + inputs[0].GetShape().size();
} else {
axis = args->axis_;
}
FUNCTION("main", inputRefs, {outputs[0]})
{
std::vector<SymbolicScalar> noConcatAxisDimensions = GetNoAxisDims(inputs, axis);
std::vector<int64_t> noConcatAxisViewShapes = GetNoAxisViewShapes(args->viewShape_, axis);
const int bloop = CeilDiv(noConcatAxisDimensions[0], noConcatAxisViewShapes[0]);
const int sloop = CeilDiv(noConcatAxisDimensions[1], noConcatAxisViewShapes[1]);
const int kloop = CeilDiv(noConcatAxisDimensions[2], noConcatAxisViewShapes[2]);
LOOP("LOOP_L1_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bloop, 1))
{
LOOP("LOOP_L1_sIdx", FunctionType::DYNAMIC_LOOP, sIdx, LoopRange(0, sloop, 1))
{
LOOP("LOOP_L1_kIdx", FunctionType::DYNAMIC_LOOP, kIdx, LoopRange(0, kloop, 1))
{
std::vector<SymbolicScalar> indices = {bIdx, sIdx, kIdx};
indices.insert(indices.begin() + axis, 0);
std::vector<Tensor> concatTensors = {};
for (size_t tensorId = 0; tensorId < inputs.size(); tensorId++) {
const std::vector<int64_t> tensorViewShape =
GetConcatViewShape(inputs[tensorId], args->viewShape_, axis);
SymbolicScalar validShape0 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[0]) - indices[0] * tensorViewShape[0],
tensorViewShape[0]);
SymbolicScalar validShape1 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[1]) - indices[1] * tensorViewShape[1],
tensorViewShape[1]);
SymbolicScalar validShape2 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[2]) - indices[2] * tensorViewShape[2],
tensorViewShape[2]);
SymbolicScalar validShape3 = std::min(
SymbolicScalar(inputs[tensorId].GetShape()[3]) - indices[3] * tensorViewShape[3],
tensorViewShape[3]);
auto tileTensor = View(
inputs[tensorId], tensorViewShape, {validShape0, validShape1, validShape2, validShape3},
{
indices[0] * tensorViewShape[0],
indices[1] * tensorViewShape[1],
indices[2] * tensorViewShape[2],
indices[3] * tensorViewShape[3],
});
concatTensors.push_back(tileTensor);
}
TileShape::Current().SetVecTile(args->tileShape_);
auto res = Cat(concatTensors, axis);
Assemble(
res,
{
indices[0] * args->viewShape_[0],
indices[1] * args->viewShape_[1],
indices[2] * args->viewShape_[2],
indices[3] * args->viewShape_[3],
},
outputs[0]);
}
}
}
}
}
class ConcatOperationTest : public npu::tile_fwk::stest::TestSuite_STest_Ops_Aihac_param<ConcatOpMetaData> {};
INSTANTIATE_TEST_SUITE_P(
TestConcat, ConcatOperationTest,
::testing::ValuesIn(GetOpMetaData<ConcatOpMetaData>(
{ConcatOperationExeFuncDoubleCut, ConcatOperationExeFuncTripleCut, ConcatOperationExeFuncQuadraticCut},
"Concat")));
TEST_P(ConcatOperationTest, TestConcat)
{
auto test_data = GetParam().test_data_;
int axis = GetValueByName<int>(test_data, "axis");
auto args = ConcatOpFuncArgs(axis, GetViewShape(test_data), GetTileShape(test_data));
auto testCase = CreateTestCaseDesc<ConcatOpMetaData>(GetParam(), &args);
TestExecutor::runTest(testCase);
}
}