/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "ascir_common.h"
#include "symbolizer/symbolic_utils.h"
#include "graph/ascendc_ir/utils/asc_tensor_utils.h"

namespace af {
namespace ascir {

bool OnlySecondInputSupportScalar(const std::vector<bool> &is_scalar_list) {
  GE_ASSERT_EQ(is_scalar_list.size(), 2UL);
  return !is_scalar_list[0] && is_scalar_list[1];
}

[[nodiscard]] std::pair<std::vector<ge::DataType>, std::vector<ge::DataType>>
GetConversionFromDtypeMap(const AscNode &node, const std::map<ge::DataType, ge::DataType> &dtype_conversion_map) {
  std::pair<std::vector<ge::DataType>, std::vector<ge::DataType>> conversion_dtype;
  AscNodeInputs node_inputs = node.inputs;
  AscNodeOutputs node_outputs = node.outputs;
  for (size_t i = 0; i < node_inputs().size(); i++) {
    auto it = dtype_conversion_map.find(node_inputs[i].attr.dtype);
    if (it != dtype_conversion_map.end()) {
        conversion_dtype.first.emplace_back(it->second);  // 使用迭代器访问
    } else {
        conversion_dtype.first.emplace_back(node_inputs[i].attr.dtype);
    }
  }
  for (size_t i = 0; i < node_outputs().size(); i++) {
    auto it = dtype_conversion_map.find(node_outputs[i].attr.dtype);
    if (it != dtype_conversion_map.end()) {
        conversion_dtype.second.emplace_back(it->second);  // 使用迭代器访问
    } else {
        conversion_dtype.second.emplace_back(node_outputs[i].attr.dtype);
    }
  }
  return conversion_dtype;
}

bool IsAllVecAxisContinuous(const AscNode &node) {
  AscNodeInputs node_inputs = node.inputs;
  AscNodeOutputs node_outputs = node.outputs;
  for (size_t i = 0; i < node_inputs().size(); i++) {
    if (node_inputs[i].attr.vectorized_axis.size() == 1) {
      continue;
    }
    auto &attr = node_inputs[i].attr;
    for (size_t j = 1; j < attr.vectorized_axis.size(); j++) {
      auto it = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[j]);
      GE_ASSERT_TRUE(it != attr.axis.end(), "Incorrect axis ID in node: %s input %zu vectorized_axis: %zu",
                     node.GetName(), i, j);
      auto axis_id = static_cast<uint64_t>(std::distance(attr.axis.begin(), it));
      Expression cur_axis_stride = attr.repeats[axis_id] * attr.vectorized_strides[j];
      if (SymbolicUtils::StaticCheckEq(cur_axis_stride, attr.vectorized_strides[j - 1]) != TriBool::kTrue) {
        return false;
      }
    }
  }
  for (size_t i = 0; i < node_outputs().size(); i++) {
    if (node_outputs[i].attr.vectorized_axis.size() == 1) {
      continue;
    }
    auto &attr = node_outputs[i].attr;
    for (size_t j = 1; j < attr.vectorized_axis.size(); j++) {
      auto it = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[j]);
      GE_ASSERT_TRUE(it != attr.axis.end(), "Incorrect axis ID in node: %s output %zu vectorized_axis: %zu",
                     node.GetName(), i, j);
      auto axis_id = static_cast<uint64_t>(std::distance(attr.axis.begin(), it));
      Expression cur_axis_stride = attr.repeats[axis_id] * attr.vectorized_strides[j];
      if (SymbolicUtils::StaticCheckEq(cur_axis_stride, attr.vectorized_strides[j - 1]) != TriBool::kTrue) {
        return false;
      }
    }
  }
  return true;
}

bool IsUBScalarTensor(const AscTensor &tensor) {
  auto &attr = tensor.attr;
  uint64_t axis_id = UINT64_MAX;
  for (size_t i = 0; i < attr.vectorized_axis.size(); i++) {
    auto it = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[i]);
    GE_ASSERT_TRUE(it != attr.axis.end(), "Incorrect axis ID in vectorized_axis");
    axis_id = static_cast<uint64_t>(std::distance(attr.axis.begin(), it));
    if (SymbolicUtils::StaticCheckEq(attr.repeats[axis_id], Symbol(1)) != TriBool::kTrue) {
      return false;
    }
  }
  return true;
}

bool IsVectorizedAxisSupportBrc(const AscNode &node, size_t input_id,
                                const BroadcastCapability &broadcast_capability) {
  // 如果是vectorized_axis轴存在不相等,则必须是ub_scalar或node支持brc_inline,且input_id在support_brc_list中才认为合法。
  AscNodeInputs node_inputs = node.inputs;
  if ((IsUBScalarTensor(node_inputs[input_id]) || broadcast_capability.support_inline_brc) &&
      (std::find(broadcast_capability.scalar_inputs.begin(), broadcast_capability.scalar_inputs.end(), input_id) !=
       broadcast_capability.scalar_inputs.end())) {
    return true;
  }
  return false;
}

