/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* Generated By CANNBot */

/*!
 * \file test_inv_infershape.cpp
 * \brief Inv InferShape UT — verifies y.shape == x.shape across dtypes, ranks, formats.
 */

#include <gtest/gtest.h>
#include <iostream>
#include "infershape_context_faker.h"
#include "infershape_case_executor.h"

class InvInfershape : public testing::Test {
protected:
    static void SetUpTestCase()
    {
        std::cout << "InvInfershape SetUp" << std::endl;
    }

    static void TearDownTestCase()
    {
        std::cout << "InvInfershape TearDown" << std::endl;
    }
};

// 1D shape passthrough (FP32)
TEST_F(InvInfershape, inv_infershape_1d_fp32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{10}, {10}}, ge::DT_FLOAT, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {10},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// 3D shape passthrough (FP32)
TEST_F(InvInfershape, inv_infershape_3d_fp32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{2, 3, 4}, {2, 3, 4}}, ge::DT_FLOAT, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {2, 3, 4},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// 4D shape passthrough (FP16)
TEST_F(InvInfershape, inv_infershape_4d_fp16_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{1, 64, 2, 64}, {1, 64, 2, 64}}, ge::DT_FLOAT16, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT16, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {1, 64, 2, 64},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// BF16 dtype
TEST_F(InvInfershape, inv_infershape_bf16_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{2, 3, 4}, {2, 3, 4}}, ge::DT_BF16, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_BF16, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {2, 3, 4},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// 5D large rank
TEST_F(InvInfershape, inv_infershape_5d_fp32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}}, ge::DT_FLOAT, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {1, 2, 3, 4, 5},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// Empty tensor (contains a 0 dim) — InferShape must still pass through
TEST_F(InvInfershape, inv_infershape_empty_tensor_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{0, 3, 4}, {0, 3, 4}}, ge::DT_FLOAT, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {0, 3, 4},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// Scalar shape (0-dim) — InferShape just copies, does not reshape
TEST_F(InvInfershape, inv_infershape_scalar_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{}, {}}, ge::DT_FLOAT, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// ====================================================================
// int32 路径 InferShape(A2 / 迭代一)—— int32→int32 shape 透传
// dtype 由 OpDef Input/Output DataType 列表按位跟随,InferShape 仅做 shape 透传
// ====================================================================

// int32 1D shape passthrough
TEST_F(InvInfershape, inv_infershape_1d_int32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{10}, {10}}, ge::DT_INT32, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_INT32, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {10},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// int32 3D shape passthrough
TEST_F(InvInfershape, inv_infershape_3d_int32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{2, 3, 4}, {2, 3, 4}}, ge::DT_INT32, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_INT32, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {2, 3, 4},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// int32 empty tensor passthrough
TEST_F(InvInfershape, inv_infershape_empty_int32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{0, 3, 4}, {0, 3, 4}}, ge::DT_INT32, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_INT32, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {0, 3, 4},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// int32 scalar (rank0) passthrough
TEST_F(InvInfershape, inv_infershape_scalar_int32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{}, {}}, ge::DT_INT32, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_INT32, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// ====================================================================
// 迭代三 A2:全覆盖收口 —— InferShape 全 dtype × 全 rank 对称性补齐
// 既有:fp32(1/3/4/5D+empty+scalar)、fp16(4D)、bf16(3D)、int32(1/3D+empty+scalar)。
// 补齐:fp16/bf16 的 rank0 标量与 8D 最大 rank、int32 8D,确认 InferShape 对
//       任意 dtype × 任意 rank(0..8)均为纯 shape 透传。
// ====================================================================

// fp16 scalar (rank0) passthrough
TEST_F(InvInfershape, inv_infershape_scalar_fp16_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{}, {}}, ge::DT_FLOAT16, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT16, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// bf16 scalar (rank0) passthrough
TEST_F(InvInfershape, inv_infershape_scalar_bf16_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{}, {}}, ge::DT_BF16, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_BF16, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// fp16 8D (max rank) passthrough
TEST_F(InvInfershape, inv_infershape_8d_fp16_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{2, 2, 2, 2, 2, 2, 2, 2}, {2, 2, 2, 2, 2, 2, 2, 2}}, ge::DT_FLOAT16, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_FLOAT16, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {2, 2, 2, 2, 2, 2, 2, 2},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}

// int32 8D (max rank) passthrough
TEST_F(InvInfershape, inv_infershape_8d_int32_test)
{
    gert::InfershapeContextPara infershapeContextPara(
        "Inv",
        {
            {{{2, 2, 2, 2, 2, 2, 2, 2}, {2, 2, 2, 2, 2, 2, 2, 2}}, ge::DT_INT32, ge::FORMAT_ND},
        },
        {
            {{{}, {}}, ge::DT_INT32, ge::FORMAT_ND},
        });
    std::vector<std::vector<int64_t>> expectOutputShape = {
        {2, 2, 2, 2, 2, 2, 2, 2},
    };
    ExecuteTestCase(infershapeContextPara, ge::GRAPH_SUCCESS, expectOutputShape);
}