/*
* -------------------------------------------------------------------------
*  This file is part of the Vision SDK project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* Vision SDK is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
*           http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
 * Description: Rotate operator host file.
 * Author: MindX SDK
 * Create: 2024
 * History: NA
 */

#include "rotate_tiling.h"
#include "register/op_def_registry.h"

using namespace std;

namespace optiling {

constexpr uint32_t BLOCK_DIM = 8;
constexpr uint64_t TILING_KEY_HALF = 1;
constexpr uint64_t TILING_KEY_FLOAT = 2;
constexpr uint64_t MIN_SHAPE_SIZE = 2;

static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
    if (context == nullptr) {
        return ge::GRAPH_FAILED;
    }
    RotateTilingData tiling;
    auto tensorX = context->GetInputTensor(0);
    if (tensorX == nullptr) {
        return ge::GRAPH_FAILED;
    }
    uint32_t totalLength = tensorX->GetShapeSize();
    tiling.set_size(totalLength);

    if (context->GetInputShape(0) == nullptr || context->GetInputShape(1) == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto srcShape = context->GetInputShape(0)->GetStorageShape();
    auto offsetShape = context->GetInputShape(1)->GetStorageShape();
    if (srcShape.GetDimNum() < MIN_SHAPE_SIZE || offsetShape.GetDimNum() < MIN_SHAPE_SIZE) {
        return ge::GRAPH_FAILED;
    }
    tiling.set_height(srcShape[0]);
    tiling.set_width(srcShape[1]);
    tiling.set_offsetHeight(offsetShape[0]);
    tiling.set_offsetWidth(offsetShape[1]);

    auto attrs = context->GetAttrs();
    if (attrs == nullptr) {
        return ge::GRAPH_FAILED;
    }
    const int* angle = attrs->GetAttrPointer<int>(0);
    const int* needBlockNum = attrs->GetAttrPointer<int>(1);
    if (angle == nullptr || needBlockNum == nullptr) {
        return ge::GRAPH_FAILED;
    }
    tiling.set_angle(*angle);
    tiling.set_needBlockNum(*needBlockNum);

    uint64_t tilingKey = 0;
    auto xDataType = tensorX->GetDataType();
    if (xDataType == ge::DT_FLOAT16) {
        tilingKey = TILING_KEY_HALF;
    } else if (xDataType == ge::DT_FLOAT) {
        tilingKey = TILING_KEY_FLOAT;
    }
    context->SetTilingKey(tilingKey);

    context->SetBlockDim(BLOCK_DIM);
    if (context->GetRawTilingData() == nullptr) {
        return ge::GRAPH_FAILED;
    }
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());

    return ge::GRAPH_SUCCESS;
}
}


namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext* context)
{
    if (context == nullptr) {
        return ge::GRAPH_FAILED;
    }
    const gert::Shape* x1_shape = context->GetInputShape(0);
    gert::Shape* y_shape = context->GetOutputShape(0);
    if (x1_shape == nullptr || y_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    
    *y_shape = *x1_shape;
    return GRAPH_SUCCESS;
}
}


namespace ops {
class Rotate : public OpDef {
public:
    explicit Rotate(const char* name) : OpDef(name)
    {
        this->Input("x")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
        this->Input("y")
            .ParamType(REQUIRED)
            .DataType({ge::DT_UINT32, ge::DT_UINT32})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
        this->Output("z")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});

        this->Attr("angle").AttrType(REQUIRED).Int();
        this->Attr("needBlockNum").AttrType(REQUIRED).Int();

        this->SetInferShape(ge::InferShape);

        this->AICore()
            .SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend310p");
    }
};

OP_ADD(Rotate);
}