/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file tile_shape.h
 * \brief
 */

#pragma once

#include <array>
#include <vector>
#include <cstdint>
#include <algorithm>
#include <sstream>
#include "tilefwk/tilefwk_op.h"

#define MAX_DIST_DIM_SIZE 3

#define MAX_M_DIM_SIZE 2
#define MAX_K_DIM_SIZE 3
#define MAX_N_DIM_SIZE 2

/**
 * @brief VecTile tile for elementwise operation
 *
 */
struct VecTile {
    std::vector<int64_t> tile;

    bool valid() const;

    int64_t operator[](int i) const { return tile[i]; }
    int64_t& operator[](int i) { return tile[i]; }

    size_t size() const { return tile.size(); }
};

/**
 * @brief CubeTile tile for matmul operation, m[0], k[0], n[0] for L0 Cache, m[1], k[1], n[1] for L1 Cache
 *
 */
struct CubeTile {
    std::array<int64_t, MAX_M_DIM_SIZE> m;
    std::array<int64_t, MAX_K_DIM_SIZE> k;
    std::array<int64_t, MAX_N_DIM_SIZE> n;
    bool enableSplitK{false};

    bool valid() const;

    std::string ToString() const;
};
/**
 * @brief ConvTile tile for conv operation
 *
 */
struct ConvTile {
    npu::tile_fwk::Conv::TileL1Info tileL1Info;
    npu::tile_fwk::Conv::TileL0Info tileL0Info;
    bool setL0Tile{false};

    bool valid() const;

    std::string ToString() const;
};

/**
 * \brief DistTile tile for distributed operation
 *
 */
struct DistTile {
    std::array<int, MAX_DIST_DIM_SIZE> row;
    std::array<int, MAX_DIST_DIM_SIZE> col;
    std::array<int, MAX_DIST_DIM_SIZE> rank;
    int rankId{INT16_MAX};

    bool valid() const;

    std::string ToString() const;
};

enum class TileType {
    VEC,
    CUBE,
    CONV,
    DIST,
    MAX,
};

/**
 * @brief TileShape tile shape for operation
 *
 */
struct TileShape {
    TileShape();

    TileShape(
        const std::vector<int64_t>& vTile, const CubeTile& cTile, const ConvTile& cvTile, const DistTile& dTile,
        const std::vector<int64_t>& mSize);

    /**
     * \brief Set the Vec Tile
     *
     * \param tile
     */
    void SetVecTile(const std::vector<int64_t>& tile);
    void SetVecTile(const VecTile& tile);

    template <typename... Args, typename = std::enable_if_t<std::conjunction_v<std::is_integral<Args>...>>>
    inline void SetVecTile(Args... args)
    {
        SetVecTile(std::vector<int64_t>{args...});
    }

    /**
     * \brief Get the Vec Tile
     *
     * \return const std::vector<int64_t>&
     */
    const VecTile& GetVecTile() const { return vecTile; }
    VecTile& GetVecTile() { return vecTile; }

    /**
     * \brief Set the Cube Tile
     *
     * \param m
     * \param k
     * \param n
     */
    void SetCubeTile(
        const std::array<int64_t, MAX_M_DIM_SIZE>& m, const std::array<int64_t, MAX_K_DIM_SIZE>& k,
        const std::array<int64_t, MAX_N_DIM_SIZE>& n, bool enableSplitK = false);

    /**
     * \brief Get the Cube Tile
     */
    const CubeTile& GetCubeTile() const { return cubeTile; }
    CubeTile& GetCubeTile() { return cubeTile; }

    /**
     * \brief Set the Conv Tile
     *
     * \param tileL1Info
     * \param tileL0Info
     * \param setL0Tile
     */
    void SetConvTile(
        const npu::tile_fwk::Conv::TileL1Info& tileL1Info, const npu::tile_fwk::Conv::TileL0Info& tileL0Info,
        bool setL0Tile = false);

    /**
     * \brief Get the Conv Tile
     */
    const ConvTile& GetConvTile() const { return convTile; }
    ConvTile& GetConvTile() { return convTile; }

    /**
     * \brief Set the Dist Tile
     *
     * \param row
     * \param col
     * \param rank
     */
    void SetDistTile(
        const std::array<int, MAX_DIST_DIM_SIZE>& row, const std::array<int, MAX_DIST_DIM_SIZE>& col,
        const std::array<int, MAX_DIST_DIM_SIZE>& rank);

    /**
     * \brief Get the Dist Tile
     *
     * \return const DistTile&
     */
    const DistTile& GetDistTile() const { return distTile; }
    DistTile& GetDistTile() { return distTile; }

    /**
     * @brief Set the Dist Rank Id
     *
     * @param rankId
     */
    void SetDistRankId(int64_t rankId);
    /**
     * @brief Get the Dist Rank Id
     *
     * @return int64_t
     */
    int64_t GetDistRankId() const { return distTile.rankId; }

    /**
     * @brief Set the Dist Col
     *
     * @param col
     */
    void SetDistTileCol(const std::array<int, MAX_DIST_DIM_SIZE>& col);

    /**
     * @brief Get the Dist Col
     *
     * @return const std::vector<int64_t>&
     */
    const std::array<int, MAX_DIST_DIM_SIZE>& GetDistTileCol() const { return distTile.col; }

    /**
     * @brief Set the Dist Row
     *
     * @param row
     */
    void SetDistTileRow(const std::array<int, MAX_DIST_DIM_SIZE>& row);

    /**
     * @brief Get the Dist Row
     *
     * @return const std::vector<int64_t>&
     */
    const std::array<int, MAX_DIST_DIM_SIZE>& GetDistTileRow() const { return distTile.row; }

    /**
     * @brief Set the Dist Rank
     *
     * @param rank
     */
    void SetDistTileRank(const std::array<int, MAX_DIST_DIM_SIZE>& rank);

    /**
     * @brief Get the Dist Rank
     *
     * @return const std::vector<int64_t>&
     */
    const std::array<int, MAX_DIST_DIM_SIZE>& GetDistTileRank() const { return distTile.rank; }

    /**
     * @brief Global tile shape
     *
     * @return TileShape&
     */
    static TileShape& Current();

    /**
     * @brief Reset the tile shape
     *
     */
    void Reset()
    {
        vecTile = {};
        cubeTile = {{0, 0}, {0, 0, 0}, {0, 0}};
        distTile = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}, INT16_MAX};
    }

    void SetMatrixSize(const std::vector<int64_t>& size);

    const std::vector<int64_t>& GetMatrixSize() const { return matrixSize; }

    void UpdateScopeDistTile();

    std::string ToString(TileType type = TileType::MAX) const;

private:
    VecTile vecTile;
    CubeTile cubeTile;
    ConvTile convTile;
    DistTile distTile;
    std::vector<int64_t> matrixSize;
};