/* -------------------------------------------------------------------------

 *  This file is part of the MindStudio project.

 * Copyright (c) 2026 Huawei Technologies Co.,Ltd.

 *

 * MindStudio is licensed under Mulan PSL v2.

 * You can use this software according to the terms and conditions of the Mulan PSL v2.

 * You may obtain a copy of Mulan PSL v2 at:

 *

 *          http://license.coscl.org.cn/MulanPSL2

 *

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,

 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,

 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.

 * See the Mulan PSL v2 for more details.

 * ------------------------------------------------------------------------- */



#include <cmath>



#include "nan_test_tiling.h"

#include "register/op_def_registry.h"

#include "register/op_impl_registry.h"

#include "register/tilingdata_base.h"

#include "tiling/tiling_api.h"



namespace optiling

{

static ge::graphStatus TilingFunc(gert::TilingContext* context)

{

    NanTestTilingData tiling;

    const gert::StorageShape* out_tensor_shape = context->GetInputShape(0);

    int32_t data_sz = 1;

    for (int i = 0; i < out_tensor_shape->GetStorageShape().GetDimNum(); i++)

        data_sz *= out_tensor_shape->GetStorageShape().GetDim(i);

    tiling.set_size(data_sz);



    if (data_sz < 1)

    {

        std::cout << "ERROR! DATA IS NOT ENOUPH, SIZE < 1" << std::endl;

    }



    const auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());

    uint32_t sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();

    size_t* currentWorkspace = context->GetWorkspaceSizes(1);

    currentWorkspace[0] = sysWorkspaceSize;



    context->SetBlockDim(1);

    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());

    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());



    return ge::GRAPH_SUCCESS;

}

}  // namespace optiling



namespace ge

{

static ge::graphStatus InferShape(gert::InferShapeContext* context) { return GRAPH_SUCCESS; }

static ge::graphStatus InferDataType(gert::InferDataTypeContext* context) { return ge::GRAPH_SUCCESS; }

}  // namespace ge



namespace ops

{

class NanTest : public OpDef

{

   public:

    explicit NanTest(const char* name) : OpDef(name)

    {

        this->Input("in")

            .ParamType(REQUIRED)

            .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,

                       ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,

                       ge::DT_INT64})

            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                     ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                     ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})

            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                                 ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                                 ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})

            .AutoContiguous();

        this->Input("tensorlist")

            .ParamType(DYNAMIC)

            .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_INT64,

                       ge::DT_UINT8, ge::DT_UINT16, ge::DT_UINT32, ge::DT_UINT64, ge::DT_BOOL, ge::DT_INT4, ge::DT_BF16,

                       ge::DT_HIFLOAT8, ge::DT_DOUBLE})

            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                     ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                     ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})

            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                                 ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                                 ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})

            .AutoContiguous();

        this->Output("out")

            .ParamType(REQUIRED)

            .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,

                       ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,

                       ge::DT_INT64})

            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                     ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                     ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})

            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                                 ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,

                                 ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});



        this->SetInferShape(ge::InferShape).SetInferDataType(ge::InferDataType);



        this->AICore().SetTiling(optiling::TilingFunc);

        this->AICore().AddConfig("@NAN_TEST_SOC_CONFIG@");

    }

};



OP_ADD(NanTest);

}  // namespace ops