* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License")
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file select_simt.h
* \brief
*/
#ifndef SELECT_SIMT_H
#define SELECT_SIMT_H
#include "kernel_operator.h"
#include "select_struct.h"
#ifdef __CCE_AICORE__
#include "simt_api/asc_simt.h"
#endif
namespace SelectOp {
using namespace AscendC;
constexpr uint32_t THREAD_NUM_LAUNCH_BOUND = 2048;
template <typename T>
class SelectSimt {
public:
__aicore__ inline SelectSimt(){};
__aicore__ inline void Init(GM_ADDR condition, GM_ADDR x1, GM_ADDR x2, GM_ADDR y,
const SelectSimtTilingData* tilingData);
__aicore__ inline void Process();
private:
static __simt_vf__ __aicore__ LAUNCH_BOUND(THREAD_NUM_LAUNCH_BOUND) inline void OpSelectSimt(int32_t needCoreNum,
int32_t threadNum, int64_t aSize, uint64_t bSize, int64_t currentCoreElements, uint64_t m, uint64_t shift,
uint64_t xyIndexBase, __gm__ uint8_t* condition,__gm__ T* x1, __gm__ T* x2, __gm__ T* y);
private:
GlobalTensor<uint8_t> conditionGm_;
GlobalTensor<T> x1Gm_;
GlobalTensor<T> x2Gm_;
GlobalTensor<T> yGm_;
const SelectSimtTilingData* tilingData_ = nullptr;
};
template <typename T>
__simt_vf__ __aicore__ LAUNCH_BOUND(THREAD_NUM_LAUNCH_BOUND) inline void SelectSimt<T>::OpSelectSimt(int32_t needCoreNum,
int32_t threadNum, int64_t aSize, uint64_t bSize, int64_t currentCoreElements, uint64_t m, uint64_t shift,
uint64_t xyIndexBase, __gm__ uint8_t* condition,__gm__ T* x1, __gm__ T* x2, __gm__ T* y) {
for (uint64_t index = static_cast<uint64_t>(threadIdx.x); index < currentCoreElements;
index += static_cast<uint32_t>(blockDim.x)) {
uint64_t xyIndex = xyIndexBase + index;
uint64_t conditionIdx = Simt::UintDiv(xyIndex, m, shift);
y[xyIndex] = condition[conditionIdx] ? x1[xyIndex] : x2[xyIndex];
}
}
template <typename T>
__aicore__ inline void SelectSimt<T>::Init(GM_ADDR condition, GM_ADDR x1, GM_ADDR x2,
GM_ADDR y, const SelectSimtTilingData* tilingData) {
tilingData_ = tilingData;
conditionGm_.SetGlobalBuffer((__gm__ uint8_t*)condition);
x1Gm_.SetGlobalBuffer((__gm__ T*)x1);
x2Gm_.SetGlobalBuffer((__gm__ T*)x2);
yGm_.SetGlobalBuffer((__gm__ T*)y);
}
template <typename T>
__aicore__ inline void SelectSimt<T>::Process() {
int32_t blockIdx = static_cast<int32_t>(GetBlockIdx());
int32_t needCoreNum = static_cast<int32_t>(tilingData_->needCoreNum);
int32_t threadNum = static_cast<int32_t>(tilingData_->threadNum);
int64_t aSize = static_cast<int64_t>(tilingData_->aSize);
uint64_t bSize = static_cast<uint64_t>(tilingData_->bSize);
int64_t currentCoreElements = static_cast<int64_t>(tilingData_->perCoreElements);
if (blockIdx == tilingData_->needCoreNum - 1) {
currentCoreElements = static_cast<int64_t>(tilingData_->lastCoreElements);
}
uint64_t m {0};
uint64_t shift {0};
uint64_t xyIndexBase = blockIdx * tilingData_->perCoreElements;
GetUintDivMagicAndShift(m, shift, bSize);
asc_vf_call<OpSelectSimt>(dim3(threadNum), needCoreNum, threadNum, aSize, bSize, currentCoreElements, m, shift,
xyIndexBase, (__gm__ uint8_t*) (conditionGm_.GetPhyAddr()), (__gm__ T*) (x1Gm_.GetPhyAddr()), (__gm__ T*) (x2Gm_.GetPhyAddr()),
(__gm__ T*) (yGm_.GetPhyAddr()));
}
}
#endif