/**

 * Copyright (c) 2025 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */



#include "autofuse_stub.h"

#include "platform/platform_infos_def.h"

#include "exe_graph/runtime/tiling_context.h"

#include "exe_graph/runtime/infer_shape_context.h"



using namespace gert;

using namespace ge;



class TilingSymbolEvalContext : public TilingContext {

 public:

  const gert::Tensor *GetInputTensor(size_t data_index) const {

    auto *tensor = GetInputPointer<gert::Tensor>(data_index + 1);

    if (tensor == nullptr) {

      return nullptr;

    }

    return tensor;

  }

};



class InferShapeSymbolEvalContext : public InferShapeContext {

 public:

  const gert::Tensor *GetInputTensor(size_t data_index) const {

    auto *tensor = GetInputPointer<gert::Tensor>(data_index + 1);

    if (tensor == nullptr) {

      return nullptr;

    }

    return tensor;

  }

};



class SymbolTilingParseContext : public KernelContext {

 public:

  fe::PlatFormInfos *GetPlatFormInfos() {

    return GetInputValue<fe::PlatFormInfos *>(0);

  }

};



uint32_t TilingFunc(TilingSymbolEvalContext *context) {

  auto kernel_context = reinterpret_cast<KernelContext *>(context);

  auto tiling_data_ptr = kernel_context->GetOutputPointer<TilingData *>(TilingContext::kOutputTilingData);

  int64_t data1 = 1L;

  (*tiling_data_ptr)->Append(data1);

  int64_t data2 = 2L;

  (*tiling_data_ptr)->Append(data2);

  int64_t data3 = 3L;

  (*tiling_data_ptr)->Append(data3);



  auto input_data_num =  kernel_context->GetInputValue<size_t>(0);

  auto tiling_parser = kernel_context->GetInputValue<AfTilingParseData *>(input_data_num + 1);

  auto block_dim = tiling_parser->aiv_num;



  context->SetBlockDim(block_dim);

  *context->GetWorkspaceSizes(1) = 1024;

  return 0;

}



size_t GetTilingDataSize() { return 128; }



graphStatus InferShape(InferShapeSymbolEvalContext *context) {

  auto s0 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(0);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  auto s2 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(1);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  auto s4 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(2);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  auto s5 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(2);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(1);

  }();

  auto s6 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(3);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  auto s7 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(3);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(1);

  }();



  context->GetOutputShape(0)->SetDimNum(0);

  context->GetOutputShape(0)->AppendDim(s4 + s7);

  context->GetOutputShape(0)->AppendDim(s5 * s6);

  context->GetOutputShape(0)->AppendDim(s0 + s2);

  std::cout << "InferShape success from stub" << std::endl;

  return GRAPH_SUCCESS;

}



graphStatus GetSymbolTilingCacheKey(TilingSymbolEvalContext *context) {

  auto kernel_context = reinterpret_cast<KernelContext *>(context);

  auto symbol_source_vector = kernel_context->GetOutputPointer<TypedContinuousVector<int64_t>>(0U);



  auto s0 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(0);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  auto s1 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(0);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(1);

  }();

  auto s2 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(1);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();



  symbol_source_vector->MutableData()[0] = s0;

  symbol_source_vector->MutableData()[1] = s1;

  symbol_source_vector->MutableData()[2] = s2;



  symbol_source_vector->SetSize(3);



  return GRAPH_SUCCESS;

}



graphStatus TilingParse(SymbolTilingParseContext *context) {

  auto *platform = context->GetPlatFormInfos();



  auto kernel_context = reinterpret_cast<KernelContext *>(context);

  auto tiling_parse_av = kernel_context->GetOutput(0);

  auto tiling_parse_data_ptr = new (std::nothrow) uint8_t[sizeof(AfTilingParseData)];

  tiling_parse_av->SetWithDefaultDeleter<uint8_t[]>(tiling_parse_data_ptr);



  auto tiling_parse_data = kernel_context->GetOutputPointer<AfTilingParseData *>(0);

  (*tiling_parse_data)->aiv_num = platform->GetCoreNum();

  (*tiling_parse_data)->ub_size = (184 * 1024);



  return GRAPH_SUCCESS;

}



ge::graphStatus DfxInputSymbolInfo(TilingSymbolEvalContext *context, char *out_symbol_info, size_t size)

{

  if (out_symbol_info == nullptr || size == 0) {

    return GRAPH_SUCCESS;

  }

  std::string symbol_info;

  auto s0 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(0);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  symbol_info += ("s0: " + std::to_string(s0));



  auto s1 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(0);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(1);

  }();

  symbol_info += (", s1: " + std::to_string(s1));



  auto s2 = [&]() -> int64_t {

    const auto *tensor = context->GetInputTensor(1);

    if (tensor == nullptr) {

      return -1;

    }

    return tensor->GetOriginShape().GetDim(0);

  }();

  symbol_info += (", s2: " + std::to_string(s2));





  if (symbol_info.empty()) {

    out_symbol_info[0] = '\0';

    return GRAPH_SUCCESS;

  }

  symbol_info += ".";

  if (strncpy_s(out_symbol_info, size, symbol_info.c_str(), std::min(symbol_info.size(), size - 1)) != 0) {

    return GRAPH_FAILED;

  }

  return GRAPH_SUCCESS;

}