/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef OPTIMIZE_PLATFORM_COMMON_GRAPH_PASS_POW_EQUIV_SUBSTITUTION_PASS_H
#define OPTIMIZE_PLATFORM_COMMON_GRAPH_PASS_POW_EQUIV_SUBSTITUTION_PASS_H

#include "optimize/graph_pass/base_graph_pass.h"
namespace optimize {
using SubstitutionFunc = std::function<af::Status(af::AscGraph &, const af::AscNodePtr &)>;

enum class PatternType : uint8_t {
  kZero = 0U,  // 0.0 (pow 0 = Brc one)
  kHalf = 1U,       // 0.5 (pow 0.5 = sqrt)
  kNegHalf = 2U,    // -0.5 (pow -0.5 = 1/sqrt)
  kOne = 3U,        // 1.0 (pow 1 just remove node)
  kNegOne = 4U,     // -1.0 (pow -1 = reciprocal)
  kTwo = 5U,        // 2.0 (pow 2 = mul)
  kNegTwo = 6U,     // -2.0 (pow -2 = 1/(x*x))
  kThree = 7U,      // 3.0 (pow 3 = mul(x,x)*x)
  kFour = 8U,       // 4.0 (pow 4 = mul(x*x,x*x))
  kNone = 100U
};

class PowEquivSubstitutionPass final : public BaseGraphPass {
 public:
  PowEquivSubstitutionPass() = default;
  ~PowEquivSubstitutionPass() override = default;
  af::Status RunPass(af::AscGraph &graph) override;

 private:
  static std::vector<af::AscNodePtr> FilterPowNodes(af::AscGraph &graph);

  static const std::unordered_map<PatternType, SubstitutionFunc> &GetGlobalSubstitutionMap();

  static bool GetScalarInput(const af::AscNodePtr &pow_node, std::string &scalar_val);

  // 检查 Pow 的 input0 是否是 Scalar,如果是则插入 Brc
  static af::Status EnsureInputWithBrcIfNeeded(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status RelinkPowInputToNode(const af::AscNodePtr &pow_node, const af::AscNodePtr &target_node,
                                         const int32_t in_idx = 0);

  static af::Status RelinkPowOutputToNode(const af::AscNodePtr &pow_node, const af::AscNodePtr &target_node);

  static af::Status ReplaceWithScalarBrc(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithSqrt(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithInverseSqrt(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status RemoveUseLessPow(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithInverseInput(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithMul(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithInverseMul(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithCube(af::AscGraph &graph, const af::AscNodePtr &pow_node);

  static af::Status ReplaceWithFourthPower(af::AscGraph &graph, const af::AscNodePtr &pow_node);
};
}  // namespace optimize

#endif  // OPTIMIZE_PLATFORM_COMMON_GRAPH_PASS_POW_EQUIV_SUBSTITUTION_PASS_H