/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#pragma once
#include <any>
#include <atomic>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include <vector>

#include "core/any_cast.h"
#include "core/dtype.h"
#include "core/error.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/pipe.h"
#include "ir/reflection/field_traits.h"
#include "ir/span.h"
#include "ir/type.h"

namespace pypto {
namespace ir {

/**
 * \brief Base class for all expressions in the IR
 *
 * This is the root base class for all expression types (scalar, tensor, etc).
 * Expressions represent computations that produce values.
 * All expressions are immutable.
 */
class Expr : public IRNode {
protected:
    TypePtr type_; // Type of the expression result

public:
    /**
     * \brief Create an expression
     *
     * \param span Source location
     * \param type Type of the expression result (defaults to UnknownType)
     */
    explicit Expr(Span s, TypePtr type = GetUnknownType()) : IRNode(std::move(s)), type_(std::move(type)) {}
    ~Expr() override = default;

    /**
     * \brief Get the type name of this expression
     *
     * \return Human-readable type name (e.g., "ScalarExpr", "Var", "Call")
     */
    [[nodiscard]] std::string TypeName() const override { return "Expr"; }

    /**
     * \brief Get the type of this expression
     *
     * \return Type pointer of the expression result
     */
    [[nodiscard]] const TypePtr& GetType() const { return type_; }

    static constexpr auto GetFieldDescriptors()
    {
        return std::tuple_cat(
            IRNode::GetFieldDescriptors(), std::make_tuple(reflection::UsualField(&Expr::type_, "type")));
    }
};

using ExprPtr = std::shared_ptr<const Expr>;

// Forward declaration for MemorySpace enum (defined in memory_space.h)
enum class MemorySpace;

/**
 * \brief Base class for operations/functions
 *
 * Represents callable operations in the IR.
 * Stores the schema of allowed kwargs (key -> expected type mapping).
 * Actual kwarg values are stored per-Call instance in Call::kwargs_.
 */
class Op {
public:
    std::string name_;

    explicit Op(std::string name) : name_(std::move(name)) {}
    virtual ~Op() = default;

    /**
     * \brief Register an allowed kwarg with its expected type
     *
     * Defines that this operator accepts a kwarg with the given key and type.
     * This is used for validation when creating Call expressions.
     *
     * Only specific types are allowed: bool, int, std::string, double, DataType, MemorySpace,
     * std::vector<int>. This is enforced at compile-time via static_assert.
     *
     * \tparam T Expected type of the kwarg value (must be one of the allowed types)
     * \param key Kwarg key (string identifier)
     */
    template <typename T>
    void SetAttrType(const std::string& key) const
    {
        // Compile-time check: only allow specific types
        static_assert(
            std::is_same_v<T, bool> || std::is_same_v<T, int> || std::is_same_v<T, std::string> ||
                std::is_same_v<T, double> || std::is_same_v<T, DataType> || std::is_same_v<T, MemorySpace> ||
                std::is_same_v<T, std::vector<int>>,
            "SetAttrType only accepts: bool, int, std::string, double, DataType, MemorySpace, std::vector<int>");

        attrs_.emplace(key, std::type_index(typeid(T)));
    }

    /**
     * \brief Get the expected type for a kwarg
     *
     * \param key Kwarg key
     * \return type_index of the expected type
     * \throws pypto::ir::ValueError if kwarg is not registered
     */
    [[nodiscard]] std::type_index GetAttrType(const std::string& key, const Span& span = Span::Unknown()) const
    {
        auto it = attrs_.find(key);
        IRCHECK(it != attrs_.end()) << "Attribute '" << key << "' not found in operator '" << name_ << "'"
                                    << " at " << span.ToString();
        return it->second;
    }

    /**
     * \brief Check if a kwarg is registered
     *
     * \param key Kwarg key
     * \return true if the kwarg is registered
     */
    [[nodiscard]] bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); }

    /**
     * \brief Get all registered kwarg keys
     *
     * \return Vector of all kwarg keys
     */
    [[nodiscard]] std::vector<std::string> GetAttrKeys() const
    {
        std::vector<std::string> keys;
        keys.reserve(attrs_.size());
        for (const auto& pair : attrs_) {
            keys.push_back(pair.first);
        }
        return keys;
    }

    /**
     * \brief Get all registered kwargs as a map
     *
     * \return Map of kwarg keys to expected types
     */
    [[nodiscard]] const std::unordered_map<std::string, std::type_index>& GetAttrs() const { return attrs_; }

    /**
     * \brief Set the pipeline type for this operator
     *
     * \param pipe Pipeline type (e.g., MTE2, V)
     */
    void SetPipe(PipeType pipe) const { pipe_ = pipe; }

    /**
     * \brief Get the pipeline type for this operator
     *
     * \return Optional pipeline type (nullopt if not set)
     */
    [[nodiscard]] std::optional<PipeType> GetPipe() const { return pipe_; }