Status ValidateInputTensorLoopAxis(const AscNode &node, size_t input_id, size_t input_axis_id) {
  AscNodeInputs node_inputs = node.inputs;
  AscNodeOutputs node_outputs = node.outputs;
  auto input_attr = node_inputs[input_id].attr;
  auto output_attr = node_outputs[0].attr;

  auto it = std::find(output_attr.axis.begin(), output_attr.axis.end(), input_attr.axis[input_axis_id]);
  GE_ASSERT_TRUE(it != output_attr.axis.end(), "Node %s[%s]: input tensor %zu loop axis %zu is not in output tensor "
                 "axis", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id);
  auto output_axis_id = static_cast<uint64_t>(std::distance(output_attr.axis.begin(), it));
  if ((SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
      TriBool::kTrue) || (SymbolicUtils::StaticCheckEq(input_attr.repeats[input_axis_id], Symbol(1)) ==
      TriBool::kTrue)) {
    return ge::SUCCESS;
  } else if (SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
             TriBool::kUnknown) {
    GELOGW("Node %s[%s]: input tensor %zu loop axis %zu repeat %s and output tensor 0 loop axis %zu repeat %s may not "
           "be equal or broadcastable(relation cannot be determined)", node.GetTypePtr(), node.GetNamePtr(), input_id,
           input_axis_id, input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
           output_attr.repeats[output_axis_id].Str().get());
    return ge::SUCCESS;
  }

  GELOGE(ge::FAILED, "Node %s[%s]: input tensor %zu loop axis %zu repeat %s and output tensor 0 loop axis %zu repeat "
         "%s are not equal or broadcastable", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id,
         input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
         output_attr.repeats[output_axis_id].Str().get());
  return ge::FAILED;
}

Status ValidateInputTensorVectorizedAxis(const AscNode &node, size_t input_id, size_t input_axis_id,
                                         const BroadcastCapability &broadcast_capability) {
  AscNodeInputs node_inputs = node.inputs;
  AscNodeOutputs node_outputs = node.outputs;
  auto input_attr = node_inputs[input_id].attr;
  auto output_attr = node_outputs[0].attr;

  auto it = std::find(output_attr.axis.begin(), output_attr.axis.end(), input_attr.axis[input_axis_id]);
  GE_ASSERT_TRUE(it != output_attr.axis.end(), "Node %s[%s]: input tensor %zu vectorized_axis %zu is not in output "
                 "tensor axis", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id);
  auto output_axis_id = static_cast<uint64_t>(std::distance(output_attr.axis.begin(), it));
  if ((SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
      TriBool::kTrue) || IsVectorizedAxisSupportBrc(node, input_id, broadcast_capability)) {
    return ge::SUCCESS;
  } else if (SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
             TriBool::kUnknown) {
    GELOGW("Node %s[%s]: input tensor %zu vectorized_axis %zu repeat: %s and output tensor 0 vectorized_axis %zu "
           "repeat: %s may not be equal or broadcastable(relation cannot be determined)", node.GetTypePtr(),
           node.GetNamePtr(), input_id, input_axis_id, input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
           output_attr.repeats[output_axis_id].Str().get());
    return ge::SUCCESS;
  }

  GELOGE(ge::FAILED, "Node %s[%s]: input tensor %zu vectorized_axis %zu repeat: %s and output tensor 0 vectorized_axis "
         "%zu repeat: %s are not equal or broadcastable", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id,
         input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
         output_attr.repeats[output_axis_id].Str().get());
  return ge::FAILED;
}

Status ValidateShapeConsistencyWithSingleOutput(const AscNode &node,
                                                const BroadcastCapability &broadcast_capability) {
  AscNodeInputs node_inputs = node.inputs;
  AscNodeOutputs node_outputs = node.outputs;
  GE_ASSERT_TRUE(!(node_outputs().size() != 1), "Node %s[%s]: output tensor size is not equal with 1",
                 node.GetTypePtr(), node.GetNamePtr());
  GE_ASSERT_TRUE(!node_outputs[0].attr.vectorized_axis.empty(), "Node %s[%s]: output tensor has empty vectorized_axis",
                 node.GetTypePtr(), node.GetNamePtr());
  std::vector<Expression> output_repeats = node_outputs[0].attr.repeats;

  for (size_t i = 0; i < node_inputs().size(); i++) {
    auto &input = node_inputs[i];
    // 如果输入tensor的向量化轴为空,则不考虑该tensor,在是否支持Scalar输入节点的流程中进行处理。
    // 仅处理能够通过向量化轴获取tensor大小的情况。
    if (input.attr.vectorized_axis.empty()) {
      continue;
    }
    for (size_t j = 0; j < input.attr.repeats.size(); j++) {
      if (std::find(input.attr.vectorized_axis.begin(), input.attr.vectorized_axis.end(), input.attr.axis[j]) !=
          input.attr.vectorized_axis.end()) {
        GE_ASSERT_SUCCESS(ValidateInputTensorVectorizedAxis(node, i, j, broadcast_capability), "Node %s[%s]: input "
                          "tensor %zu axis %zu validate vectorized_axis consistency failed", node.GetTypePtr(),
                          node.GetNamePtr(), i, j);
      } else {
        GE_ASSERT_SUCCESS(ValidateInputTensorLoopAxis(node, i, j), "Node %s[%s]: input tensor %zu "
                          "%zu axis %zu validate loop axis consistency failed", node.GetTypePtr(),
                          node.GetNamePtr(), i, j);
      }
    }
  }
  return ge::SUCCESS;
}

bool IsNodeHasScalarInput(const AscNode &node) {
  AscNodeInputs node_inputs = node.inputs;
  for (size_t i = 0; i < node_inputs().size(); i++) {
    if (node.GetInDataNodes().at(i)->GetType() == "Scalar") {
      return true;
    }
  }
  return false;
}

bool IsNodeFirstInputScalar(const AscNode &node) {
  return node.GetInDataNodes().at(0)->GetType() == "Scalar";
}
} // namespace ascir
} // namespace af