/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef INCLUDE_SIMT_API_COOPERATIVE_GROUPS_H
#define INCLUDE_SIMT_API_COOPERATIVE_GROUPS_H

#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_COOPERATIVE_GROUPS_H__
#endif

#include "simt_api/common_functions.h"
#include "simt_api/device_types.h"
#include "simt_api/asc_fp16.h"
#include "simt_api/asc_bf16.h"

namespace cooperative_groups {

class _coalesced_group_data_access;
class thread_block;

enum class group_type : unsigned int {
    thread_block_type,
    tiled_group_type,
    coalesced_group_type,
};

class thread_group {
public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long size() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long num_threads() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long thread_rank() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void sync() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline group_type get_type() const;

protected:
    group_type _type;

    struct tiled_info {
        bool is_tiled;
        unsigned int mask;
        unsigned int num_threads;
        unsigned int meta_group_rank;
        unsigned int meta_group_size;
    } _tiled_info;

    friend class thread_block;

    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group(group_type type);
};

class tiled_group : public thread_group {
    friend class thread_block;
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(
        const thread_group& parent, unsigned int tilesz);
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(
        const thread_block& parent, unsigned int tilesz);

public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void sync() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long num_threads() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long size() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long thread_rank() const;

protected:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline tiled_group(unsigned int num_threads);
};

class thread_block : public thread_group {
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block this_thread_block();
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(
        const thread_group& parent, unsigned int tilesz);
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(
        const thread_block& parent, unsigned int tilesz);

public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static void sync();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static unsigned int thread_rank();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static dim3 group_index();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static dim3 thread_index();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static dim3 dim_threads();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static unsigned int num_threads();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static unsigned int size();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline static dim3 group_dim();

private:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group create_tiled_group(unsigned int tile_size) const;
};

__SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block this_thread_block();

template <unsigned int Size, typename ParentT>
class thread_block_tile;

template <unsigned int Size, typename ParentT>
class _static_parent_thread_block_tile_base {
public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int meta_group_rank() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int meta_group_size() const;
};

class coalesced_group : public thread_group {
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group coalesced_threads();
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(
        const thread_group& parent, unsigned int tilesz);
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group tiled_partition(
        const coalesced_group& parent, unsigned int tilesz);
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group binary_partition(const coalesced_group& g, bool pred);
    friend class _coalesced_group_data_access;

    template <unsigned int Size, typename ParentT>
    friend __SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group binary_partition(
        const thread_block_tile<Size, ParentT>& g, bool pred);

public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void sync() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long num_threads() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long size() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long thread_rank() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long meta_group_size() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned long long meta_group_rank() const;

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl(T var, int src_rank) const;

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl_up(T var, unsigned int delta) const;

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl_down(T var, unsigned int delta) const;

    __SIMT_DEVICE_FUNCTIONS_DECL__ inline int any(int predicate) const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline int all(int predicate) const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int ballot(int predicate) const;

protected:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group(unsigned int mask);
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int get_mask() const;

private:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int _packLanes(unsigned int laneMask) const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group create_tiled_group(unsigned int tile_size) const;
};

__SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group coalesced_threads();

template <unsigned int TileCount, unsigned int TileMask, unsigned int LaneMask, unsigned int ShiftCount>
struct _tile_helpers {
    static constexpr unsigned int tile_count = TileCount;
    static constexpr unsigned int tile_mask = TileMask;
    static constexpr unsigned int lane_mask = LaneMask;
    static constexpr unsigned int shift_count = ShiftCount;
};

template <unsigned int>
struct tile_helpers;
template <>
struct tile_helpers<32> : public _tile_helpers<1, 0xFFFFFFFF, 0x1F, 5> {};  // thread_block_tile<32> config
template <>
struct tile_helpers<16> : public _tile_helpers<2, 0x0000FFFF, 0x0F, 4> {};  // thread_block_tile<16> config
template <>
struct tile_helpers<8> : public _tile_helpers<4, 0x000000FF, 0x07, 3> {};   // thread_block_tile<8> config
template <>
struct tile_helpers<4> : public _tile_helpers<8, 0x0000000F, 0x03, 2> {};   // thread_block_tile<4> config
template <>
struct tile_helpers<2> : public _tile_helpers<16, 0x00000003, 0x01, 1> {};  // thread_block_tile<2> config
template <>
struct tile_helpers<1> : public _tile_helpers<32, 0x00000001, 0x00, 0> {};  // thread_block_tile<1> config

template <unsigned int Size>
struct _is_valid_thread_block_tile_size {
    static constexpr bool value = Size == 1 || Size == 2 || Size == 4 || Size == 8 || Size == 16 || Size == 32;
};

template <unsigned int Size>
class thread_block_tile_base {
    static_assert(_is_valid_thread_block_tile_size<Size>::value, "Size must be one of 1/2/4/8/16/32");