    [[nodiscard]] virtual ObjectKind GetKind() const { return ObjectKind::Op; }
    [[nodiscard]] virtual std::string TypeName() const { return "Op"; }

private:
    mutable std::unordered_map<std::string, std::type_index> attrs_; ///< Kwarg schema (key -> type)
    mutable std::optional<PipeType> pipe_;                           ///< Pipeline type
};

using OpPtr = std::shared_ptr<const Op>;

/**
 * \brief Variable reference expression
 *
 * Represents a reference to a named variable.
 * Can represent both scalar and tensor variables based on its type.
 */
class Var : public Expr {
public:
    std::string name_;

    /**
     * \brief Create a variable reference
     *
     * \param name Variable name
     * \param type Type of the variable (ScalarType, TensorType, or TileType)
     *             Memory reference information is stored in ShapedType for Tensor/Tile types
     * \param span Source location
     * \return Shared pointer to const Var expression
     */
    Var(std::string name, TypePtr type, Span span) : Expr(std::move(span), std::move(type)), name_(std::move(name)) {}

    [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::Var; }
    [[nodiscard]] std::string TypeName() const override { return "Var"; }

    /**
     * \brief Get field descriptors for reflection-based visitation
     *
     * \return Tuple of field descriptors (name_ as IGNORE field)
     */
    static constexpr auto GetFieldDescriptors()
    {
        return std::tuple_cat(
            Expr::GetFieldDescriptors(), std::make_tuple(reflection::IgnoreField(&Var::name_, "name")));
    }
};

using VarPtr = std::shared_ptr<const Var>;

/**
 * \brief Iteration argument variable
 *
 * Represents an iteration argument (loop-carried value) in for loops.
 * IterArgs implement SSA-style loop-carried dependencies where values are
 * carried from one iteration to the next via yield statements.
 *
 * **Scoping Rules:**
 * - IterArg variables are scoped to the loop body only
 * - Cannot be directly accessed outside the loop
 * - Must use return_vars to expose final values after the loop
 *
 * **Usage Pattern:**
 * 1. Create IterArg with initial value
 * 2. Use in ForStmt's iter_args list
 * 3. Update via YieldStmt in loop body
 * 4. Capture final value in ForStmt's return_vars
 *
 * \example
 * // for i, (sum,) in pl.range(0, n, 1, init_values=[0]):
 * //     sum = pl.yield_(sum + i)
 * // sum_final = sum
 * auto sum_iter = std::make_shared<IterArg>("sum", type, init_val, span);
 * auto sum_final = std::make_shared<Var>("sum_final", type, span);
 * auto for_stmt = std::make_shared<ForStmt>(
 *     i, start, stop, step,
 *     std::vector{sum_iter},  // iter_args (loop-scoped)
 *     body,
 *     std::vector{sum_final}, // return_vars (accessible after loop)
 *     span
 * );
 */
class IterArg {
public:
    VarPtr iterVar_;    // Variable reference for the iteration argument
    ExprPtr initValue_; // Initial value expression for first iteration

    /**
     * \brief Create an iteration argument
     *
     * \param name Variable name (scoped to loop body)
     * \param type Type of the variable (ScalarType, TensorType, or TileType)
     *             Memory reference information is stored in ShapedType for Tensor/Tile types
     * \param initValue Initial value expression for first iteration
     * \param span Source location
     */
    IterArg(std::string name, TypePtr type, ExprPtr initValue, Span span)
        : iterVar_(std::make_shared<Var>(name, std::move(type), std::move(span))), initValue_(std::move(initValue))
    {}

    /**
     * \brief Create an iteration argument with existing variable reference
     *
     * \param var Variable reference for the iteration argument
     * \param initValue Initial value expression for first iteration
     */
    IterArg(VarPtr var, ExprPtr initValue) : iterVar_(std::move(var)), initValue_(std::move(initValue)) {}

    /**
     * \brief Get field descriptors for reflection-based visitation
     *
     * \return Tuple of field descriptors (initValue_ as USUAL field)
     */
    static constexpr auto GetFieldDescriptors()
    {
        return std::make_tuple(
            reflection::UsualField(&IterArg::initValue_, "initValue"),
            reflection::UsualField(&IterArg::iterVar_, "iterVar"));
    }
};

using IterArgPtr = std::shared_ptr<const IterArg>;

// MemRefPtr is already declared in memref.h, just reuse it here

/**
 * \brief Function call expression
 *
 * Represents a function call with an operation and arguments.
 * Can accept any Expr as arguments, not just scalar expressions.
 * Supports keyword arguments (kwargs) for operator metadata.
 */
class Call : public Expr {
public:
    std::string name_;                                     // Operation name
    std::vector<ExprPtr> args_;                            // Positional arguments
    std::vector<std::pair<std::string, std::any>> kwargs_; // Keyword arguments (metadata, ordered)

    /**
     * \brief Create a function call expression
     *
     * \param op Operation/function to call
     * \param args List of argument expressions
     * \param span Source location
     */
    Call(std::string name, std::vector<ExprPtr> args, Span span)
        : Expr(std::move(span)), name_(std::move(name)), args_(std::move(args)), kwargs_()
    {}

