* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ASCENDC_MODULE_SIMT_WARP_LEVEL_IMPL_H
#define ASCENDC_MODULE_SIMT_WARP_LEVEL_IMPL_H
#if defined(ASCENDC_CPU_DEBUG)
#include <cmath>
#include "kernel_utils.h"
#include "simt_api/cpp/kernel_simt_utils.h"
#include "stub_def.h"
#endif
#include "kernel_simt_common_impl.h"
namespace AscendC {
namespace Simt {
#if defined(ASCENDC_CPU_DEBUG)
constexpr int32_t MAX_SHLF_OFFSET = 31;
constexpr int32_t WARP_SIZE = 32;
#else
constexpr int32_t MAX_OFFSET_OF_MODE = 0x1f;
constexpr int32_t MAX_OFFSET_OF_UP_MODE = 0;
constexpr int32_t MAX_OFFSET_START_POS = 8;
constexpr int32_t LANE_MASK_START_POS = 16;
#endif
__aicore__ inline int32_t AllSyncImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](int32_t a, int32_t b) -> int32_t { return (a != 0) && (b != 0); };
return warp.WarpOp(predicate, customFunc);
#else
return bisheng::cce::simt::__all(predicate);
#endif
}
__aicore__ inline int32_t AnySyncImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](int32_t a, int32_t b) -> int32_t { return (a != 0) || (b != 0); };
return warp.WarpOp(predicate, customFunc);
#else
return bisheng::cce::simt::__any(predicate);
#endif
}
__aicore__ inline int32_t UniImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
ASCENDC_ASSERT(false, { KERNEL_LOG(KERNEL_ERROR, "Uni CPU: not supported yet"); });
return 0;
#else
return bisheng::cce::simt::__uni(predicate);
#endif
}
__aicore__ inline uint32_t BallotSyncImpl(int32_t predicate)
{
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
uint32_t laneId = GetLaneId();
uint32_t one = 1;
uint32_t bitNum = 0;
if (predicate != 0) {
bitNum = one << laneId;
}
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](uint32_t a, uint32_t b) -> uint32_t { return a + b; };
return warp.WarpOp(bitNum, customFunc);
#else
return bisheng::cce::simt::__ballot(predicate);
#endif
}
__aicore__ inline uint32_t ActiveMaskImpl()
{
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
uint32_t laneId = GetLaneId();
uint32_t one = 1;
uint32_t bitNum = one << laneId;
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](uint32_t a, uint32_t b) -> uint32_t { return a + b; };
return warp.WarpOp(bitNum, customFunc);
#else
return bisheng::cce::simt::__activemask();
#endif
}
template <typename T>
__aicore__ inline T ShflSyncImpl(T var, int32_t srcLane, int32_t width = warpSize)
{
static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
"Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
{ KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
ASCENDC_ASSERT((srcLane < WARP_SIZE && srcLane >= 0), { KERNEL_LOG(KERNEL_ERROR, "srcLane must be in [0, 32)"); });
uint32_t warpId = GetWarpId();
uint32_t laneId = GetLaneId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
int32_t minLane = laneId / width * width;
srcLane = minLane + srcLane % width;
return warp.WarpShuffleOp(var, laneId, srcLane);
#else
return bisheng::cce::simt::__shfl(var,
((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_MODE << MAX_OFFSET_START_POS) | (srcLane));
#endif
}
template <typename T>
__aicore__ inline T ShflUpSyncImpl(T var, uint32_t delta, int32_t width = warpSize)
{
static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
"Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
{ KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
ASCENDC_ASSERT((delta < WARP_SIZE && delta >= 0), { KERNEL_LOG(KERNEL_ERROR, "delta must be in [0, 32)"); });
uint32_t warpId = GetWarpId();
uint32_t laneId = GetLaneId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
int32_t minLane = laneId / width * width;
int32_t srcLane = laneId - delta;
if (srcLane < minLane) {
srcLane = laneId;
}
return warp.WarpShuffleOp(var, laneId, srcLane);
#else
return bisheng::cce::simt::__shfl_up(var,
((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_UP_MODE << MAX_OFFSET_START_POS) | (delta));
#endif
}
template <typename T>
__aicore__ inline T ShflDownSyncImpl(T var, uint32_t delta, int32_t width = warpSize)
{
static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
"Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
{ KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
ASCENDC_ASSERT((delta < WARP_SIZE && delta >= 0), { KERNEL_LOG(KERNEL_ERROR, "delta must be in [0, 32)"); });
uint32_t warpId = GetWarpId();
uint32_t laneId = GetLaneId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
int32_t minLane = laneId / width * width;
int32_t maxLane = minLane + MAX_SHLF_OFFSET % width;
int32_t srcLane = laneId + delta;
if (srcLane > maxLane) {
srcLane = laneId;
}
return warp.WarpShuffleOp(var, laneId, srcLane);
#else
return bisheng::cce::simt::__shfl_down(var,
((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_MODE << MAX_OFFSET_START_POS) | (delta));
#endif
}
template <typename T>
__aicore__ inline T ShflXorSyncImpl(T var, int32_t laneMask, int32_t width = warpSize)
{
static_assert(SupportType<T, int32_t, uint32_t, int64_t, uint64_t, half, half2, float>(),
"Input type of var only supports int32_t, uint32_t, int64_t, uint64_t, half, half2, float.");
#if defined(ASCENDC_CPU_DEBUG)
ASCENDC_ASSERT((width <= WARP_SIZE && width > 0 && WARP_SIZE % width == 0),
{ KERNEL_LOG(KERNEL_ERROR, "width must be a power of 2 and in (0, 32]"); });
ASCENDC_ASSERT((laneMask < WARP_SIZE && laneMask >= 0),
{ KERNEL_LOG(KERNEL_ERROR, "laneMask must be in [0, 32)"); });
uint32_t warpId = GetWarpId();
uint32_t laneId = GetLaneId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
int32_t minLane = laneId / width * width;
int32_t maxLane = minLane + MAX_SHLF_OFFSET % width;
int32_t srcLane = laneId ^ laneMask;
if (srcLane < 0 || srcLane > maxLane) {
srcLane = laneId;
}
return warp.WarpShuffleOp(var, laneId, srcLane);
#else
return bisheng::cce::simt::__shfl_xor(var,
((warpSize - width) << LANE_MASK_START_POS) | (MAX_OFFSET_OF_MODE << MAX_OFFSET_START_POS) | (laneMask));
#endif
}
template <typename T>
__aicore__ inline T ReduceAddSyncImpl(T val)
{
static_assert(SupportType<T, int32_t, uint32_t, half, float>(),
"Input type of val only supports int32_t, uint32_t, half, float.");
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](T a, T b) -> T { return a + b; };
return warp.WarpOp(val, customFunc);
#else
return bisheng::cce::simt::__reduce_add(val);
#endif
}
template <typename T>
__aicore__ inline T ReduceMaxSyncImpl(T val)
{
static_assert(SupportType<T, int32_t, uint32_t, half, float>(),
"Input type of val only supports int32_t, uint32_t, half, float.");
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](T a, T b) -> T { return a > b ? a : b; };
return warp.WarpOp(val, customFunc);
#else
return bisheng::cce::simt::__reduce_max(val);
#endif
}
template <typename T>
__aicore__ inline T ReduceMinSyncImpl(T val)
{
static_assert(SupportType<T, int32_t, uint32_t, half, float>(),
"Input type of val only supports int32_t, uint32_t, half, float.");
#if defined(ASCENDC_CPU_DEBUG)
uint32_t warpId = GetWarpId();
Warp &warp = ThreadBlock::GetBlockInstance().warps_[warpId];
auto customFunc = [](T a, T b) -> T { return a < b ? a : b; };
return warp.WarpOp(val, customFunc);
#else
return bisheng::cce::simt::__reduce_min(val);
#endif
}
__aicore__ inline void ThreadBarrierImpl()
{
#if defined(ASCENDC_CPU_DEBUG)
Sync();
#else
__sync_workitems();
#endif
}
__aicore__ inline void ThreadFenceImpl()
{
#if defined(ASCENDC_CPU_DEBUG)
#else
bisheng::cce::simt::__threadfence();
#endif
}
}
}
#endif