* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_operation.cpp
*/
#include "gtest/gtest.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/function/function.h"
#include "interface/operation/operation.h"
#include "interface/program/program.h"
#include "tilefwk/data_type.h"
#include "tilefwk/platform.h"
using namespace npu::tile_fwk;
class OperationOpsTest : public testing::Test {};
TEST_F(OperationOpsTest, CheckIndexAddParamsInvalid_FP16_Overflow)
{
std::vector<int64_t> selfShape = {10, 10};
std::vector<int64_t> srcShape = {5, 10};
std::vector<int64_t> indicesShape = {5};
int axis = 0;
Tensor self(DT_FP16, selfShape);
Tensor src(DT_FP16, srcShape);
Tensor indices(DT_INT32, indicesShape);
Element alpha(DT_FP16, 65505.0f);
EXPECT_THROW(IndexAdd_(self, src, indices, axis, alpha), std::exception);
}
TEST_F(OperationOpsTest, Range_UnsupportedStartDataType)
{
Element start(DT_INT8, 0);
Element end(DT_INT32, 10);
Element step(DT_INT32, 1);
EXPECT_THROW(Range(start, end, step), std::exception);
}
TEST_F(OperationOpsTest, Range_UnsupportedEndDataType)
{
Element start(DT_INT32, 0);
Element end(DT_INT8, 10);
Element step(DT_INT32, 1);
EXPECT_THROW(Range(start, end, step), std::exception);
}
TEST_F(OperationOpsTest, Range_UnsupportedStepDataType)
{
Element start(DT_INT32, 0);
Element end(DT_INT32, 10);
Element step(DT_INT8, 1);
EXPECT_THROW(Range(start, end, step), std::exception);
}
TEST_F(OperationOpsTest, Range_UnsupportedOutputDataType)
{
Element start(DT_INT64, 0);
Element end(DT_INT64, INT64_MAX);
Element step(DT_INT64, 1);
EXPECT_THROW(Range(start, end, step), std::exception);
}
TEST_F(OperationOpsTest, LogicalNot_UnsupportedDataType)
{
std::vector<int64_t> shape = {10, 10};
Tensor input(DT_INT32, shape);
EXPECT_THROW(LogicalNot(input), std::exception);
}
TEST_F(OperationOpsTest, QuantMX_DefaultRoundDownFp8Output)
{
Platform::Instance().GetSoc().SetNPUArch(NPUArch::DAV_3510);
TileShape::Current().SetVecTile(8, 64);
Tensor input(DT_FP32, {8, 64});
FUNCTION("QuantMXDefaultFp8", {input})
{
auto defaultRes = QuantMX(input);
EXPECT_EQ(std::get<0>(defaultRes).GetDataType(), DT_FP8E4M3);
EXPECT_EQ(std::get<1>(defaultRes).GetDataType(), DT_FP8E8M0);
EXPECT_EQ(std::get<1>(defaultRes).GetShape(), std::vector<int64_t>({8, 1, 2}));
}
Platform::Instance().GetSoc().SetNPUArch(NPUArch::DAV_UNKNOWN);
}
TEST_F(OperationOpsTest, QuantMX_RoundUpFp8AndFp4Output)
{
Platform::Instance().GetSoc().SetNPUArch(NPUArch::DAV_3510);
TileShape::Current().SetVecTile(8, 128);
Tensor fp16Input(DT_FP16, {8, 128});
Tensor fp32Input(DT_FP32, {8, 128});
FUNCTION("QuantMXRoundUp", {fp16Input, fp32Input})
{
auto fp8Res = QuantMX(fp32Input, DT_FP8E4M3, DequantScaleRoundingMode::ROUND_UP, -1, true);
EXPECT_EQ(std::get<0>(fp8Res).GetDataType(), DT_FP8E4M3);
EXPECT_EQ(std::get<1>(fp8Res).GetDataType(), DT_FP8E8M0);
auto fp4Res = QuantMX(fp16Input, DT_FP4_E2M1X2, DequantScaleRoundingMode::ROUND_UP, -1, true);
EXPECT_EQ(std::get<0>(fp4Res).GetDataType(), DT_FP4_E2M1X2);
EXPECT_EQ(std::get<1>(fp4Res).GetDataType(), DT_FP8E8M0);
}
Platform::Instance().GetSoc().SetNPUArch(NPUArch::DAV_UNKNOWN);
}
TEST_F(OperationOpsTest, QuantMX_Fp32ToFp4Unsupported)
{
Platform::Instance().GetSoc().SetNPUArch(NPUArch::DAV_3510);
TileShape::Current().SetVecTile(8, 64);
Tensor input(DT_FP32, {8, 64});
FUNCTION("QuantMXFp32ToFp4Unsupported", {input})
{
EXPECT_THROW(QuantMX(input, DT_FP4_E2M1X2, DequantScaleRoundingMode::ROUND_DOWN, -1, true), std::exception);
EXPECT_THROW(QuantMX(input, DT_FP4_E2M1X2, DequantScaleRoundingMode::ROUND_UP, -1, true), std::exception);
}
Platform::Instance().GetSoc().SetNPUArch(NPUArch::DAV_UNKNOWN);
}