* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __ASCENDC_API_WHERE_H__
#define __ASCENDC_API_WHERE_H__
struct DoSelectParams {
__aicore__ DoSelectParams() {
do_size = 0;
mask = 0;
repeat_times = 0;
calc_size = 0;
src0_select_offset = 1;
src1_select_offset = 1;
src0_slice_offset = 0;
src1_slice_offset = 0;
output_stride = 0;
input0_stride = 0;
input1_stride = 0;
mask_stride = 0;
mask_offset = 0;
src0_offset = 0;
src1_offset = 0;
dst_offset = 0;
}
uint32_t do_size;
uint64_t mask;
uint8_t repeat_times;
BinaryRepeatParams rpt_params;
uint32_t calc_size;
AscendC::LocalTensor<half> mask_cast_buf;
AscendC::LocalTensor<int16_t> mask_shift_buf;
AscendC::LocalTensor<float> src0_cast_buf;
AscendC::LocalTensor<float> sel_res_buf;
uint32_t src0_select_offset;
uint32_t src1_select_offset;
uint32_t src0_slice_offset;
uint32_t src1_slice_offset;
uint32_t output_stride;
uint32_t input0_stride;
uint32_t input1_stride;
uint32_t mask_stride;
uint32_t dst_offset;
uint32_t src0_offset;
uint32_t src1_offset;
uint32_t mask_offset;
};
template <typename O, typename I>
inline __aicore__ void SafeCast(const AscendC::LocalTensor<O> &dst,
const AscendC::LocalTensor<I> &src,
const uint32_t size) {
if constexpr (std::is_same<I, float>::value and (std::is_same<O, int64_t>::value || std::is_same<O, int32_t>::value ||
std::is_same<O, int16_t>::value)) {
Cast(dst, src, RoundMode::CAST_RINT, size);
} else if constexpr (std::is_same<I, int64_t>::value and std::is_same<O, float>::value) {
Cast(dst, src, RoundMode::CAST_RINT, size);
} else if constexpr (std::is_same<I, half>::value and std::is_same<O, int16_t>::value) {
Cast(dst, src, RoundMode::CAST_RINT, size);
} else {
Cast(dst, src, RoundMode::CAST_NONE, size);
}
}
template <typename T>
inline __aicore__ void DoSelect(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<float> &src0,
const AscendC::LocalTensor<float> &src1,
const DoSelectParams ¶ms) {
uint32_t offset = params.calc_size;
SafeCast(params.mask_cast_buf[0], sel_mask[offset], params.do_size);
uint32_t cmp_size = (params.do_size + ONE_REPEAT_HALF_SIZE - 1) / ONE_REPEAT_HALF_SIZE * ONE_REPEAT_HALF_SIZE;
LocalTensor<uint8_t> mask_tmp_bit_buf = params.mask_cast_buf.ReinterpretCast<uint8_t>();
CompareScalar(mask_tmp_bit_buf, params.mask_cast_buf[0], (half)1.0, CMPMODE::EQ, cmp_size);
if constexpr (std::is_same<T, float>::value) {
Select(dst[offset],
mask_tmp_bit_buf,
src0[params.src0_select_offset * offset],
src1[params.src1_select_offset * offset],
SELMODE::VSEL_TENSOR_TENSOR_MODE,
params.mask,
params.repeat_times,
params.rpt_params);
} else {
Select(params.sel_res_buf[0],
mask_tmp_bit_buf,
src0[params.src0_select_offset * offset],
src1[params.src1_select_offset * offset],
SELMODE::VSEL_TENSOR_TENSOR_MODE,
params.mask,
params.repeat_times,
params.rpt_params);
SafeCast(dst[offset], params.sel_res_buf[0], params.do_size);
}
}
template <typename O, typename I0, typename I1>
inline __aicore__ void CastBeforeSelect(const AscendC::LocalTensor<O> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<I0> &src0,
const AscendC::LocalTensor<I1> &src1,
DoSelectParams ¶ms) {
constexpr bool src0_is_float = std::is_same<I0, float>::value;
constexpr bool src1_is_float = std::is_same<I1, float>::value;
uint32_t offset = params.calc_size;
if constexpr (src0_is_float && src1_is_float) {
DoSelect(dst, sel_mask, src0, src1, params);
} else if constexpr (!src0_is_float && src1_is_float) {
SafeCast(params.sel_res_buf[0], src0[offset], params.do_size);
params.src0_select_offset = 0;
DoSelect(dst, sel_mask, params.sel_res_buf, src1, params);
} else if constexpr (src0_is_float && !src1_is_float) {
SafeCast(params.sel_res_buf[0], src1[offset], params.do_size);
params.src1_select_offset = 0;
DoSelect(dst, sel_mask, src0, params.sel_res_buf, params);
} else {
SafeCast(params.src0_cast_buf[0], src0[offset], params.do_size);
SafeCast(params.sel_res_buf[0], src1[offset], params.do_size);
params.src0_select_offset = 0;
params.src1_select_offset = 0;
DoSelect(dst, sel_mask, params.src0_cast_buf, params.sel_res_buf, params);
}
}
inline __aicore__ void CastAndSelect(const AscendC::LocalTensor<int64_t> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
DoSelectParams ¶ms) {
uint32_t offset = params.calc_size;
uint32_t mask_offset = offset >> 1;
uint32_t do_size = params.do_size;
uint32_t mask_do_size = do_size >> 1;
const AscendC::LocalTensor<float> dst_buf = dst.ReinterpretCast<float>();
const AscendC::LocalTensor<float> src0_buf = src0.ReinterpretCast<float>();
const AscendC::LocalTensor<float> src1_buf = src1.ReinterpretCast<float>();
SafeCast(params.mask_cast_buf[0], sel_mask[mask_offset], mask_do_size);
SafeCast(params.mask_shift_buf[0], params.mask_cast_buf[0], mask_do_size);
const AscendC::LocalTensor<int16_t> mask_cast_buf_reuse_buf = params.mask_cast_buf.ReinterpretCast<int16_t>();
constexpr int16_t shift_left_size = 8;
ShiftLeft(mask_cast_buf_reuse_buf, params.mask_shift_buf[0], shift_left_size, static_cast<int32_t>(mask_do_size));
Or(params.mask_shift_buf[0], mask_cast_buf_reuse_buf, params.mask_shift_buf[0], static_cast<int32_t>(mask_do_size));
const AscendC::LocalTensor<uint8_t> mask_shift_cast_buf = params.mask_shift_buf.ReinterpretCast<uint8_t>();
AscendC::LocalTensor<half> mask_cast_buf_resize_buf = params.mask_cast_buf.ReinterpretCast<half>();
mask_cast_buf_resize_buf.SetSize(params.mask_cast_buf.GetSize() + params.mask_shift_buf.GetSize());
SafeCast(mask_cast_buf_resize_buf, mask_shift_cast_buf, do_size);
uint32_t cmp_size = (do_size + ONE_REPEAT_HALF_SIZE - 1) / ONE_REPEAT_HALF_SIZE * ONE_REPEAT_HALF_SIZE;
LocalTensor<uint8_t> mask_tmp_bit_buf = mask_cast_buf_resize_buf.ReinterpretCast<uint8_t>();
CompareScalar(mask_tmp_bit_buf, mask_cast_buf_resize_buf, (half)1.0, CMPMODE::EQ, cmp_size);
Select(dst_buf[offset],
mask_tmp_bit_buf,
src0_buf[params.src0_select_offset * offset],
src1_buf[params.src1_select_offset * offset],
SELMODE::VSEL_TENSOR_TENSOR_MODE,
params.mask,
params.repeat_times,
params.rpt_params);
}
template <typename T, typename T1, typename T2>
inline __aicore__ void WhereBase(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<T1> &src0,
const AscendC::LocalTensor<T2> &src1,
const uint32_t size,
DoSelectParams ¶ms) {
constexpr uint32_t ONE_RPT_SIZE = KernelUtils::RptSize<float>();
params.rpt_params.blockNumber = KernelUtils::BlkNum<float>(size);
params.mask = ONE_RPT_SIZE;
uint32_t max_buf_size = params.sel_res_buf.GetSize() / sizeof(float);
uint32_t max_buf_rpt_num = KernelUtils::RptNum<float>(max_buf_size);
uint32_t max_do_rpt_num = KernelUtils::Min(MAX_REPEAT_TIME, max_buf_rpt_num);
uint32_t max_do_size = max_do_rpt_num * ONE_RPT_SIZE;
if (max_do_size <= size) {
params.repeat_times = max_do_rpt_num;
params.do_size = max_do_size;
for (; params.calc_size + params.do_size < size; params.calc_size += params.do_size) {
CastBeforeSelect(dst, sel_mask, src0, src1, params);
}
}
if (params.calc_size + ONE_RPT_SIZE <= size) {
uint32_t left_rpt_num = KernelUtils::RptNum<float>(size - params.calc_size);
params.repeat_times = left_rpt_num;
params.do_size = left_rpt_num * KernelUtils::RptSize<float>();
CastBeforeSelect(dst, sel_mask, src0, src1, params);
params.calc_size += params.do_size;
}
if (params.calc_size < size) {
constexpr uint32_t redundant_size = KernelUtils::BlkSize<float>() - 1;
uint32_t left_size = size - params.calc_size;
params.rpt_params.blockNumber = KernelUtils::BlkNum<float>(left_size + redundant_size);
params.do_size = left_size;
params.mask = params.rpt_params.blockNumber * KernelUtils::BlkSize<float>();
params.repeat_times = 1;
CastBeforeSelect(dst, sel_mask, src0, src1, params);
}
}
inline __aicore__ void WhereBase(const AscendC::LocalTensor<int64_t> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
const uint32_t size_,
DoSelectParams ¶ms) {
const uint32_t mask_size = size_;
const uint32_t data_size = size_ << 1;
constexpr uint32_t ONE_RPT_SIZE = KernelUtils::RptSize<float>();
params.rpt_params.blockNumber = KernelUtils::BlkNum<float>(data_size);
params.mask = ONE_RPT_SIZE;
uint32_t max_buf_rpt_num = KernelUtils::RptNum<float>(params.mask_shift_buf.GetSize());
uint32_t max_do_rpt_num = KernelUtils::Min(MAX_REPEAT_TIME, max_buf_rpt_num);
uint32_t max_do_size = max_do_rpt_num * ONE_RPT_SIZE;
if (max_do_size <= data_size) {
params.repeat_times = max_do_rpt_num;
params.do_size = max_do_size;
for (; params.calc_size + params.do_size < data_size; params.calc_size += params.do_size) {
CastAndSelect(dst, sel_mask, src0, src1, params);
}
}
if (params.calc_size + ONE_RPT_SIZE <= data_size) {
uint32_t left_rpt_num = KernelUtils::RptNum<float>(data_size - params.calc_size);
params.repeat_times = left_rpt_num;
params.do_size = left_rpt_num * KernelUtils::RptSize<float>();
CastAndSelect(dst, sel_mask, src0, src1, params);
params.calc_size += params.do_size;
}
if (params.calc_size < data_size) {
constexpr uint32_t redundant_size = KernelUtils::BlkSize<float>() - 1;
uint32_t left_size = data_size - params.calc_size;
params.rpt_params.blockNumber = KernelUtils::BlkNum<float>(left_size + redundant_size);
params.do_size = left_size;
params.mask = params.rpt_params.blockNumber * KernelUtils::BlkSize<float>();
params.repeat_times = 1;
CastAndSelect(dst, sel_mask, src0, src1, params);
}
}
inline __aicore__ void WherePartition2Buffer(AscendC::LocalTensor<uint8_t> &tmp_buf, DoSelectParams ¶ms,
uint32_t used_size) {
constexpr uint32_t partition_number = sizeof(half) + sizeof(float);
uint32_t rest_buf_byte_size = tmp_buf.GetSize() - used_size;
uint32_t each_part_buf_byte_size =
rest_buf_byte_size / partition_number / ONE_REPEAT_BYTE_SIZE * ONE_REPEAT_BYTE_SIZE;
params.mask_cast_buf = tmp_buf[used_size].template ReinterpretCast<half>();
params.mask_cast_buf.SetSize(each_part_buf_byte_size);
params.sel_res_buf = tmp_buf[used_size + each_part_buf_byte_size].template ReinterpretCast<float>();
params.sel_res_buf.SetSize(each_part_buf_byte_size);
}
inline __aicore__ void WherePartition3Buffer(AscendC::LocalTensor<uint8_t> &tmp_buf, DoSelectParams ¶ms,
uint32_t used_size) {
constexpr uint32_t partition_number = sizeof(half) + sizeof(float) + sizeof(float);
uint32_t rest_buf_byte_size = tmp_buf.GetSize() - used_size;
uint32_t each_part_buf_byte_size =
rest_buf_byte_size / partition_number / ONE_REPEAT_BYTE_SIZE * ONE_REPEAT_BYTE_SIZE;
params.mask_cast_buf = tmp_buf[used_size].template ReinterpretCast<half>();
params.mask_cast_buf.SetSize(each_part_buf_byte_size);
uint32_t offset = used_size + each_part_buf_byte_size;
params.sel_res_buf = tmp_buf[offset].template ReinterpretCast<float>();
params.sel_res_buf.SetSize(each_part_buf_byte_size);
offset += each_part_buf_byte_size;
params.src0_cast_buf = tmp_buf[offset].template ReinterpretCast<float>();
params.src0_cast_buf.SetSize(each_part_buf_byte_size);
}
inline __aicore__ void WherePartitionBufferInt64(AscendC::LocalTensor<uint8_t> &tmp_buf, DoSelectParams ¶ms,
uint32_t used_size) {
constexpr uint32_t partition_number = sizeof(half) + sizeof(int16_t);
uint32_t rest_buf_byte_size = tmp_buf.GetSize() - used_size;
uint32_t each_part_buf_byte_size =
rest_buf_byte_size / partition_number / ONE_REPEAT_BYTE_SIZE * ONE_REPEAT_BYTE_SIZE;
params.mask_cast_buf = tmp_buf[used_size].template ReinterpretCast<half>();
params.mask_cast_buf.SetSize(each_part_buf_byte_size);
uint32_t offset = used_size + each_part_buf_byte_size * sizeof(half);
params.mask_shift_buf = tmp_buf[offset].template ReinterpretCast<int16_t>();
params.mask_shift_buf.SetSize(each_part_buf_byte_size);
}
* 场景1: src0和src1都是标量,输出Shape与mask相同
*/
template <typename T>
inline __aicore__ void Where(const AscendC::LocalTensor<T> &dst, const AscendC::LocalTensor<uint8_t> &mask, T src0,
T src1, const uint32_t size, AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<float> src0_buf = tmp_buf[0].template ReinterpretCast<float>();
Duplicate(src0_buf[0], (float)src0, KernelUtils::BlkSize<float>());
LocalTensor<float> src1_buf = tmp_buf[ONE_BLK_SIZE].template ReinterpretCast<float>();
Duplicate(src1_buf[0], (float)src1, KernelUtils::BlkSize<float>());
DoSelectParams params;
WherePartition2Buffer(tmp_buf, params, ONE_BLK_SIZE * 2);
params.rpt_params.src0BlkStride = 0;
params.rpt_params.src0RepStride = 0;
params.rpt_params.src1BlkStride = 0;
params.rpt_params.src1RepStride = 0;
params.src0_select_offset = 0;
params.src1_select_offset = 0;
WhereBase(dst, mask, src0_buf, src1_buf, size, params);
}
* 场景2: src0是标量,src1及输出的Shape与mask相同
*/
template <typename T>
inline __aicore__ void Where(const AscendC::LocalTensor<T> &dst, const AscendC::LocalTensor<uint8_t> &mask, T src0,
const AscendC::LocalTensor<T> &src1, const uint32_t size,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<float> src0_buf = tmp_buf[0].template ReinterpretCast<float>();
Duplicate(src0_buf[0], (float)src0, KernelUtils::BlkSize<float>());
DoSelectParams params;
WherePartition2Buffer(tmp_buf, params, ONE_BLK_SIZE);
params.rpt_params.src0BlkStride = 0;
params.rpt_params.src0RepStride = 0;
params.src0_select_offset = 0;
WhereBase(dst, mask, src0_buf, src1, size, params);
}
inline __aicore__ void Where(const AscendC::LocalTensor<int64_t> &dst, const AscendC::LocalTensor<uint8_t> &mask,
float src0, const AscendC::LocalTensor<int64_t> &src1, const uint32_t size,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
Where(dst, mask, static_cast<int64_t>(src0), src1, size, tmp_buf);
}
* 场景3: src1是标量,src0及输出的Shape与mask相同
*/
template <typename T>
inline __aicore__ void Where(const AscendC::LocalTensor<T> &dst, const AscendC::LocalTensor<uint8_t> &mask,
const AscendC::LocalTensor<T> &src0, T src1, const uint32_t size,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<float> src1_buf = tmp_buf[0].template ReinterpretCast<float>();
Duplicate(src1_buf[0], (float)src1, KernelUtils::BlkSize<float>());
DoSelectParams params;
WherePartition2Buffer(tmp_buf, params, ONE_BLK_SIZE);
params.rpt_params.src1BlkStride = 0;
params.rpt_params.src1RepStride = 0;
params.src1_select_offset = 0;
WhereBase(dst, mask, src0, src1_buf, size, params);
}
* 场景4: src0和src1都不是标量,且Shape均与mask相同,且不需要广播。
*/
template <typename T>
inline __aicore__ void Where(const AscendC::LocalTensor<T> &dst, const AscendC::LocalTensor<uint8_t> &mask,
const AscendC::LocalTensor<T> &src0, const AscendC::LocalTensor<T> &src1,
const uint32_t size, AscendC::LocalTensor<uint8_t> &tmp_buf) {
DoSelectParams params;
WherePartition3Buffer(tmp_buf, params, 0);
WhereBase(dst, mask, src0, src1, size, params);
}
* 场景5: int64+场景1:src0和src1都是标量,输出Shape与mask相同
*/
inline __aicore__ void Where(const AscendC::LocalTensor<int64_t> &dst, const AscendC::LocalTensor<uint8_t> &mask,
int64_t src0, int64_t src1, const uint32_t size, AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<int64_t> src0_buf = tmp_buf[0].template ReinterpretCast<int64_t>();
LocalTensor<uint8_t> left_tmp_buf = tmp_buf[ONE_BLK_SIZE].template ReinterpretCast<uint8_t>();
Duplicate(src0_buf[0], src0, KernelUtils::BlkSize<int64_t>(), left_tmp_buf);
LocalTensor<int64_t> src1_buf = left_tmp_buf[0].template ReinterpretCast<int64_t>();
left_tmp_buf = left_tmp_buf[ONE_BLK_SIZE].template ReinterpretCast<uint8_t>();
Duplicate(src1_buf[0], src1, KernelUtils::BlkSize<int64_t>(), left_tmp_buf);
DoSelectParams params;
WherePartitionBufferInt64(tmp_buf, params, ONE_BLK_SIZE * 2);
params.rpt_params.src0BlkStride = 0;
params.rpt_params.src0RepStride = 0;
params.rpt_params.src1BlkStride = 0;
params.rpt_params.src1RepStride = 0;
params.src0_select_offset = 0;
params.src1_select_offset = 0;
WhereBase(dst, mask, src0_buf, src1_buf, size, params);
}
* 场景6: int64+场景2: src0是标量,src1及输出的Shape与mask相同
*/
inline __aicore__ void Where(const AscendC::LocalTensor<int64_t> &dst, const AscendC::LocalTensor<uint8_t> &mask,
int64_t src0, const AscendC::LocalTensor<int64_t> &src1, const uint32_t size,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<int64_t> src0_buf = tmp_buf[0].template ReinterpretCast<int64_t>();
LocalTensor<uint8_t> left_tmp_buf = tmp_buf[ONE_BLK_SIZE].template ReinterpretCast<uint8_t>();
Duplicate(src0_buf[0], src0, KernelUtils::BlkSize<int64_t>(), left_tmp_buf);
DoSelectParams params;
WherePartitionBufferInt64(tmp_buf, params, ONE_BLK_SIZE);
params.rpt_params.src0BlkStride = 0;
params.rpt_params.src0RepStride = 0;
params.src0_select_offset = 0;
WhereBase(dst, mask, src0_buf, src1, size, params);
}
* 场景7: int64+场景3: src1是标量,src0及输出的Shape与mask相同
*/
inline __aicore__ void Where(const AscendC::LocalTensor<int64_t> &dst, const AscendC::LocalTensor<uint8_t> &mask,
const AscendC::LocalTensor<int64_t> &src0, int64_t src1, const uint32_t size,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<int64_t> src1_buf = tmp_buf[0].template ReinterpretCast<int64_t>();
LocalTensor<uint8_t> left_tmp_buf = tmp_buf[ONE_BLK_SIZE].template ReinterpretCast<uint8_t>();
Duplicate(src1_buf[0], src1, KernelUtils::BlkSize<int64_t>(), left_tmp_buf);
DoSelectParams params;
WherePartitionBufferInt64(tmp_buf, params, ONE_BLK_SIZE);
params.rpt_params.src1BlkStride = 0;
params.rpt_params.src1RepStride = 0;
params.src1_select_offset = 0;
WhereBase(dst, mask, src0, src1_buf, size, params);
}
* 场景8: int64+场景4:src0和src1都不是标量,且Shape均与mask相同,且不需要广播。
*/
inline __aicore__ void Where(const AscendC::LocalTensor<int64_t> &dst, const AscendC::LocalTensor<uint8_t> &mask,
const AscendC::LocalTensor<int64_t> &src0, const AscendC::LocalTensor<int64_t> &src1,
const uint32_t size, AscendC::LocalTensor<uint8_t> &tmp_buf) {
DoSelectParams params;
WherePartitionBufferInt64(tmp_buf, params, 0);
WhereBase(dst, mask, src0, src1, size, params);
}
* Normal 模式:两根轴
*/
template <typename O, typename I, bool isSlice2Buf>
inline __aicore__ void SafeCastNormal(const AscendC::LocalTensor<O> &dst,
const AscendC::LocalTensor<I> &src,
const uint64_t mask, const uint8_t repeat_times, const uint32_t stride) {
uint8_t dstRepStride = 8;
uint8_t srcRepStride = 8;
if constexpr (isSlice2Buf) {
dstRepStride = static_cast<uint8_t>(mask * sizeof(O) / ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(stride * sizeof(I) / ONE_BLK_SIZE);
} else {
dstRepStride = static_cast<uint8_t>(stride * sizeof(O) / ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(mask * sizeof(I) / ONE_BLK_SIZE);
}
if constexpr (std::is_same<I, float>::value and
(std::is_same<O, int64_t>::value || std::is_same<O, int32_t>::value ||
std::is_same<O, int16_t>::value)) {
Cast(dst, src, RoundMode::CAST_RINT, mask, repeat_times, {1, 1, dstRepStride, srcRepStride});
} else if constexpr (std::is_same<I, int64_t>::value and std::is_same<O, float>::value) {
Cast(dst, src, RoundMode::CAST_RINT, mask, repeat_times, {1, 1, dstRepStride, srcRepStride});
} else if constexpr (std::is_same<I, half>::value and std::is_same<O, int16_t>::value) {
Cast(dst, src, RoundMode::CAST_RINT, mask, repeat_times, {1, 1, dstRepStride, srcRepStride});
} else {
Cast(dst, src, RoundMode::CAST_NONE, mask, repeat_times, {1, 1, dstRepStride, srcRepStride});
}
}
inline __aicore__ void MaskSafeCastNormal(const AscendC::LocalTensor<half> &dst,
const AscendC::LocalTensor<uint8_t> &src,
const uint64_t mask, const uint8_t repeat_times, const uint32_t stride) {
uint8_t dstRepStride = 8;
uint8_t srcRepStride = 8;
dstRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>() * sizeof(half)/ ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(stride * sizeof(uint8_t) / ONE_BLK_SIZE);
Cast(dst, src, RoundMode::CAST_NONE, mask, repeat_times, {1, 1, dstRepStride, srcRepStride});
}
inline __aicore__ void MaskSafeCastInt64Normal(const AscendC::LocalTensor<half> &dst,
const AscendC::LocalTensor<uint8_t> &src,
const uint64_t mask, const uint8_t repeat_times, const uint32_t stride) {
uint8_t dstRepStride = 8;
uint8_t srcRepStride = 8;
dstRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(half)/ ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(stride * sizeof(uint8_t) / ONE_BLK_SIZE);
Cast(dst, src, RoundMode::CAST_NONE, mask, repeat_times, {1, 1, dstRepStride, srcRepStride});
}
template <typename T>
inline __aicore__ void DoSelectNormal(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<float> &src0,
const AscendC::LocalTensor<float> &src1,
DoSelectParams ¶ms) {
MaskSafeCastNormal(params.mask_cast_buf[0], sel_mask[params.mask_offset], params.mask,
params.repeat_times, params.mask_stride);
uint32_t cmp_size = (KernelUtils::RptSize<float>() * params.repeat_times + ONE_REPEAT_HALF_SIZE - 1)
/ ONE_REPEAT_HALF_SIZE * ONE_REPEAT_HALF_SIZE;
LocalTensor<uint8_t> mask_tmp_bit_buf = params.mask_cast_buf.ReinterpretCast<uint8_t>();
CompareScalar(mask_tmp_bit_buf, params.mask_cast_buf[0], (half)1.0, CMPMODE::EQ, cmp_size);
if constexpr (std::is_same<T, float>::value) {
Select(dst[params.dst_offset],
mask_tmp_bit_buf,
src0[params.src0_select_offset * params.src0_offset],
src1[params.src1_select_offset * params.src1_offset],
SELMODE::VSEL_TENSOR_TENSOR_MODE,
params.mask,
params.repeat_times,
params.rpt_params);
} else {
params.rpt_params.dstRepStride = params.mask * sizeof(float) / ONE_BLK_SIZE;
Select(params.sel_res_buf[0],
mask_tmp_bit_buf,
src0[params.src0_select_offset * params.src0_offset],
src1[params.src1_select_offset * params.src1_offset],
SELMODE::VSEL_TENSOR_TENSOR_MODE,
params.mask,
params.repeat_times,
params.rpt_params);
SafeCastNormal<T, float, false>(dst[params.dst_offset], params.sel_res_buf[0], params.mask, params.repeat_times,
params.output_stride);
}
}
template <typename O, typename I0, typename I1>
inline __aicore__ void CastSelectNormal(const AscendC::LocalTensor<O> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<I0> &src0,
const AscendC::LocalTensor<I1> &src1,
DoSelectParams ¶ms) {
constexpr bool src0_is_float = std::is_same<I0, float>::value;
constexpr bool src1_is_float = std::is_same<I1, float>::value;
if constexpr (src0_is_float && src1_is_float) {
DoSelectNormal(dst, sel_mask, src0, src1, params);
} else if constexpr (!src0_is_float && src1_is_float) {
SafeCastNormal<float, I0, true>(params.sel_res_buf[0], src0[params.src0_offset], params.mask, params.repeat_times,
params.input0_stride);
params.src0_select_offset = 0;
params.rpt_params.src0RepStride = params.mask * sizeof(float) / ONE_BLK_SIZE;
DoSelectNormal(dst, sel_mask, params.sel_res_buf, src1, params);
} else if constexpr (src0_is_float && !src1_is_float) {
SafeCastNormal<float, I1, true>(params.sel_res_buf[0], src1[params.src1_offset], params.mask, params.repeat_times,
params.input1_stride);
params.src1_select_offset = 0;
params.rpt_params.src1RepStride = params.mask * sizeof(float) / ONE_BLK_SIZE;
DoSelectNormal(dst, sel_mask, src0, params.sel_res_buf, params);
} else {
SafeCastNormal<float, I0, true>(params.src0_cast_buf[0], src0[params.src0_offset], params.mask, params.repeat_times,
params.input0_stride);
SafeCastNormal<float, I1, true>(params.sel_res_buf[0], src1[params.src1_offset], params.mask, params.repeat_times,
params.input1_stride);
params.src0_select_offset = 0;
params.src1_select_offset = 0;
params.rpt_params.src0RepStride = params.mask * sizeof(float) / ONE_BLK_SIZE;
params.rpt_params.src1RepStride = params.mask * sizeof(float) / ONE_BLK_SIZE;
DoSelectNormal(dst, sel_mask, params.src0_cast_buf, params.sel_res_buf, params);
}
}
inline __aicore__ void CastSelectInt64Normal(const AscendC::LocalTensor<int64_t> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
DoSelectParams ¶ms) {
const AscendC::LocalTensor<float> dst_buf = dst.ReinterpretCast<float>();
const AscendC::LocalTensor<float> src0_buf = src0.ReinterpretCast<float>();
const AscendC::LocalTensor<float> src1_buf = src1.ReinterpretCast<float>();
MaskSafeCastInt64Normal(params.mask_cast_buf[0], sel_mask[params.mask_offset], params.mask / 2,
params.repeat_times, params.mask_stride);
uint8_t dstRepStride = 8;
uint8_t srcRepStride = 8;
dstRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(int16_t) / ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(half) / ONE_BLK_SIZE);
Cast(params.mask_shift_buf[0], params.mask_cast_buf[0], RoundMode::CAST_RINT, params.mask/2, params.repeat_times, {1, 1, dstRepStride, srcRepStride});
const AscendC::LocalTensor<int16_t> mask_cast_buf_reuse_buf = params.mask_cast_buf.ReinterpretCast<int16_t>();
constexpr int16_t shift_left_size = 8;
dstRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(int16_t) / ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(int16_t) / ONE_BLK_SIZE);
ShiftLeft(mask_cast_buf_reuse_buf, params.mask_shift_buf[0], shift_left_size, params.mask/2, params.repeat_times, {1, 1, dstRepStride, srcRepStride});
dstRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(int16_t) / ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>()/2 * sizeof(int16_t) / ONE_BLK_SIZE);
Or(params.mask_shift_buf[0], mask_cast_buf_reuse_buf, params.mask_shift_buf[0], params.mask/2, params.repeat_times, {1, 1, 1, dstRepStride, srcRepStride, srcRepStride});
const AscendC::LocalTensor<uint8_t> mask_shift_cast_buf = params.mask_shift_buf.ReinterpretCast<uint8_t>();
AscendC::LocalTensor<half> mask_cast_buf_resize_buf = params.mask_cast_buf.ReinterpretCast<half>();
mask_cast_buf_resize_buf.SetSize(params.mask_cast_buf.GetSize() + params.mask_shift_buf.GetSize());
dstRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>() * sizeof(half)/ ONE_BLK_SIZE);
srcRepStride = static_cast<uint8_t>(KernelUtils::RptSize<float>() * sizeof(uint8_t) / ONE_BLK_SIZE);
Cast(mask_cast_buf_resize_buf, mask_shift_cast_buf, RoundMode::CAST_NONE, params.mask, params.repeat_times, {1, 1, dstRepStride, srcRepStride});
uint32_t cmp_size =
(KernelUtils::RptSize<float>() * params.repeat_times + ONE_REPEAT_HALF_SIZE - 1) / ONE_REPEAT_HALF_SIZE * ONE_REPEAT_HALF_SIZE;
LocalTensor<uint8_t> mask_tmp_bit_buf = mask_cast_buf_resize_buf.ReinterpretCast<uint8_t>();
CompareScalar(mask_tmp_bit_buf, mask_cast_buf_resize_buf, (half)1.0, CMPMODE::EQ, cmp_size);
Select(dst_buf[params.dst_offset],
mask_tmp_bit_buf,
src0_buf[params.src0_select_offset * params.src0_offset],
src1_buf[params.src1_select_offset * params.src1_offset],
SELMODE::VSEL_TENSOR_TENSOR_MODE,
params.mask,
params.repeat_times,
params.rpt_params);
}
template <typename T, typename T1, typename T2>
inline __aicore__ void WhereExtend(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<T1> &src0,
const AscendC::LocalTensor<T2> &src1,
const uint32_t first_axis, const uint32_t last_axis, DoSelectParams ¶ms) {
constexpr uint32_t ONE_RPT_SIZE = KernelUtils::RptSize<float>();
uint32_t element_extent = last_axis / ONE_RPT_SIZE;
uint32_t element_reminder = last_axis - element_extent * ONE_RPT_SIZE;
uint32_t max_do_rpt_num = KernelUtils::Min(MAX_REPEAT_TIME, params.sel_res_buf.GetSize() / sizeof(float) / ONE_RPT_SIZE);
max_do_rpt_num = max_do_rpt_num / (sizeof(float) / sizeof(T)) * (sizeof(float) / sizeof(T));
uint32_t repeat_throw_for_extent = first_axis / max_do_rpt_num;
uint32_t repeat_reminder = first_axis - repeat_throw_for_extent * max_do_rpt_num;
params.mask = ONE_RPT_SIZE;
for (uint32_t outer_for = 0; outer_for < element_extent; outer_for++) {
params.mask_offset = outer_for * ONE_RPT_SIZE;
params.dst_offset = outer_for * ONE_RPT_SIZE;
params.src0_offset = outer_for * ONE_RPT_SIZE;
params.src1_offset = outer_for * ONE_RPT_SIZE;
params.repeat_times = max_do_rpt_num;
for (uint32_t inner_for = 0; inner_for < repeat_throw_for_extent; inner_for++) {
CastSelectNormal(dst, sel_mask, src0, src1, params);
params.mask_offset += max_do_rpt_num * params.mask_stride;
params.dst_offset += max_do_rpt_num * params.output_stride;
params.src0_offset += max_do_rpt_num * params.input0_stride;
params.src1_offset += max_do_rpt_num * params.input1_stride;
}
if (repeat_reminder != 0) {
params.repeat_times = repeat_reminder;
CastSelectNormal(dst, sel_mask, src0, src1, params);
}
}
if (element_reminder != 0) {
params.mask = KernelUtils::BlkNum<float>(element_reminder + KernelUtils::BlkSize<float>() - 1) * KernelUtils::BlkSize<float>();
params.mask_offset = element_extent * ONE_RPT_SIZE;
params.dst_offset = element_extent * ONE_RPT_SIZE;
params.src0_offset = element_extent * ONE_RPT_SIZE;
params.src1_offset = element_extent * ONE_RPT_SIZE;
params.repeat_times = max_do_rpt_num;
for (uint32_t inner_for = 0; inner_for < repeat_throw_for_extent; inner_for++) {
CastSelectNormal(dst, sel_mask, src0, src1, params);
params.mask_offset += max_do_rpt_num * params.mask_stride;
params.dst_offset += max_do_rpt_num * params.output_stride;
params.src0_offset += max_do_rpt_num * params.input0_stride;
params.src1_offset += max_do_rpt_num * params.input1_stride;
}
if (repeat_reminder != 0) {
params.repeat_times = repeat_reminder;
CastSelectNormal(dst, sel_mask, src0, src1, params);
}
}
}
inline __aicore__ void WhereExtend(const AscendC::LocalTensor<int64_t> &dst,
const AscendC::LocalTensor<uint8_t> &sel_mask,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
const uint32_t first_axis, const uint32_t last_axis, DoSelectParams ¶ms) {
const uint32_t size = last_axis << 1;
constexpr uint32_t ONE_RPT_SIZE = KernelUtils::RptSize<float>();
uint32_t element_extent = size / ONE_RPT_SIZE;
uint32_t element_reminder = size - element_extent * ONE_RPT_SIZE;
uint32_t max_do_rpt_num = KernelUtils::Min(MAX_REPEAT_TIME, params.mask_shift_buf.GetSize() / sizeof(float) / ONE_RPT_SIZE);
uint32_t repeat_throw_for_extent = first_axis / max_do_rpt_num;
uint32_t repeat_reminder = first_axis - repeat_throw_for_extent * max_do_rpt_num;
params.mask = ONE_RPT_SIZE;
for (uint32_t outer_for = 0; outer_for < element_extent; outer_for++) {
params.mask_offset = outer_for * ONE_RPT_SIZE / 2;
params.dst_offset = outer_for * ONE_RPT_SIZE;
params.src0_offset = outer_for * ONE_RPT_SIZE;
params.src1_offset = outer_for * ONE_RPT_SIZE;
params.repeat_times = max_do_rpt_num;
for (uint32_t inner_for = 0; inner_for < repeat_throw_for_extent; inner_for++) {
CastSelectInt64Normal(dst, sel_mask, src0, src1, params);
params.mask_offset += max_do_rpt_num * params.mask_stride;
params.dst_offset += max_do_rpt_num * params.output_stride * 2;
params.src0_offset += max_do_rpt_num * params.input0_stride * 2;
params.src1_offset += max_do_rpt_num * params.input1_stride * 2;
}
if (repeat_reminder != 0) {
params.repeat_times = repeat_reminder;
CastSelectInt64Normal(dst, sel_mask, src0, src1, params);
}
}
if (element_reminder != 0) {
params.mask = KernelUtils::BlkNum<float>(element_reminder + KernelUtils::BlkSize<float>() - 1) * KernelUtils::BlkSize<float>();
params.mask_offset = element_extent * ONE_RPT_SIZE / 2;
params.dst_offset = element_extent * ONE_RPT_SIZE;
params.src0_offset = element_extent * ONE_RPT_SIZE;
params.src1_offset = element_extent * ONE_RPT_SIZE;
params.repeat_times = max_do_rpt_num;
for (uint32_t inner_for = 0; inner_for < repeat_throw_for_extent; inner_for++) {
CastSelectInt64Normal(dst, sel_mask, src0, src1, params);
params.mask_offset += max_do_rpt_num * params.mask_stride;
params.dst_offset += max_do_rpt_num * params.output_stride * 2;
params.src0_offset += max_do_rpt_num * params.input0_stride * 2;
params.src1_offset += max_do_rpt_num * params.input1_stride * 2;
}
if (repeat_reminder != 0) {
params.repeat_times = repeat_reminder;
CastSelectInt64Normal(dst, sel_mask, src0, src1, params);
}
}
}
template <bool isBcastSrc0 = false, bool isBcastSrc1 = false, typename T, typename T1, typename T2>
inline __aicore__ void Where(const AscendC::LocalTensor<T> &dst, const AscendC::LocalTensor<uint8_t> &mask,
const AscendC::LocalTensor<T1> &src0, const AscendC::LocalTensor<T2> &src1,
const uint32_t first_axis, const uint32_t last_axis,
const uint32_t output_last_axis_stride,
const uint32_t mask_last_axis_stride,
const uint32_t input0_last_axis_stride,
const uint32_t input1_last_axis_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf, const uint32_t used_size) {
constexpr uint32_t MAX_VALID_STRIDE_BYTES = ONE_BLK_SIZE * 256 ;
bool useWhereExtend = input0_last_axis_stride * sizeof(T1) < MAX_VALID_STRIDE_BYTES && input0_last_axis_stride * sizeof(float) < MAX_VALID_STRIDE_BYTES &&
input1_last_axis_stride * sizeof(T2) < MAX_VALID_STRIDE_BYTES && input1_last_axis_stride * sizeof(float) < MAX_VALID_STRIDE_BYTES &&
output_last_axis_stride * sizeof(T) < MAX_VALID_STRIDE_BYTES && output_last_axis_stride * sizeof(float) < MAX_VALID_STRIDE_BYTES &&
mask_last_axis_stride * sizeof(uint8_t) < MAX_VALID_STRIDE_BYTES;
DoSelectParams params;
constexpr bool is_int64_scene = std::is_same<T, int64_t>::value && std::is_same<T1, int64_t>::value && std::is_same<T2, int64_t>::value;
if constexpr (isBcastSrc0 && isBcastSrc1) {
if constexpr (is_int64_scene) {
WherePartitionBufferInt64(tmp_buf, params, used_size);
} else {
WherePartition2Buffer(tmp_buf, params, used_size);
}
params.rpt_params.src0BlkStride = 0;
params.rpt_params.src1BlkStride = 0;
params.rpt_params.src0RepStride = 0;
params.rpt_params.src1RepStride = 0;
params.src0_select_offset = 0;
params.src1_select_offset = 0;
} else if constexpr (isBcastSrc0) {
if constexpr (is_int64_scene) {
WherePartitionBufferInt64(tmp_buf, params, used_size);
} else {
WherePartition2Buffer(tmp_buf, params, used_size);
}
params.rpt_params.src0BlkStride = 0;
params.rpt_params.src0RepStride = 0;
params.src0_select_offset = 0;
if (useWhereExtend){
params.rpt_params.src1BlkStride = 1;
params.rpt_params.src1RepStride = input1_last_axis_stride * sizeof(T2) / ONE_BLK_SIZE;
}
} else if constexpr (isBcastSrc1) {
if constexpr (is_int64_scene) {
WherePartitionBufferInt64(tmp_buf, params, used_size);
} else {
WherePartition2Buffer(tmp_buf, params, used_size);
}
params.rpt_params.src1BlkStride = 0;
params.rpt_params.src1RepStride = 0;
params.src1_select_offset = 0;
if (useWhereExtend) {
params.rpt_params.src0BlkStride = 1;
params.rpt_params.src0RepStride = input0_last_axis_stride * sizeof(T1) / ONE_BLK_SIZE;
}
} else {
if constexpr (is_int64_scene) {
WherePartitionBufferInt64(tmp_buf, params, used_size);
} else {
WherePartition3Buffer(tmp_buf, params, used_size);
}
if (useWhereExtend) {
params.rpt_params.src0BlkStride = 1;
params.rpt_params.src1BlkStride = 1;
params.rpt_params.src0RepStride = input0_last_axis_stride * sizeof(T1) / ONE_BLK_SIZE;
params.rpt_params.src1RepStride = input1_last_axis_stride * sizeof(T2) / ONE_BLK_SIZE;
}
}
if (useWhereExtend) {
params.rpt_params.dstRepStride = output_last_axis_stride * sizeof(T) / ONE_BLK_SIZE;
params.output_stride = output_last_axis_stride;
params.mask_stride = mask_last_axis_stride;
params.input0_stride = input0_last_axis_stride;
params.input1_stride = input1_last_axis_stride;
WhereExtend(dst, mask, src0, src1, first_axis, last_axis, params);
} else {
for (uint32_t i = 0;i < first_axis;i++) {
DoSelectParams tempParams = params;
WhereBase(dst[i * output_last_axis_stride], mask[i * mask_last_axis_stride],
src0[i * params.input0_stride * params.src0_select_offset],
src1[i * params.input1_stride * params.src1_select_offset], last_axis, tempParams);
}
}
}
#endif