/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

#ifndef ASCENDC_MODULE_SIMT_WARP_LEVEL_IMPL_H
#define ASCENDC_MODULE_SIMT_WARP_LEVEL_IMPL_H

#if defined(ASCENDC_CPU_DEBUG)
#include <cmath>

#include "kernel_utils.h"
#include "simt_api/cpp/kernel_simt_utils.h"
#include "stub_def.h"
#endif
#include "kernel_simt_common_impl.h"
namespace AscendC {
namespace Simt {

#if defined(ASCENDC_CPU_DEBUG)
constexpr int32_t MAX_SHLF_OFFSET = 31;
constexpr int32_t WARP_SIZE = 32;
#else
// clamp the max source lane, except up mode
constexpr int32_t MAX_OFFSET_OF_MODE = 0x1f;

// clamp the max source lane for up mode
constexpr int32_t MAX_OFFSET_OF_UP_MODE = 0;

// the start position of max offset and lane mask
constexpr int32_t MAX_OFFSET_START_POS = 8;
constexpr int32_t LANE_MASK_START_POS = 16;
#endif

__aicore__ inline int32_t AllSyncImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](int32_t a, int32_t b) -> int32_t { return (a != 0) && (b != 0); };
    return warp.WarpOp(predicate, customFunc);
#else
    return bisheng::cce::simt::__all(predicate);
#endif
}

__aicore__ inline int32_t AnySyncImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](int32_t a, int32_t b) -> int32_t { return (a != 0) || (b != 0); };
    return warp.WarpOp(predicate, customFunc);
#else
    return bisheng::cce::simt::__any(predicate);
#endif
}

__aicore__ inline int32_t UniImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
    ASCENDC_ASSERT(false, { KERNEL_LOG(KERNEL_ERROR, "Uni CPU: not supported yet"); });
    return 0;
#else
    return bisheng::cce::simt::__uni(predicate);
#endif
}

__aicore__ inline uint32_t BallotSyncImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    uint32_t laneId = GetLaneId();
    uint32_t one = 1;
    uint32_t bitNum = 0;
    if (predicate != 0) {
        bitNum = one << laneId;
    }
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](uint32_t a, uint32_t b) -> uint32_t { return a + b; };
    return warp.WarpOp(bitNum, customFunc);
#else
    return bisheng::cce::simt::__ballot(predicate);
#endif
}

__aicore__ inline uint32_t ActiveMaskImpl()
{
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    uint32_t laneId = GetLaneId();
    uint32_t one = 1;
    uint32_t bitNum = one << laneId;
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](uint32_t a, uint32_t b) -> uint32_t { return a + b; };
    return warp.WarpOp(bitNum, customFunc);
#else
    return bisheng::cce::simt::__activemask();
#endif
}