    /**
     * \brief Create a function call expression with explicit type
     *
     * \param op Operation/function to call
     * \param args List of argument expressions
     * \param type Result type of the call
     * \param span Source location
     */
    Call(std::string name, std::vector<ExprPtr> args, TypePtr type, Span span)
        : Expr(std::move(span), std::move(type)), name_(std::move(name)), args_(std::move(args)), kwargs_()
    {}

    Call(std::string name, std::vector<ExprPtr> args, std::vector<std::pair<std::string, std::any>> kwargs, Span span)
        : Expr(std::move(span)), name_(std::move(name)), args_(std::move(args)), kwargs_(std::move(kwargs))
    {}

    Call(
        std::string name, std::vector<ExprPtr> args, std::vector<std::pair<std::string, std::any>> kwargs, TypePtr type,
        Span span)
        : Expr(std::move(span), std::move(type)),
          name_(std::move(name)),
          args_(std::move(args)),
          kwargs_(std::move(kwargs))
    {}

    /**
     * \brief Get a kwarg value with type checking
     *
     * \tparam T Type of the kwarg value
     * \param key Kwarg key
     * \param default_value Default value if key doesn't exist
     * \return The kwarg value or default
     */
    template <typename T>
    T GetKwarg(const std::string& key, const T& default_value = T{}) const
    {
        for (const auto& [k, v] : kwargs_) {
            if (k == key) {
                return AnyCast<T>(v, "kwarg key: " + key);
            }
        }
        return default_value;
    }

    /**
     * \brief Check if a kwarg exists
     *
     * \param key Kwarg key
     * \return true if the kwarg exists
     */
    [[nodiscard]] bool HasKwarg(const std::string& key) const
    {
        for (const auto& kwarg : kwargs_) {
            if (kwarg.first == key) {
                return true;
            }
        }
        return false;
    }

    [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::Call; }
    [[nodiscard]] std::string TypeName() const override { return "Call"; }

    /**
     * \brief Get field descriptors for reflection-based visitation
     *
     * \return Tuple of field descriptors (op, args, and kwargs as USUAL fields)
     */
    static constexpr auto GetFieldDescriptors()
    {
        return std::tuple_cat(
            Expr::GetFieldDescriptors(),
            std::make_tuple(
                reflection::UsualField(&Call::name_, "name"), reflection::UsualField(&Call::args_, "args"),
                reflection::UsualField(&Call::kwargs_, "kwargs")));
    }
};

using CallPtr = std::shared_ptr<const Call>;

/**
 * \brief Expression to create a tuple from multiple expressions
 *
 * Takes a list of expressions and creates a tuple value.
 * The result type is TupleType containing the types of all input expressions.
 */
class MakeTuple : public Expr {
public:
    std::vector<ExprPtr> elements_; // Elements of the tuple

    /**
     * \brief Create a tuple construction expression
     *
     * \param elements Expressions to be tuple elements
     * \param span Source location
     */
    MakeTuple(std::vector<ExprPtr> elements, Span span);

    [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::MakeTuple; }
    [[nodiscard]] std::string TypeName() const override { return "MakeTuple"; }

    /**
     * \brief Get field descriptors for reflection-based visitation
     *
     * \return Tuple of field descriptors
     */
    static constexpr auto GetFieldDescriptors()
    {
        return std::tuple_cat(
            Expr::GetFieldDescriptors(), std::make_tuple(reflection::UsualField(&MakeTuple::elements_, "elements")));
    }
};

using MakeTuplePtr = std::shared_ptr<const MakeTuple>;

/**
 * \brief Unified subscript expression: value[slice]
 *
 * Represents Python subscript syntax for both tuple element access and tile offset.
 * The concrete semantics is determined by the static type of `value_`:
 *   - `value_` is `TupleType`: tuple element access. `slice_` must be a `ConstInt`
 *     whose value is the 0-based element index. Result type is the element type.
 *   - `value_` is `TileType`: tile element offset. `slice_` is an integer expression
 *     (static or dynamic). Result type is the same TileType as the base tile,
 *     with physical address shifted by `slice * sizeof(dtype)` bytes.
 */
class GetItemExpr : public Expr {
public:
    ExprPtr value_; ///< Base expression (must have TupleType or TileType)
    ExprPtr slice_; ///< Subscript expression

    /**
     * \brief Create a subscript expression
     *
     * \param value Base expression (must have TupleType or TileType)
     * \param slice Subscript expression (for TupleType, must be a ConstInt)
     * \param span Source location
     */
    GetItemExpr(ExprPtr value, ExprPtr slice, Span span);

    [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::GetItemExpr; }
    [[nodiscard]] std::string TypeName() const override { return "GetItemExpr"; }

    static constexpr auto GetFieldDescriptors()
    {
        return std::tuple_cat(
            Expr::GetFieldDescriptors(), std::make_tuple(
                                             reflection::UsualField(&GetItemExpr::value_, "value"),
                                             reflection::UsualField(&GetItemExpr::slice_, "slice")));
    }
};

using GetItemExprPtr = std::shared_ptr<const GetItemExpr>;

} // namespace ir
} // namespace pypto