* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "exe_graph/runtime/infer_shape_context.h"
#include "graph/ge_error_codes.h"
#include <gtest/gtest.h>
#include "faker/kernel_run_context_faker.h"
#include "exe_graph/runtime/storage_shape.h"
#include "register/kernel_registry.h"
namespace gert {
class InferShapeContextUT : public testing::Test {};
TEST_F(InferShapeContextUT, GetInputShapeOk) {
gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
gert::StorageShape in_shape2 = {{2, 2, 3, 8}, {8, 1, 2, 2, 16}};
gert::StorageShape out_shape = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
auto context_holder = InferShapeContextFaker()
.IrInputNum(2)
.NodeIoNum(2, 1)
.NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0)
.NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z)
.InputShapes({&in_shape1, &in_shape2})
.OutputShapes({&out_shape})
.Build();
auto context = context_holder.GetContext<InferShapeContext>();
ASSERT_NE(context, nullptr);
ASSERT_NE(context->GetInputShape(0), nullptr);
EXPECT_EQ(*context->GetInputShape(0), in_shape1.GetOriginShape());
ASSERT_NE(context->GetInputShape(1), nullptr);
EXPECT_EQ(*context->GetInputShape(1), in_shape2.GetOriginShape());
EXPECT_EQ(context->GetInputShape(2), nullptr);
}
TEST_F(InferShapeContextUT, GetDynamicInputShapeOk) {
gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
gert::StorageShape in_shape2 = {{2, 2, 3, 8}, {2, 2, 3, 8}};
gert::StorageShape in_shape3 = {{3, 2, 3, 8}, {3, 2, 3, 8}};
gert::StorageShape in_shape4 = {{4, 2, 3, 8}, {4, 2, 3, 8}};
gert::StorageShape out_shape = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
auto context_holder = InferShapeContextFaker()
.IrInstanceNum({1, 2, 0, 1})
.NodeIoNum(4, 1)
.NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0)
.NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ)
.NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ)
.NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ)
.InputShapes({&in_shape1, &in_shape2, &in_shape3, &in_shape4})
.OutputShapes({&out_shape})
.Build();
auto context = context_holder.GetContext<InferShapeContext>();
ASSERT_NE(context, nullptr);
ASSERT_NE(context->GetOptionalInputShape(0), nullptr);
EXPECT_EQ(*context->GetOptionalInputShape(0), in_shape1.GetOriginShape());
ASSERT_NE(context->GetDynamicInputShape(1, 0), nullptr);
EXPECT_EQ(*context->GetDynamicInputShape(1, 0), in_shape2.GetOriginShape());
ASSERT_NE(context->GetDynamicInputShape(1, 1), nullptr);
EXPECT_EQ(*context->GetDynamicInputShape(1, 1), in_shape3.GetOriginShape());
EXPECT_EQ(context->GetOptionalInputShape(2), nullptr);
ASSERT_NE(context->GetOptionalInputShape(3), nullptr);
EXPECT_EQ(*context->GetOptionalInputShape(3), in_shape4.GetOriginShape());
}
TEST_F(InferShapeContextUT, GetOutShapeOk) {
gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
gert::StorageShape in_shape2 = {{2, 2, 3, 8}, {8, 1, 2, 2, 16}};
gert::StorageShape out_shape = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
auto context_holder = InferShapeContextFaker()
.IrInputNum(2)
.NodeIoNum(2, 1)
.NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0)
.NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z)
.InputShapes({&in_shape1, &in_shape2})
.OutputShapes({&out_shape})
.Build();
auto context = context_holder.GetContext<InferShapeContext>();
ASSERT_NE(context, nullptr);
ASSERT_NE(context->GetOutputShape(0), nullptr);
EXPECT_EQ(*context->GetOutputShape(0), out_shape.GetOriginShape());
EXPECT_EQ(context->GetOutputShape(1), nullptr);
}
TEST_F(InferShapeContextUT, GetOptionalInputTensorFailed_NotSetOptionalInput) {
gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
auto infer_shape_func_addr = reinterpret_cast<void *>(0x11);
auto context_holder = KernelRunContextFaker()
.IrInputNum(2)
.IrInstanceNum({1, 0})
.KernelIONum(2, 0)
.NodeIoNum(1, 0)
.Inputs({&in_shape1, infer_shape_func_addr})
.Build();
auto context = context_holder.GetContext<InferShapeContext>();
ASSERT_NE(context, nullptr);
ASSERT_NE(context->GetInputShape(0), nullptr);
EXPECT_EQ(*context->GetInputShape(0), in_shape1.GetOriginShape());
EXPECT_EQ(context->GetOptionalInputTensor(1), nullptr);
EXPECT_NE(context->GetInputTensor(1), nullptr);
}
TEST_F(InferShapeContextUT, GetOptionalInputTensorOK_SetOptionalInput) {
gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}};
gert::Tensor in_tensor_2 = {{{1, 16, 256}, {1, 16, 256}},
{ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}},
kOnDeviceHbm,
ge::DT_FLOAT16,
(void *) 0xabc};
auto infer_shape_func_addr = reinterpret_cast<void *>(0x11);
auto context_holder = KernelRunContextFaker()
.IrInputNum(2)
.IrInstanceNum({1, 1})
.KernelIONum(3, 0)
.NodeIoNum(2, 0)
.Inputs({&in_shape1, &in_tensor_2, infer_shape_func_addr})
.Build();
auto context = context_holder.GetContext<InferShapeContext>();
ASSERT_NE(context, nullptr);
ASSERT_NE(context->GetInputShape(0), nullptr);
EXPECT_EQ(*context->GetInputShape(0), in_shape1.GetOriginShape());
EXPECT_NE(context->GetOptionalInputTensor(1), nullptr);
EXPECT_NE(context->GetInputTensor(2), nullptr);
}
TEST_F(InferShapeContextUT, GetRequiredInputTensorOk) {
gert::Tensor in_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}},
{ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}},
kOnDeviceHbm,
ge::DT_FLOAT16,
(void *)0x0};
gert::Tensor in_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}},
{ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}},
kOnDeviceHbm,
ge::DT_FLOAT16,
(void *)0x0};
gert::Tensor in_tensor_3 = {{{3, 2, 3, 8}, {3, 2, 3, 8}},
{ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}},
kOnDeviceHbm,
ge::DT_FLOAT16,
(void *)0x0};
gert::Tensor in_tensor_4 = {{{4, 2, 3, 8}, {4, 2, 3, 8}},
{ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}},
kOnDeviceHbm,
ge::DT_FLOAT16,
(void *)0x12345};
gert::Tensor out_tensor = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}},
{ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}},
kOnDeviceHbm,
ge::DT_FLOAT16,
(void *)0x0};
auto context_holder = InferShapeContextFaker()
.IrInstanceNum({1, 2, 0, 1})
.NodeIoNum(4, 1)
.NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0)
.NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ)
.NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ)
.NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ)
.InputShapes({&in_tensor_1, &in_tensor_2, &in_tensor_3, &in_tensor_4})
.OutputShapes({&out_tensor})
.Build();
auto context = context_holder.GetContext<InferShapeContext>();
ASSERT_NE(context, nullptr);
ASSERT_NE(context->GetRequiredInputShape(0), nullptr);
EXPECT_EQ(*context->GetRequiredInputShape(0), in_tensor_1.GetOriginShape());
ASSERT_NE(context->GetRequiredInputTensor(0), nullptr);
EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape());
EXPECT_EQ(context->GetRequiredInputTensor(0)->GetAddr(), nullptr);
ASSERT_NE(context->GetDynamicInputShape(1, 0), nullptr);
EXPECT_EQ(*context->GetDynamicInputShape(1, 0), in_tensor_2.GetOriginShape());
ASSERT_NE(context->GetDynamicInputTensor(1, 0), nullptr);
EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetOriginShape(), in_tensor_2.GetOriginShape());
EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetAddr(), nullptr);
ASSERT_NE(context->GetDynamicInputShape(1, 1), nullptr);
EXPECT_EQ(*context->GetDynamicInputShape(1, 1), in_tensor_3.GetOriginShape());
ASSERT_NE(context->GetDynamicInputTensor(1, 1), nullptr);
EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetOriginShape(), in_tensor_3.GetOriginShape());
EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetAddr(), nullptr);
EXPECT_EQ(context->GetOptionalInputShape(2), nullptr);
ASSERT_NE(context->GetRequiredInputShape(3), nullptr);
EXPECT_EQ(*context->GetRequiredInputShape(3), in_tensor_4.GetOriginShape());
ASSERT_NE(context->GetRequiredInputTensor(3), nullptr);
EXPECT_EQ(context->GetRequiredInputTensor(3)->GetOriginShape(), in_tensor_4.GetOriginShape());
EXPECT_EQ(context->GetRequiredInputTensor(3)->GetAddr(), in_tensor_4.GetAddr());
}
}