template <typename T>
__aicore__ inline T ShflSyncImpl(T var, int32_t srcLane, int32_t width = warpSize)
{
    static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
                 "Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
    ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
                   { KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
    ASCENDC_ASSERT((srcLane < WARP_SIZE && srcLane >= 0), { KERNEL_LOG(KERNEL_ERROR, "srcLane must be in [0, 32)"); });

    uint32_t warpId = GetWarpId();
    uint32_t laneId = GetLaneId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    int32_t minLane = laneId / width * width;
    srcLane = minLane + srcLane % width;

    return warp.WarpShuffleOp(var, laneId, srcLane);
#else
    return bisheng::cce::simt::__shfl(var,
        ((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_MODE << MAX_OFFSET_START_POS) | (srcLane));
#endif
}

template <typename T>
__aicore__ inline T ShflUpSyncImpl(T var, uint32_t delta, int32_t width = warpSize)
{
    static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
              "Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
    ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
                   { KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
    ASCENDC_ASSERT((delta < WARP_SIZE && delta >= 0), { KERNEL_LOG(KERNEL_ERROR, "delta must be in [0, 32)"); });

    uint32_t warpId = GetWarpId();
    uint32_t laneId = GetLaneId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    int32_t minLane = laneId / width * width;
    int32_t srcLane = laneId - delta;

    if (srcLane < minLane) {
        srcLane = laneId;
    }

    return warp.WarpShuffleOp(var, laneId, srcLane);
#else
    return bisheng::cce::simt::__shfl_up(var,
        ((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_UP_MODE << MAX_OFFSET_START_POS) | (delta));
#endif
}

template <typename T>
__aicore__ inline T ShflDownSyncImpl(T var, uint32_t delta, int32_t width = warpSize)
{
    static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
              "Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
    ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
                   { KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
    ASCENDC_ASSERT((delta < WARP_SIZE && delta >= 0), { KERNEL_LOG(KERNEL_ERROR, "delta must be in [0, 32)"); });

    uint32_t warpId = GetWarpId();
    uint32_t laneId = GetLaneId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    int32_t minLane = laneId / width * width;
    int32_t maxLane = minLane + MAX_SHLF_OFFSET % width;
    int32_t srcLane = laneId + delta;

    if (srcLane > maxLane) {
        srcLane = laneId;
    }

    return warp.WarpShuffleOp(var, laneId, srcLane);
#else
    return bisheng::cce::simt::__shfl_down(var,
                       ((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_MODE << MAX_OFFSET_START_POS) | (delta));
#endif
}

template <typename T>
__aicore__ inline T ShflXorSyncImpl(T var, int32_t laneMask, int32_t width = warpSize)
{
    static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
              "Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
    ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
                   { KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
    ASCENDC_ASSERT((laneMask < WARP_SIZE && laneMask >= 0),
                   { KERNEL_LOG(KERNEL_ERROR, "laneMask must be in [0, 32)"); });

    uint32_t warpId = GetWarpId();
    uint32_t laneId = GetLaneId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    int32_t minLane = laneId / width * width;
    int32_t maxLane = minLane + MAX_SHLF_OFFSET % width;
    int32_t srcLane = laneId ^ laneMask;

    if (srcLane < 0 || srcLane > maxLane) {
        srcLane = laneId;
    }

    return warp.WarpShuffleOp(var, laneId, srcLane);
#else
    return bisheng::cce::simt::__shfl_xor(var,
                      ((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_MODE << MAX_OFFSET_START_POS) | (laneMask));
#endif
}

template <typename T>
__aicore__ inline T ReduceAddSyncImpl(T val)
{
    static_assert(SupportType<T, int32_t, uint32_t, half, float>(),
              "Input type of val only supports int32_t, uint32_t, half, float.");
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](T a, T b) -> T { return a + b; };
    return warp.WarpOp(val, customFunc);
#else
    return bisheng::cce::simt::__reduce_add(val);
#endif
}

template <typename T>
__aicore__ inline T ReduceMaxSyncImpl(T val)
{
    static_assert(SupportType<T, int32_t, uint32_t, half, float>(),
              "Input type of val only supports int32_t, uint32_t, half, float.");
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](T a, T b) -> T { return a > b ? a : b; };
    return warp.WarpOp(val, customFunc);
#else
    return bisheng::cce::simt::__reduce_max(val);
#endif
}

template <typename T>
__aicore__ inline T ReduceMinSyncImpl(T val)
{
    static_assert(SupportType<T, int32_t, uint32_t, half, float>(),
                  "Input type of val only supports int32_t, uint32_t, half, float.");
#if defined(ASCENDC_CPU_DEBUG)
    uint32_t warpId = GetWarpId();
    Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
    auto customFunc = [](T a, T b) -> T { return a < b ? a : b; };
    return warp.WarpOp(val, customFunc);
#else
    return bisheng::cce::simt::__reduce_min(val);
#endif
}

__aicore__ inline void ThreadBarrierImpl()
{
#if defined(ASCENDC_CPU_DEBUG)
    Sync();
#else
    __sync_workitems();
#endif
}

__aicore__ inline void ThreadFenceImpl()
{
#if defined(ASCENDC_CPU_DEBUG)
#else
    bisheng::cce::simt::__threadfence();
#endif
}
}  // namespace Simt
}  // namespace AscendC
#endif  // ASCENDC_MODULE_SIMT_WARP_LEVEL_IMPL_H