/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 * 
 * The code snippet comes from Ascend project.
 * 
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_
#define INC_EXTERNAL_GRAPH_TENSOR_H_

#include <atomic>
#include <memory>
#include <string>
#include <vector>
#include <utility>

#include "./ge_error_codes.h"
#include "./types.h"
#include "ascend_string.h"

namespace ge {
class ShapeImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape {
 public:
  Shape();
  ~Shape() = default;
  explicit Shape(const std::vector<int64_t> &dims);

  /**
   * `GetDimNum()`标识有效的dim的个数,跟`GetDims().size()`不等价,调用方按需选择
   * 比如如果dim是[-2], 维度未可知时:
   * GetDimNum()会返回0;
   * 而GetDims().size()会返回dim的个数,即1;
   * 另外如果需要判断是否是标量,推荐使用接口`GetDims.size() == 0U`来判断
   * @return
   */
  size_t GetDimNum() const;
  // If the idx is invalid, return 0
  int64_t GetDim(size_t idx) const;
  graphStatus SetDim(size_t idx, int64_t value);
  /**
   * `GetDims`标识dim的个数,跟`GetDimNum()`不等价,调用方按需选择
   * 比如如果dim是[-2], 维度未可知时:
   * GetDimNum()会返回0;
   * 而GetDims().size()会返回dim的个数,即1;
   * 另外如果需要判断是否是标量,推荐使用接口`GetDims.size() == 0U`来判断
   * @return
   */
  std::vector<int64_t> GetDims() const;
  /**
   * 获取shape的各个维度的dim值的乘积
   * @return
   * 如果dim值包含-1或者-2,那么size直接返回-1, 含义是unknown shape
   * 如果dim值包含0,那么size直接返回0,含义是空tensor
   * 如果dim值的个数为0,那么size直接返回0,含义是标量
   * 如果dim值的乘积产生了int64的溢出,那么size直接返回0,含义是乘积溢出
   */
  int64_t GetShapeSize() const;

 private:
  std::shared_ptr<ShapeImpl> impl_;
};

class TensorDescImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc {
 public:
  TensorDesc();
  ~TensorDesc() = default;
  explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
  // Copy
  TensorDesc(const TensorDesc &desc);
  // Move
  TensorDesc(TensorDesc &&desc);
  // Copy
  TensorDesc &operator=(const TensorDesc &desc);
  // Move
  TensorDesc &operator=(TensorDesc &&desc);

  void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
  Shape GetShape() const;
  void SetShape(const Shape &shape);
  // set shape with -2, it stand for unknown shape
  graphStatus SetUnknownDimNumShape();
  // for unknown shape
  graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
  graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;

  Format GetFormat() const;
  void SetFormat(Format format);

  Shape GetOriginShape() const;
  void SetOriginShape(const Shape &origin_shape);

  Format GetOriginFormat() const;
  void SetOriginFormat(Format origin_format);

  DataType GetDataType() const;
  void SetDataType(DataType dt);

  ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &))
  std::string GetName() const;
  graphStatus GetName(AscendString &name);
  graphStatus GetName(AscendString &name) const;

  ATTRIBUTED_DEPRECATED(void SetName(const char_t *))
  void SetName(const std::string &name);
  void SetName(const char_t *name);

  // Attr access
  void SetSize(int64_t size);
  int64_t GetSize() const;

  int64_t GetRealDimCnt() const;
  void SetRealDimCnt(const int64_t real_dim_cnt);

  void SetPlacement(Placement placement);
  Placement GetPlacement() const;

  void SetConstData(std::unique_ptr<uint8_t[]> const_data_buffer, const size_t &const_data_len);
  bool GetConstData(uint8_t **const_data_buffer, size_t &const_data_len) const;
 /*
  * 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。
  * 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。
  *
  * 通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如:
  * | 补维规则   | 补维前shape | 补维后shape                                                    |
  * | -------- | ----------- | ------------------------------------------------------------ |
  * | 0110     | [2, 2]      | [2, 1, 1, 2]                                                 |
  * | 100      | [2, 3]      | [1, 2, 3]                                                    |
  * | 1000     | [2, 3]      | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 |
  *
  */
  void SetExpandDimsRule(const AscendString &expand_dims_rule);
  graphStatus GetExpandDimsRule(AscendString &expand_dims_rule) const;

  void SetReuseInputIndex(const uint32_t idx);

 private:
  std::shared_ptr<TensorDescImpl> impl;
  friend class TensorAdapter;
};

class TensorImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor {
 public:
  using DeleteFunc = std::function<void(uint8_t *)>;
  Tensor();
  ~Tensor() = default;
  explicit Tensor(const TensorDesc &tensor_desc);
  Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data);
  Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size);
  Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data);

  TensorDesc GetTensorDesc() const;
  graphStatus SetTensorDesc(const TensorDesc &tensor_desc);

  const uint8_t *GetData() const;
  uint8_t *GetData();
  size_t GetSize() const;
  std::unique_ptr<uint8_t[], Tensor::DeleteFunc> ResetData();

  graphStatus SetData(std::vector<uint8_t> &&data);
  graphStatus SetData(const std::vector<uint8_t> &data);
  graphStatus SetData(const uint8_t *data, size_t size);
  ATTRIBUTED_DEPRECATED(graphStatus SetData(const char_t *data))
  graphStatus SetData(const std::string &data);
  graphStatus SetData(const char_t *data);
  ATTRIBUTED_DEPRECATED(graphStatus SetData(const std::vector<AscendString> &))
  graphStatus SetData(const std::vector<std::string> &data);
  graphStatus SetData(const std::vector<AscendString> &datas);
  graphStatus SetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func);
  graphStatus IsValid();

  graphStatus SetOriginShapeDimNum(const size_t dim_num);
  size_t GetOriginShapeDimNum() const;

  graphStatus SetOriginShapeDim(const size_t idx, const int64_t dim_value);
  int64_t GetOriginShapeDim(const size_t idx) const;

  graphStatus SetOriginFormat(const ge::Format &format);
  ge::Format GetOriginFormat() const;

  graphStatus SetShapeDimNum(const size_t dim_num);
  size_t GetShapeDimNum() const;

  graphStatus SetShapeDim(const size_t idx, const int64_t dim_value);
  int64_t GetShapeDim(const size_t idx) const;

  graphStatus SetFormat(const ge::Format &format);
  ge::Format GetFormat() const;

  graphStatus SetDataType(const ge::DataType &dtype);
  ge::DataType GetDataType() const;

  graphStatus SetPlacement(const ge::Placement &placement);
  ge::Placement GetPlacement() const;

  /*
  * 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。
  * 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。
  *
  * 通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如:
  * | 补维规则   | 补维前shape | 补维后shape                                                    |
  * | -------- | ----------- | ------------------------------------------------------------ |
  * | 0110     | [2, 2]      | [2, 1, 1, 2]                                                 |
  * | 100      | [2, 3]      | [1, 2, 3]                                                    |
  * | 1000     | [2, 3]      | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 |
  *
  */
  graphStatus SetExpandDimsRule(const AscendString &expand_dims_rule);
  graphStatus GetExpandDimsRule(AscendString &expand_dims_rule) const;

  // 高性能接口,与SetData接口的区别是避免重复make_shared,此时需要用户保证该tensor的内存只被当前tensor使用,具有独占所有权
  graphStatus ResetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func);

  Tensor Clone() const;

 private:
  std::shared_ptr<TensorImpl> impl;
  friend class TensorAdapter;
};
}  // namespace ge

#endif  // INC_EXTERNAL_GRAPH_TENSOR_H_