    using th = tile_helpers<Size>;
    static constexpr unsigned int numThreads = Size;

    template <unsigned int Sz, typename ParentT>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline friend coalesced_group binary_partition(
        const thread_block_tile<Sz, ParentT>& g, bool pred);
    friend class _coalesced_group_data_access;

public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void sync() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ static inline unsigned long long thread_rank();
    __SIMT_DEVICE_FUNCTIONS_DECL__ static constexpr inline unsigned long long num_threads();
    __SIMT_DEVICE_FUNCTIONS_DECL__ static constexpr inline unsigned long long size();

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl(T var, int src_rank) const;

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl_up(T var, unsigned int delta) const;

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl_down(T var, unsigned int delta) const;

    template <typename T>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline T shfl_xor(T var, unsigned int lane_mask) const;

    __SIMT_DEVICE_FUNCTIONS_DECL__ inline int any(int predicate) const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline int all(int predicate) const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int ballot(int predicate) const;

protected:
    __SIMT_DEVICE_FUNCTIONS_DECL__ static inline unsigned int build_mask();
    __SIMT_DEVICE_FUNCTIONS_DECL__ static inline unsigned int get_mask() { return build_mask(); }
};

template <unsigned int Size, typename ParentT = void>
class thread_block_tile_impl : public thread_block_tile_base<Size>,
                               public _static_parent_thread_block_tile_base<Size, ParentT> {
public:
    using thread_block_tile_base<Size>::thread_rank;
    using thread_block_tile_base<Size>::num_threads;
    using thread_block_tile_base<Size>::size;

protected:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile_impl();
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile_impl(unsigned int, unsigned int);
};

template <unsigned int Size>
class thread_block_tile_impl<Size, void> : public thread_block_tile_base<Size>, public tiled_group {
public:
    using thread_block_tile_base<Size>::sync;
    using thread_block_tile_base<Size>::thread_rank;
    using thread_block_tile_base<Size>::num_threads;
    using thread_block_tile_base<Size>::size;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int meta_group_rank() const;
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline unsigned int meta_group_size() const;

protected:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile_impl(
        unsigned int meta_group_rank = 0, unsigned int meta_group_size = 1);
};

template <unsigned int Size, typename ParentT = void>
class thread_block_tile : public thread_block_tile_impl<Size, ParentT> {
public:
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile(const ParentT& g);
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline operator thread_block_tile<Size, void>() const;
};

template <unsigned int Size>
class thread_block_tile<Size, void> : public thread_block_tile_impl<Size, void> {
    template <unsigned int, typename ParentT>
    friend class thread_block_tile;

public:
    template <typename ParentT>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile(const thread_block_tile<Size, ParentT>& g);

protected:
    template <unsigned int OtherSize, typename OtherParentT>
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile(const thread_block_tile<OtherSize, OtherParentT>& g);
};

template <unsigned int Size, typename ParentT>
struct tiled_partition_impl;

template <unsigned int Size>
struct tiled_partition_impl<Size, thread_block> : public thread_block_tile<Size, thread_block> {
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline tiled_partition_impl(const thread_block& g);
};

template <unsigned int Size, unsigned int ParentSize, typename GrandParent>
struct tiled_partition_impl<Size, thread_block_tile<ParentSize, GrandParent>>
    : public thread_block_tile<Size, thread_block_tile<ParentSize, GrandParent>> {
    static_assert(Size < ParentSize, "Tile size bigger or equal to the parent group size");
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline tiled_partition_impl(const thread_block_tile<ParentSize, GrandParent>& g);
};

template <unsigned int Size, typename ParentT>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_block_tile<Size, ParentT> tiled_partition(const ParentT& g);

__SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(const thread_group& parent, unsigned int tilesz);
__SIMT_DEVICE_FUNCTIONS_DECL__ inline thread_group tiled_partition(const thread_block& parent, unsigned int tilesz);
__SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group tiled_partition(
    const coalesced_group& parent, unsigned int tilesz);

__SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group binary_partition(const coalesced_group& g, bool pred);

template <unsigned int Size, typename ParentT>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline coalesced_group binary_partition(
    const thread_block_tile<Size, ParentT>& g, bool pred);

} // namespace cooperative_groups

#include "impl/simt_api/cooperative_groups_impl.h"

#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_COOPERATIVE_GROUPS_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_COOPERATIVE_GROUPS_H__
#endif

#endif