* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __ASCENDC_API_COMPARE_V2_H__
#define __ASCENDC_API_COMPARE_V2_H__
constexpr uint32_t kShortBitWidth = 16;
template <typename T, CMPMODE mode>
inline __aicore__ void CompareNormalNoLoop(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<T> &src0,
const AscendC::LocalTensor<T> &src1,
const uint8_t repeat_times, const uint32_t last_axis,
const uint8_t dst_repeat_stride, const uint8_t src_repeat_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
uint32_t tmp_offset = 0;
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
uint8_t compare_element_num = 256 / sizeof(T);
compare_out.SetSize((repeat_times * compare_element_num) / 8);
tmp_offset += (repeat_times * 256 / sizeof(T)) * sizeof(uint8_t) / 8;
LocalTensor<half> select_out = tmp_buf[KernelUtils::BlkAlign<uint8_t>(tmp_offset)].template ReinterpretCast<half>();
uint32_t select_repeat_times = DivCeil(repeat_times * 2, sizeof(T));
select_out.SetSize(128 * select_repeat_times);
const uint64_t mask = last_axis;
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(T) == 32){
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
}
if constexpr ((AscendC::IsSameType<T, float>::value || AscendC::IsSameType<T, half>::value) && mode == CMPMODE::NE) {
uint32_t elements_per_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(T);
uint32_t select_bits_per_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(half);
uint8_t cast_dst_blk_stride = dst_repeat_stride;
uint32_t mask_int16_count = (compare_element_num + kShortBitWidth - 1) / kShortBitWidth;
LocalTensor<uint8_t> nan_buf = select_out.ReinterpretCast<uint8_t>();
LocalTensor<int16_t> nan_buf_int16 = select_out.ReinterpretCast<int16_t>();
LocalTensor<int16_t> compare_out_int16 = compare_out.ReinterpretCast<int16_t>();
BinaryRepeatParams src1_repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(T) == 32) {
src1_repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
}
for (uint8_t r = 0; r < repeat_times; r++) {
AscendC::PipeBarrier<PIPE_V>();
Compare(nan_buf, src0[r * elements_per_repeat], src0[r * elements_per_repeat], CMPMODE::EQ,
mask, 1, {1, 1, 1, 8, src_repeat_stride, src_repeat_stride});
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(nan_buf_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
Compare(compare_out, src1[r * elements_per_repeat], src1[r * elements_per_repeat], CMPMODE::EQ,
mask, 1, src1_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(compare_out_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, compare_out_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
Compare(nan_buf, src0[r * elements_per_repeat], src1[r * elements_per_repeat], mode,
mask, 1, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, nan_buf_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
Duplicate<half, true>(select_out, (half)1, select_bits_per_repeat, 1, 1, 8);
AscendC::PipeBarrier<PIPE_V>();
Select<half, uint8_t, true>(select_out, compare_out, select_out, (half)0,
SELMODE::VSEL_TENSOR_SCALAR_MODE, select_bits_per_repeat, 1, {1, 1, 1, 8, 8, 8});
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half, true>(dst[r * dst_repeat_stride * ONE_BLK_SIZE], select_out, RoundMode::CAST_NONE,
mask, 1, {1, 1, cast_dst_blk_stride, 8});
}
} else {
AscendC::PipeBarrier<PIPE_V>();
Compare(compare_out, src0, src1, mode, mask, repeat_times, repeat_params);
Duplicate<half, true>(select_out, 1, ONE_REPEAT_BYTE_SIZE / sizeof(half), select_repeat_times, 1, 8);
AscendC::PipeBarrier<PIPE_V>();
repeat_params = {1, 1, 1, 8, 8, 8};
Select<half, uint8_t, true>(select_out, compare_out, select_out, (half)0, SELMODE::VSEL_TENSOR_SCALAR_MODE,
ONE_REPEAT_BYTE_SIZE / sizeof(half), select_repeat_times, repeat_params);
uint8_t select_out_stride = compare_element_num * sizeof(half) / ONE_BLK_SIZE;
UnaryRepeatParams unary_repeat_params = {1, 1, dst_repeat_stride, select_out_stride};
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half, true>(dst, select_out, RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
}
}
template <typename T, CMPMODE mode>
inline __aicore__ void CompareNormal(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<T> &src0,
const AscendC::LocalTensor<T> &src1,
const uint8_t repeat_times, const uint32_t last_axis,
const uint32_t input_last_dim_stride, const uint32_t output_last_dim_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
const uint8_t dst_repeat_stride = output_last_dim_stride / ONE_BLK_SIZE;
const uint8_t src_repeat_stride = input_last_dim_stride * sizeof(T) / ONE_BLK_SIZE;
const uint32_t elem_in_one_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(T);
if (last_axis < elem_in_one_repeat) {
CompareNormalNoLoop<T, mode>(dst, src0, src1, repeat_times, last_axis, dst_repeat_stride, src_repeat_stride,
tmp_buf);
} else {
uint32_t element_extent = last_axis / elem_in_one_repeat;
uint32_t element_reminder = last_axis - element_extent * elem_in_one_repeat;
if (element_extent <= repeat_times) {
uint32_t tmp_offset = 0;
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
compare_out.SetSize((repeat_times * elem_in_one_repeat) / 8);
tmp_offset += (repeat_times * elem_in_one_repeat) * sizeof(uint8_t) / 8;
LocalTensor<half> select_out =
tmp_buf[KernelUtils::BlkAlign<uint8_t>(tmp_offset)].template ReinterpretCast<half>();
select_out.SetSize(ONE_REPEAT_BYTE_SIZE / sizeof(half) * repeat_times);
for (uint32_t outer_for = 0; outer_for < element_extent; outer_for++) {
uint32_t mask = elem_in_one_repeat;
if constexpr ((AscendC::IsSameType<T, float>::value || AscendC::IsSameType<T, half>::value) && mode == CMPMODE::NE) {
uint32_t select_bits_per_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(half);
uint32_t mask_int16_count = (elem_in_one_repeat + kShortBitWidth - 1) / kShortBitWidth;
LocalTensor<uint8_t> nan_buf = select_out.ReinterpretCast<uint8_t>();
LocalTensor<int16_t> nan_buf_int16 = select_out.ReinterpretCast<int16_t>();
LocalTensor<int16_t> compare_out_int16 = compare_out.ReinterpretCast<int16_t>();
BinaryRepeatParams src1_repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(T) == 32) {
src1_repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
}
for (uint8_t r = 0; r < repeat_times; r++) {
AscendC::PipeBarrier<PIPE_V>();
Compare(nan_buf, src0[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
src0[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
CMPMODE::EQ, mask, 1, {1, 1, 1, 8, src_repeat_stride, src_repeat_stride});
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(nan_buf_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
Compare(compare_out, src1[0], src1[0], CMPMODE::EQ, mask, 1, src1_repeat_params);
} else {
Compare(compare_out, src1[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
src1[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
CMPMODE::EQ, mask, 1, src1_repeat_params);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(compare_out_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, compare_out_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
Compare(nan_buf, src0[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
src1[0], mode, mask, 1, src1_repeat_params);
} else {
Compare(nan_buf, src0[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
src1[outer_for * elem_in_one_repeat + r * elem_in_one_repeat],
mode, mask, 1, {1, 1, 1, 8, src_repeat_stride, src_repeat_stride});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, nan_buf_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
Duplicate<half, true>(select_out, (half)1, select_bits_per_repeat, 1, 1, 8);
AscendC::PipeBarrier<PIPE_V>();
Select<half, uint8_t, true>(select_out, compare_out, select_out, (half)0,
SELMODE::VSEL_TENSOR_SCALAR_MODE, select_bits_per_repeat, 1, {1, 1, 1, 8, 8, 8});
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half, true>(dst[outer_for * elem_in_one_repeat + r * dst_repeat_stride * ONE_BLK_SIZE],
select_out, RoundMode::CAST_NONE, mask, 1, {1, 1, dst_repeat_stride, 8});
}
} else {
PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(T) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
Compare(compare_out, src0[outer_for * elem_in_one_repeat], src1[0], mode, mask, repeat_times, repeat_params);
} else {
Compare(compare_out, src0[outer_for * elem_in_one_repeat], src1[outer_for * elem_in_one_repeat], mode, mask,
repeat_times, repeat_params);
}
PipeBarrier<PIPE_V>();
Duplicate<half, true>(select_out, 1, ONE_REPEAT_BYTE_SIZE / sizeof(half), repeat_times, 1, 8);
AscendC::PipeBarrier<PIPE_V>();
repeat_params = {1, 1, 1, 8, 8, 8};
Select<half, uint8_t, true>(select_out, compare_out, select_out, (half)0, SELMODE::VSEL_TENSOR_SCALAR_MODE,
ONE_REPEAT_BYTE_SIZE / sizeof(uint8_t), repeat_times, repeat_params);
UnaryRepeatParams unary_repeat_params = {1, 1, dst_repeat_stride, 4};
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half, true>(dst[outer_for * elem_in_one_repeat], select_out, RoundMode::CAST_NONE, mask,
repeat_times, unary_repeat_params);
}
}
if (element_reminder != 0) {
uint32_t mask = element_reminder;
if constexpr ((AscendC::IsSameType<T, float>::value || AscendC::IsSameType<T, half>::value) && mode == CMPMODE::NE) {
uint32_t select_bits_per_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(half);
uint32_t mask_int16_count = (elem_in_one_repeat + kShortBitWidth - 1) / kShortBitWidth;
LocalTensor<uint8_t> nan_buf = select_out.ReinterpretCast<uint8_t>();
LocalTensor<int16_t> nan_buf_int16 = select_out.ReinterpretCast<int16_t>();
LocalTensor<int16_t> compare_out_int16 = compare_out.ReinterpretCast<int16_t>();
BinaryRepeatParams src1_repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(T) == 32) {
src1_repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
}
for (uint8_t r = 0; r < repeat_times; r++) {
AscendC::PipeBarrier<PIPE_V>();
Compare(nan_buf, src0[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
src0[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
CMPMODE::EQ, mask, 1, {1, 1, 1, 8, src_repeat_stride, src_repeat_stride});
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(nan_buf_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
Compare(compare_out, src1[0], src1[0], CMPMODE::EQ, mask, 1, src1_repeat_params);
} else {
Compare(compare_out, src1[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
src1[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
CMPMODE::EQ, mask, 1, src1_repeat_params);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(compare_out_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, compare_out_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
Compare(nan_buf, src0[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
src1[0], mode, mask, 1, src1_repeat_params);
} else {
Compare(nan_buf, src0[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
src1[element_extent * elem_in_one_repeat + r * elem_in_one_repeat],
mode, mask, 1, {1, 1, 1, 8, src_repeat_stride, src_repeat_stride});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, nan_buf_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
Duplicate<half, true>(select_out, (half)1, select_bits_per_repeat, 1, 1, 8);
AscendC::PipeBarrier<PIPE_V>();
Select<half, uint8_t, true>(select_out, compare_out, select_out, (half)0,
SELMODE::VSEL_TENSOR_SCALAR_MODE, select_bits_per_repeat, 1, {1, 1, 1, 8, 8, 8});
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half, true>(dst[element_extent * elem_in_one_repeat + r * dst_repeat_stride * ONE_BLK_SIZE],
select_out, RoundMode::CAST_NONE, mask, 1, {1, 1, dst_repeat_stride, 8});
}
} else {
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(T) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
Compare(compare_out, src0[element_extent * elem_in_one_repeat], src1[0], mode, mask, repeat_times,
repeat_params);
} else {
Compare(compare_out, src0[element_extent * elem_in_one_repeat], src1[element_extent * elem_in_one_repeat],
mode, mask, repeat_times, repeat_params);
}
PipeBarrier<PIPE_V>();
Duplicate<half, true>(select_out, 1, ONE_REPEAT_BYTE_SIZE / sizeof(half), repeat_times, 1, 8);
repeat_params = {1, 1, 1, 8, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
Select<half, uint8_t, true>(select_out, compare_out, select_out, (half)0, SELMODE::VSEL_TENSOR_SCALAR_MODE,
ONE_REPEAT_BYTE_SIZE / sizeof(uint8_t), repeat_times, repeat_params);
UnaryRepeatParams unary_repeat_params = {1, 1, dst_repeat_stride, 4};
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half, true>(dst[element_extent * elem_in_one_repeat], select_out, RoundMode::CAST_NONE, mask,
repeat_times, unary_repeat_params);
}
}
} else {
uint32_t tmp_offset = 0;
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
compare_out.SetSize(KernelUtils::BlkAlign<uint8_t>(last_axis) / 8);
tmp_offset += (KernelUtils::BlkAlign<uint8_t>(last_axis)) * sizeof(uint8_t) / 8;
LocalTensor<half> select_out =
tmp_buf[KernelUtils::BlkAlign<uint8_t>(tmp_offset)].template ReinterpretCast<half>();
select_out.SetSize(last_axis);
for (auto outer_for = 0; outer_for < repeat_times; outer_for++) {
if constexpr ((AscendC::IsSameType<T, float>::value || AscendC::IsSameType<T, half>::value) && mode == CMPMODE::NE) {
LocalTensor<uint8_t> nan_buf = select_out.ReinterpretCast<uint8_t>();
LocalTensor<int16_t> nan_buf_int16 = select_out.ReinterpretCast<int16_t>();
LocalTensor<int16_t> compare_out_int16 = compare_out.ReinterpretCast<int16_t>();
uint32_t mask_int16_count = (last_axis + kShortBitWidth - 1) / kShortBitWidth;
AscendC::PipeBarrier<PIPE_V>();
Compare(nan_buf, src0[outer_for * input_last_dim_stride],
src0[outer_for * input_last_dim_stride], CMPMODE::EQ, last_axis);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(nan_buf_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
Compare(compare_out, src1, src1, CMPMODE::EQ, src1.GetSize());
} else {
Compare(compare_out, src1[outer_for * input_last_dim_stride],
src1[outer_for * input_last_dim_stride], CMPMODE::EQ, last_axis);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Not(compare_out_int16, compare_out_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, compare_out_int16, nan_buf_int16, mask_int16_count);
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
uint32_t calcount_aligned = (last_axis + elem_in_one_repeat - 1) / elem_in_one_repeat * elem_in_one_repeat;
Compare(nan_buf, src0[outer_for * input_last_dim_stride], src1, mode, calcount_aligned);
} else {
uint32_t calcount_aligned = (last_axis + elem_in_one_repeat - 1) / elem_in_one_repeat * elem_in_one_repeat;
Compare(nan_buf, src0[outer_for * input_last_dim_stride],
src1[outer_for * input_last_dim_stride], mode, calcount_aligned);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(compare_out_int16, nan_buf_int16, compare_out_int16, mask_int16_count);
} else {
PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(T) == 32) {
event_t eventid_v_to_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventid_v_to_s);
WaitFlag<HardEvent::V_S>(eventid_v_to_s);
auto scalar_value = src1.GetValue(0);
event_t eventid_s_to_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventid_s_to_v);
WaitFlag<HardEvent::S_V>(eventid_s_to_v);
CompareScalar(compare_out, src0[outer_for * input_last_dim_stride], scalar_value, mode, last_axis);
} else {
uint32_t calcount_aligned = (last_axis + elem_in_one_repeat - 1) / elem_in_one_repeat * elem_in_one_repeat;
Compare(compare_out, src0[outer_for * input_last_dim_stride], src1[outer_for * input_last_dim_stride], mode,
calcount_aligned);
}
}
PipeBarrier<PIPE_V>();
Duplicate<half>(select_out, static_cast<half>(1), last_axis);
AscendC::PipeBarrier<PIPE_V>();
Select<half, uint8_t>(select_out, compare_out, select_out, (half)0, SELMODE::VSEL_TENSOR_SCALAR_MODE,
last_axis);
AscendC::PipeBarrier<PIPE_V>();
Cast<uint8_t, half>(dst[outer_for * output_last_dim_stride], select_out, RoundMode::CAST_NONE, last_axis);
}
}
}
}
template <CMPMODE mode>
inline __aicore__ void ApplyCompareModeNormal(LocalTensor<int32_t> &inter_buf, uint32_t mask, uint8_t repeat_times,
const UnaryRepeatParams &unary_repeat_params) {
switch (mode) {
case CMPMODE::GE:
AscendC::Adds(inter_buf, inter_buf, (int32_t)1, mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
case CMPMODE::GT:
AscendC::Maxs(inter_buf, inter_buf, (int32_t)0, mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, mask, repeat_times, unary_repeat_params);
break;
case CMPMODE::LE:
AscendC::Adds(inter_buf, inter_buf, (int32_t)(-1), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
case CMPMODE::LT:
case CMPMODE::NE:
AscendC::Maxs(inter_buf, inter_buf, (int32_t)(-1), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
if (mode == CMPMODE::NE) {
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, mask, repeat_times, unary_repeat_params);
} else {
AscendC::Mins(inter_buf, inter_buf, (int32_t)0, mask, repeat_times, unary_repeat_params);
}
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams binary_repeat_params = {1, 1, 1, 8, 8, 8};
AscendC::Mul(inter_buf, inter_buf, inter_buf, mask, repeat_times, binary_repeat_params);
break;
}
AscendC::PipeBarrier<PIPE_V>();
}
template <CMPMODE mode>
inline __aicore__ void ApplyCompareModeCount(LocalTensor<int32_t> &inter_buf, uint32_t count) {
switch (mode) {
case CMPMODE::GE:
AscendC::Adds(inter_buf, inter_buf, (int32_t)1, count);
AscendC::PipeBarrier<PIPE_V>();
case CMPMODE::GT:
AscendC::Maxs(inter_buf, inter_buf, (int32_t)0, count);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, count);
break;
case CMPMODE::LE:
AscendC::Adds(inter_buf, inter_buf, (int32_t)(-1), count);
AscendC::PipeBarrier<PIPE_V>();
case CMPMODE::LT:
case CMPMODE::NE:
AscendC::Maxs(inter_buf, inter_buf, (int32_t)(-1), count);
AscendC::PipeBarrier<PIPE_V>();
if (mode == CMPMODE::NE) {
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, count);
} else {
AscendC::Mins(inter_buf, inter_buf, (int32_t)0, count);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(inter_buf, inter_buf, inter_buf, count);
break;
}
AscendC::PipeBarrier<PIPE_V>();
}
inline __aicore__ void PerformTypeConversionNormal(const AscendC::LocalTensor<uint8_t> &dst,
AscendC::LocalTensor<int16_t> &int16_buf,
AscendC::LocalTensor<half> &half_buf,
AscendC::LocalTensor<int32_t> &inter_buf, uint32_t mask,
uint8_t repeat_times, uint8_t dst_repeat_stride,
uint32_t dst_offset = 0) {
UnaryRepeatParams unary_repeat_params = {1, 1, 4, 8};
AscendC::Cast(int16_buf, inter_buf, RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, 4, 4};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(half_buf, int16_buf, RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 4};
AscendC::PipeBarrier<PIPE_V>();
if (dst_offset == 0) {
AscendC::Cast(dst, half_buf, RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
} else {
AscendC::Cast(dst[dst_offset], half_buf, RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
}
}
inline __aicore__ void PerformTypeConversionCount(const AscendC::LocalTensor<uint8_t> &dst,
AscendC::LocalTensor<int16_t> &int16_buf,
AscendC::LocalTensor<half> &half_buf,
AscendC::LocalTensor<int32_t> &inter_buf, uint32_t num_elements,
uint32_t dst_offset = 0) {
AscendC::Cast(int16_buf, inter_buf, RoundMode::CAST_NONE, num_elements);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(half_buf, int16_buf, RoundMode::CAST_NONE, num_elements);
AscendC::PipeBarrier<PIPE_V>();
if (dst_offset == 0) {
AscendC::Cast(dst, half_buf, RoundMode::CAST_NONE, num_elements);
} else {
AscendC::Cast(dst[dst_offset], half_buf, RoundMode::CAST_NONE, num_elements);
}
}
template <CMPMODE mode>
inline __aicore__ void ProcessSmallAxisCase(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<int32_t> &src0,
const AscendC::LocalTensor<int32_t> &src1, const uint8_t repeat_times,
const uint64_t mask, const uint8_t src_repeat_stride,
const uint8_t dst_repeat_stride, AscendC::LocalTensor<int32_t> &inter_buf,
AscendC::LocalTensor<int16_t> &int16_buf,
AscendC::LocalTensor<half> &half_buf) {
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
if (src1.GetSize() * sizeof(int32_t) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
}
AscendC::Sub(inter_buf, src0, src1, mask, repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
UnaryRepeatParams unary_repeat_params = {1, 1, 8, 8};
ApplyCompareModeNormal<mode>(inter_buf, mask, repeat_times, unary_repeat_params);
PerformTypeConversionNormal(dst, int16_buf, half_buf, inter_buf, mask, repeat_times, dst_repeat_stride);
}
template <CMPMODE mode>
inline __aicore__ void ProcessBlock(const AscendC::LocalTensor<uint8_t> &dst, const AscendC::LocalTensor<int32_t> &src0,
const AscendC::LocalTensor<int32_t> &src1, const uint8_t repeat_times,
const uint32_t mask, const uint8_t src_repeat_stride,
const uint8_t dst_repeat_stride, AscendC::LocalTensor<int32_t> &inter_buf,
AscendC::LocalTensor<int16_t> &int16_buf, AscendC::LocalTensor<half> &half_buf,
uint32_t offset_src0, uint32_t offset_src1, uint32_t offset_dst) {
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int32_t) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
AscendC::Sub(inter_buf, src0[offset_src0], src1[0], mask, repeat_times, repeat_params);
} else {
AscendC::Sub(inter_buf, src0[offset_src0], src1[offset_src1], mask, repeat_times, repeat_params);
}
UnaryRepeatParams unary_repeat_params = {1, 1, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
ApplyCompareModeNormal<mode>(inter_buf, mask, repeat_times, unary_repeat_params);
PerformTypeConversionNormal(dst, int16_buf, half_buf, inter_buf, mask, repeat_times, dst_repeat_stride, offset_dst);
}
template <CMPMODE mode>
inline __aicore__ void ProcessMediumAxisCase(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<int32_t> &src0,
const AscendC::LocalTensor<int32_t> &src1, const uint8_t repeat_times,
const uint32_t last_axis, const uint32_t input_last_dim_stride,
const uint32_t output_last_dim_stride,
AscendC::LocalTensor<int32_t> &inter_buf,
AscendC::LocalTensor<int16_t> &int16_buf,
AscendC::LocalTensor<half> &half_buf) {
constexpr uint32_t elem_in_one_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(int32_t);
const uint8_t src_repeat_stride = input_last_dim_stride * sizeof(int32_t) / ONE_BLK_SIZE;
const uint8_t dst_repeat_stride = output_last_dim_stride / ONE_BLK_SIZE;
uint32_t element_extent = last_axis / elem_in_one_repeat;
uint32_t element_reminder = last_axis - element_extent * elem_in_one_repeat;
for (uint32_t outer_for = 0; outer_for < element_extent; outer_for++) {
constexpr uint32_t mask = elem_in_one_repeat;
ProcessBlock<mode>(dst, src0, src1, repeat_times, mask, src_repeat_stride, dst_repeat_stride,
inter_buf, int16_buf, half_buf, outer_for * elem_in_one_repeat,
outer_for * elem_in_one_repeat, outer_for * elem_in_one_repeat);
}
if (element_reminder != 0) {
uint32_t mask = element_reminder;
ProcessBlock<mode>(dst, src0, src1, repeat_times, mask, src_repeat_stride, dst_repeat_stride,
inter_buf, int16_buf, half_buf, element_extent * elem_in_one_repeat,
element_extent * elem_in_one_repeat, element_extent * elem_in_one_repeat);
}
}
template <CMPMODE mode>
inline __aicore__ void ProcessLargeAxisCase(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<int32_t> &src0,
const AscendC::LocalTensor<int32_t> &src1, const uint8_t repeat_times,
const uint32_t last_axis, const uint32_t input_last_dim_stride,
const uint32_t output_last_dim_stride,
AscendC::LocalTensor<int32_t> &inter_buf,
AscendC::LocalTensor<int16_t> &int16_buf,
AscendC::LocalTensor<half> &half_buf) {
constexpr uint32_t elem_in_one_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(int32_t);
uint32_t element_extent = last_axis / elem_in_one_repeat;
uint32_t element_reminder = last_axis - element_extent * elem_in_one_repeat;
const uint8_t src_repeat_stride = input_last_dim_stride * sizeof(int32_t) / ONE_BLK_SIZE;
for (auto outer_for = 0; outer_for < repeat_times; outer_for++) {
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int32_t) == 32) {
BinaryRepeatParams repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
AscendC::Sub(inter_buf, src0[outer_for * input_last_dim_stride], src1[0], elem_in_one_repeat, element_extent, repeat_params);
if (element_reminder != 0) {
AscendC::Sub(inter_buf[element_extent * elem_in_one_repeat],
src0[outer_for * input_last_dim_stride + element_extent * elem_in_one_repeat],
src1[0], element_reminder, 1, repeat_params);
}
} else {
AscendC::Sub(inter_buf, src0[outer_for * input_last_dim_stride],
src1[outer_for * input_last_dim_stride], last_axis);
}
AscendC::PipeBarrier<PIPE_V>();
ApplyCompareModeCount<mode>(inter_buf, last_axis);
AscendC::Maxs(inter_buf, inter_buf, (int32_t)0, last_axis);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, last_axis);
AscendC::PipeBarrier<PIPE_V>();
PerformTypeConversionCount(dst, int16_buf, half_buf, inter_buf, last_axis, outer_for * output_last_dim_stride);
}
}
template <CMPMODE mode>
inline __aicore__ void CompareExtendInt32(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<int32_t> &src0,
const AscendC::LocalTensor<int32_t> &src1,
const uint8_t repeat_times,
const uint32_t last_axis,
const uint32_t input_last_dim_stride,
const uint32_t output_last_dim_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
const uint8_t dst_repeat_stride = output_last_dim_stride / ONE_BLK_SIZE;
const uint8_t src_repeat_stride = input_last_dim_stride * sizeof(int32_t) / ONE_BLK_SIZE;
constexpr uint32_t elem_in_one_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(int32_t);
LocalTensor<int32_t> inter_buf = tmp_buf.ReinterpretCast<int32_t>();
LocalTensor<int16_t> int16_buf = inter_buf.ReinterpretCast<int16_t>();
LocalTensor<half> half_buf = inter_buf.ReinterpretCast<half>();
if (last_axis < elem_in_one_repeat) {
const uint64_t mask = last_axis;
ProcessSmallAxisCase<mode>(dst, src0, src1, repeat_times, mask, src_repeat_stride,
dst_repeat_stride, inter_buf, int16_buf, half_buf);
} else {
uint32_t element_extent = last_axis / elem_in_one_repeat;
if (element_extent <= repeat_times) {
ProcessMediumAxisCase<mode>(dst, src0, src1, repeat_times, last_axis, input_last_dim_stride,
output_last_dim_stride, inter_buf, int16_buf, half_buf);
} else {
ProcessLargeAxisCase<mode>(dst, src0, src1, repeat_times, last_axis, input_last_dim_stride,
output_last_dim_stride, inter_buf, int16_buf, half_buf);
}
}
}
template <CMPMODE mode>
inline __aicore__ void CompareExtendInt64EqNe(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
const uint8_t repeat_times, const uint32_t last_axis,
const uint32_t input_last_dim_stride, const uint32_t output_last_dim_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
const uint8_t dst_repeat_stride = output_last_dim_stride / ONE_BLK_SIZE;
const uint8_t src_repeat_stride = input_last_dim_stride * sizeof(int64_t) / ONE_BLK_SIZE;
LocalTensor<half> all_one_buf = KernelUtils::NewTensor<half>(tmp_buf, 0, KernelUtils::RptSize<half>());
Duplicate(all_one_buf, (half)1.0f, KernelUtils::RptSize<half>());
LocalTensor<int32_t> sub_res_buf = KernelUtils::NewTensor<int32_t>(
tmp_buf, ONE_REPEAT_BYTE_SIZE, ONE_REPEAT_BYTE_SIZE * repeat_times / sizeof(int32_t));
LocalTensor<int64_t> sub_res_64_buf = sub_res_buf.ReinterpretCast<int64_t>();
LocalTensor<uint8_t> sub_res_reuse_buf = sub_res_buf.template ReinterpretCast<uint8_t>();
LocalTensor<float> cast_res_buf = KernelUtils::NewTensor<float>(
tmp_buf, ONE_REPEAT_BYTE_SIZE * (1 + repeat_times), ONE_REPEAT_BYTE_SIZE * repeat_times / sizeof(float));
LocalTensor<uint8_t> cmp_res_buf = KernelUtils::NewTensor<uint8_t>(sub_res_reuse_buf, 0,
(ONE_REPEAT_BYTE_SIZE * repeat_times) / 8);
LocalTensor<half> sel_res_buf = KernelUtils::NewTensor<half>(
sub_res_reuse_buf, (ONE_REPEAT_BYTE_SIZE * repeat_times) / 8, ONE_REPEAT_BYTE_SIZE * repeat_times / sizeof(half));
constexpr uint32_t elem_in_one_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(int64_t);
LocalTensor<int32_t> src0_new = src0.ReinterpretCast<int32_t>();
LocalTensor<int32_t> src1_new = src1.ReinterpretCast<int32_t>();
if (last_axis <= elem_in_one_repeat) {
const uint64_t mask = last_axis;
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
}
AscendC::Sub(sub_res_buf, src0_new, src1_new, mask * 2, repeat_times, repeat_params);
UnaryRepeatParams unary_repeat_params = {1, 1, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(cast_res_buf, sub_res_64_buf, AscendC::RoundMode::CAST_RINT, mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::CompareScalar(cmp_res_buf, cast_res_buf, 0.0f, mode, mask, repeat_times, unary_repeat_params);
repeat_params = {1, 1, 1, 8, 0, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Select<half, uint8_t, true>(sel_res_buf, cmp_res_buf, all_one_buf, (half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(), repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, dst_repeat_stride, 4};
AscendC::Cast(dst, sel_res_buf, AscendC::RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
} else {
uint32_t element_extent = last_axis / elem_in_one_repeat;
uint32_t element_reminder = last_axis - element_extent * elem_in_one_repeat;
if (element_extent <= repeat_times) {
for (uint32_t outer_for = 0; outer_for < element_extent; outer_for++) {
constexpr uint64_t mask = elem_in_one_repeat;
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
AscendC::Sub(sub_res_buf, src0_new[outer_for * elem_in_one_repeat * 2], src1_new[0], mask * 2, repeat_times,
repeat_params);
} else {
AscendC::Sub(sub_res_buf, src0_new[outer_for * elem_in_one_repeat * 2],
src1_new[outer_for * elem_in_one_repeat * 2], mask * 2, repeat_times, repeat_params);
}
UnaryRepeatParams unary_repeat_params = {1, 1, 4, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(cast_res_buf, sub_res_64_buf, AscendC::RoundMode::CAST_RINT, mask, repeat_times,
unary_repeat_params);
unary_repeat_params = {1, 1, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::CompareScalar(cmp_res_buf, cast_res_buf, 0.0f, mode, mask, DivCeil(repeat_times, 2), unary_repeat_params);
repeat_params = {1, 0, 0, 8, 0, 0};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Select<half, uint8_t, true>(sel_res_buf, cmp_res_buf, all_one_buf, (half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
128, DivCeil(repeat_times, 4), repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, dst_repeat_stride, 2};
AscendC::Cast(dst[outer_for * elem_in_one_repeat], sel_res_buf, AscendC::RoundMode::CAST_NONE, mask,
repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
}
if (element_reminder != 0) {
uint32_t mask = element_reminder;
BinaryRepeatParams repeat_params = {1, 1, 1, 8, src_repeat_stride, src_repeat_stride};
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 1, 0, 8, src_repeat_stride, 0};
AscendC::Sub(sub_res_buf, src0_new[element_extent * elem_in_one_repeat * 2], src1_new[0], mask * 2,
repeat_times, repeat_params);
} else {
AscendC::Sub(sub_res_buf, src0_new[element_extent * elem_in_one_repeat * 2],
src1_new[element_extent * elem_in_one_repeat * 2], mask * 2, repeat_times, repeat_params);
}
UnaryRepeatParams unary_repeat_params = {1, 1, 4, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(cast_res_buf, sub_res_64_buf, AscendC::RoundMode::CAST_RINT, mask, repeat_times,
unary_repeat_params);
unary_repeat_params = {1, 1, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::CompareScalar(cmp_res_buf, cast_res_buf, 0.0f, mode, mask, DivCeil(repeat_times, 2), unary_repeat_params);
repeat_params = {1, 0, 0, 8, 0, 0};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Select<half, uint8_t, true>(sel_res_buf, cmp_res_buf, all_one_buf, (half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE, 128, DivCeil(repeat_times, 4),
repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, dst_repeat_stride, 2};
AscendC::Cast(dst[element_extent * elem_in_one_repeat], sel_res_buf, AscendC::RoundMode::CAST_NONE, mask,
repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
}
} else {
element_reminder = last_axis - element_extent * elem_in_one_repeat;
for (uint32_t outer_for = 0; outer_for < repeat_times; outer_for++) {
uint64_t mask = elem_in_one_repeat;
BinaryRepeatParams repeat_params = {1, 1, 1, 8, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 1, 0, 8, 8, 0};
AscendC::Sub(sub_res_buf, src0_new[outer_for * input_last_dim_stride * 2], src1_new[0], mask * 2,
element_extent, repeat_params);
} else {
AscendC::Sub(sub_res_buf, src0_new[outer_for * input_last_dim_stride * 2],
src1_new[outer_for * input_last_dim_stride * 2], mask * 2, element_extent, repeat_params);
}
UnaryRepeatParams unary_repeat_params = {1, 1, 4, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(cast_res_buf, sub_res_64_buf, AscendC::RoundMode::CAST_RINT, mask, element_extent,
unary_repeat_params);
unary_repeat_params = {1, 1, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::CompareScalar(cmp_res_buf, cast_res_buf, 0.0f, mode, mask, DivCeil(element_extent, 2),
unary_repeat_params);
repeat_params = {1, 0, 0, 8, 0, 0};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Select<half, uint8_t, true>(sel_res_buf, cmp_res_buf, all_one_buf, (half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE, 128,
DivCeil(element_extent, 4), repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, 1, 2};
AscendC::Cast(dst[outer_for * output_last_dim_stride], sel_res_buf, AscendC::RoundMode::CAST_NONE, mask,
element_extent, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
if (element_reminder != 0) {
mask = element_reminder;
BinaryRepeatParams repeat_params = {1, 1, 1, 8, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 1, 0, 8, 8, 0};
AscendC::Sub(sub_res_buf, src0_new[outer_for * input_last_dim_stride * 2 + element_extent *
elem_in_one_repeat * 2], src1_new[0], mask * 2, 1, repeat_params);
} else {
AscendC::Sub(sub_res_buf, src0_new[outer_for * input_last_dim_stride * 2 + element_extent *
elem_in_one_repeat * 2], src1_new[outer_for * input_last_dim_stride * 2 + element_extent *
elem_in_one_repeat * 2], mask * 2, 1, repeat_params);
}
UnaryRepeatParams unary_repeat_params = {1, 1, 4, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(cast_res_buf, sub_res_64_buf, AscendC::RoundMode::CAST_RINT, mask, 1, unary_repeat_params);
unary_repeat_params = {1, 1, 8, 8};
AscendC::PipeBarrier<PIPE_V>();
AscendC::CompareScalar(cmp_res_buf, cast_res_buf, 0.0f, mode, mask, 1, unary_repeat_params);
repeat_params = {1, 0, 0, 8, 0, 0};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Select<half, uint8_t, true>(sel_res_buf, cmp_res_buf, all_one_buf, (half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE, 128, 1, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, 1, 2};
AscendC::Cast(dst[outer_for * output_last_dim_stride + element_extent * elem_in_one_repeat], sel_res_buf,
AscendC::RoundMode::CAST_NONE, mask, 1, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
}
}
}
}
}
inline __aicore__ void GetSignBitTensorNormal(const AscendC::LocalTensor<uint16_t> &dst,
const AscendC::LocalTensor<int64_t> &src, const AscendC::LocalTensor<uint32_t> &inner_dup, const uint32_t mask,
const uint32_t repeat_times, const uint8_t src_repeat_stride, const uint8_t dst_stride)
{
AscendC::PipeBarrier<PIPE_V>();
AscendC::Duplicate(inner_dup, uint32_t(0x80000000), inner_dup.GetSize());
AscendC::LocalTensor<uint16_t> inner_dup_tmp = inner_dup.ReinterpretCast<uint16_t>();
AscendC::LocalTensor<uint16_t> src_tmp = src.ReinterpretCast<uint16_t>();
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, dst_stride, src_repeat_stride, 0};
if (src.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 0, 0, dst_stride, 0, 0};
}
AscendC::And(dst, src_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
AscendC::LocalTensor<int32_t> dst_tmp = dst.ReinterpretCast<int32_t>();
AscendC::PipeBarrier<PIPE_V>();
UnaryRepeatParams unary_repeat_params = {1, 1, dst_stride, dst_stride};
AscendC::Maxs(dst_tmp, dst_tmp, -1, mask * 2, repeat_times, unary_repeat_params);
AscendC::Duplicate(inner_dup, uint32_t(0xFFFFFFFF), inner_dup.GetSize());
uint64_t mask_half[2] = { uint64_t(0x5555555555555555), 0 };
AscendC::Duplicate(inner_dup, 1U, mask_half, repeat_times, 1, 8);
repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(dst, dst, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
}
inline __aicore__ void CastTensorToHalfNormal(const AscendC::LocalTensor<int32_t> &dst,
const AscendC::LocalTensor<uint16_t> &src0_bits, const AscendC::LocalTensor<uint16_t> &src1_bits,
const AscendC::LocalTensor<uint32_t> &inner_dup, const uint32_t mask, const uint32_t repeat_times,
const uint8_t dst_stride)
{
AscendC::PipeBarrier<PIPE_V>();
AscendC::Duplicate(inner_dup, uint32_t(0x0000BC00), inner_dup.GetSize());
AscendC::LocalTensor<uint16_t> inner_dup_tmp = inner_dup.ReinterpretCast<uint16_t>();
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::And(src1_bits, src0_bits, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
AscendC::LocalTensor<int32_t> sub_tmp = inner_dup.ReinterpretCast<int32_t>();
AscendC::Duplicate(sub_tmp, 0, sub_tmp.GetSize());
AscendC::Sub(dst, sub_tmp, src0_bits.ReinterpretCast<int32_t>(), mask * 2, repeat_times, repeat_params);
AscendC::Duplicate(inner_dup, uint32_t(0x00003C00), inner_dup.GetSize());
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(src0_bits, dst.ReinterpretCast<uint16_t>(), inner_dup_tmp, mask * 4, repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(src0_bits, src1_bits, src0_bits, mask * 4, repeat_times, repeat_params);
AscendC::LocalTensor<half> dst_tmp = dst.ReinterpretCast<half>();
AscendC::LocalTensor<half> sign_tmp = src0_bits.ReinterpretCast<half>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::PairReduceSum(dst_tmp, sign_tmp, repeat_times, mask * 4, 1, 1, dst_stride);
}
inline __aicore__ void CalcWeightedTensorNormal(const AscendC::LocalTensor<half> &dst,
const AscendC::LocalTensor<half> &src, const AscendC::LocalTensor<uint32_t> &inner_dup, const uint32_t mask,
const uint32_t repeat_times, const half weight0, const half weight1, const uint8_t dst_stride)
{
AscendC::PipeBarrier<PIPE_V>();
AscendC::LocalTensor<half> inner_dup_tmp = inner_dup.ReinterpretCast<half>();
AscendC::Duplicate(inner_dup_tmp, weight0, inner_dup_tmp.GetSize());
uint64_t mask_half[2] = { uint64_t(0x5555555555555555), uint64_t(0x5555555555555555) };
AscendC::Duplicate(inner_dup_tmp, weight1, mask_half, repeat_times, 1, 8);
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::Mul(src, src, inner_dup_tmp, mask * 2, repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::PairReduceSum(dst, src, repeat_times, mask * 2, 1, 1, 4);
}
template <CMPMODE mode>
inline __aicore__ void CompareExtendInt64GtGeLe(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
const uint8_t repeat_times, const uint32_t last_axis,
const uint32_t input_last_dim_stride, const uint32_t output_last_dim_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
const uint8_t dst_repeat_stride = output_last_dim_stride / ONE_BLK_SIZE;
const uint8_t src_repeat_stride = input_last_dim_stride * sizeof(int64_t) / ONE_BLK_SIZE;
constexpr uint32_t elem_in_one_repeat = ONE_REPEAT_BYTE_SIZE / sizeof(int64_t);
uint32_t quadruple_cal_cnt = 4 * repeat_times * (last_axis < elem_in_one_repeat ? elem_in_one_repeat : last_axis);
uint32_t double_cal_cnt = 2 * repeat_times * (last_axis < elem_in_one_repeat ? elem_in_one_repeat : last_axis);
uint32_t offset = 0;
AscendC::LocalTensor<uint16_t> src0_bits = KernelUtils::NewTensor<uint16_t>(tmp_buf, 0, quadruple_cal_cnt);
offset += DivCeil(quadruple_cal_cnt * sizeof(uint16_t), ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<uint16_t> src1_bits = KernelUtils::NewTensor<uint16_t>(tmp_buf, offset, quadruple_cal_cnt);
offset += DivCeil(quadruple_cal_cnt * sizeof(uint16_t), ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<int32_t> sub_res = KernelUtils::NewTensor<int32_t>(tmp_buf, offset, double_cal_cnt);
offset += DivCeil(double_cal_cnt * sizeof(int32_t), ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<half> sign_res = KernelUtils::NewTensor<half>(tmp_buf, offset, double_cal_cnt);
offset += DivCeil(double_cal_cnt * sizeof(half), ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<half> value_res = KernelUtils::NewTensor<half>(tmp_buf, offset, double_cal_cnt);
offset += DivCeil(double_cal_cnt * sizeof(half), ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<uint32_t> inner_dup = KernelUtils::NewTensor<uint32_t>(tmp_buf, offset, double_cal_cnt);
AscendC::LocalTensor<uint16_t> inner_dup_tmp = inner_dup.ReinterpretCast<uint16_t>();
if (last_axis < elem_in_one_repeat) {
uint32_t mask = last_axis;
uint8_t dst_stride = DivCeil(mask, (ONE_BLK_SIZE / sizeof(uint64_t)));
GetSignBitTensorNormal(src0_bits, src0, inner_dup, mask, repeat_times, src_repeat_stride, dst_stride);
GetSignBitTensorNormal(src1_bits, src1, inner_dup, mask, repeat_times, src_repeat_stride, dst_stride);
AscendC::LocalTensor<int32_t> src0_bits_tmp = src0_bits.ReinterpretCast<int32_t>();
AscendC::LocalTensor<int32_t> src1_bits_tmp = src1_bits.ReinterpretCast<int32_t>();
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, mask * 2, repeat_times, repeat_params);
CastTensorToHalfNormal(sub_res, src0_bits, src1_bits, inner_dup, mask, repeat_times, dst_stride);
CalcWeightedTensorNormal(sign_res, sub_res.ReinterpretCast<half>(), inner_dup, mask, repeat_times, half(8), half(2),
dst_stride);
AscendC::Duplicate(inner_dup, uint32_t(0x7FFFFFFF), inner_dup.GetSize());
AscendC::LocalTensor<uint16_t> src0_tmp = src0.ReinterpretCast<uint16_t>();
AscendC::LocalTensor<uint16_t> src1_tmp = src1.ReinterpretCast<uint16_t>();
AscendC::PipeBarrier<PIPE_V>();
repeat_params = {1, 1, 1, dst_stride, src_repeat_stride, src_repeat_stride};
AscendC::And(src0_bits, src0_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 0, 1, dst_stride, 0, src_repeat_stride};
}
AscendC::And(src1_bits, src1_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
repeat_params = {1, 1, 1, dst_stride, src_repeat_stride, src_repeat_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, mask * 2, repeat_times, repeat_params);
UnaryRepeatParams unary_repeat_params = {1, 1, src_repeat_stride, src_repeat_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Maxs(src0_bits_tmp, src0_bits_tmp, -1, mask * 2, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(src0_bits_tmp, src0_bits_tmp, 1, mask * 2, repeat_times, unary_repeat_params);
CastTensorToHalfNormal(sub_res, src0_bits, src1_bits, inner_dup, mask, repeat_times, dst_stride);
CalcWeightedTensorNormal(value_res, sub_res.ReinterpretCast<half>(), inner_dup, mask, repeat_times, half(4),
half(1), dst_stride);
repeat_params = {1, 1, 1, 1, 4, 4};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(sign_res, sign_res, value_res, mask, repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, 1, 1};
if (mode == CMPMODE::GT) {
AscendC::Maxs(sign_res, sign_res, half(0), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(1), mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 1};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst, sign_res, AscendC::RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
} else {
AscendC::Maxs(sign_res, sign_res, half(-1), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(0), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(sign_res, sign_res, half(1), mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 1};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst, sign_res, AscendC::RoundMode::CAST_NONE, mask, repeat_times, unary_repeat_params);
}
} else {
uint32_t element_extent = last_axis / elem_in_one_repeat;
uint32_t element_reminder = last_axis - element_extent * elem_in_one_repeat;
for (uint32_t outer_for = 0; outer_for < element_extent; outer_for++) {
uint32_t mask = elem_in_one_repeat;
uint8_t dst_stride = DivCeil(mask, (ONE_BLK_SIZE / sizeof(uint64_t)));
GetSignBitTensorNormal(src0_bits, src0[outer_for * elem_in_one_repeat], inner_dup, mask, repeat_times,
src_repeat_stride, dst_stride);
if (src1.GetSize() * sizeof(int64_t) == 32) {
GetSignBitTensorNormal(src1_bits, src1, inner_dup, mask, repeat_times, src_repeat_stride, dst_stride);
} else {
GetSignBitTensorNormal(src1_bits, src1[outer_for * elem_in_one_repeat], inner_dup, mask, repeat_times,
src_repeat_stride, dst_stride);
}
AscendC::LocalTensor<int32_t> src0_bits_tmp = src0_bits.ReinterpretCast<int32_t>();
AscendC::LocalTensor<int32_t> src1_bits_tmp = src1_bits.ReinterpretCast<int32_t>();
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, mask * 2, repeat_times, repeat_params);
CastTensorToHalfNormal(sub_res, src0_bits, src1_bits, inner_dup, mask, repeat_times, dst_stride);
CalcWeightedTensorNormal(sign_res, sub_res.ReinterpretCast<half>(), inner_dup, mask, repeat_times, half(8),
half(2), dst_stride);
AscendC::Duplicate(inner_dup, uint32_t(0x7FFFFFFF), inner_dup.GetSize());
AscendC::PipeBarrier<PIPE_V>();
repeat_params = {1, 1, 1, dst_stride, src_repeat_stride, src_repeat_stride};
AscendC::LocalTensor<uint16_t> src0_tmp = src0[outer_for * elem_in_one_repeat].ReinterpretCast<uint16_t>();
AscendC::And(src0_bits, src0_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 0, 1, dst_stride, 0, src_repeat_stride};
AscendC::LocalTensor<uint16_t> src1_tmp = src1.ReinterpretCast<uint16_t>();
AscendC::And(src1_bits, src1_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
} else {
AscendC::LocalTensor<uint16_t> src1_tmp = src1[outer_for * elem_in_one_repeat].ReinterpretCast<uint16_t>();
AscendC::And(src1_bits, src1_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
}
repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, mask * 2, repeat_times, repeat_params);
UnaryRepeatParams unary_repeat_params = {1, 1, dst_stride, dst_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Maxs(src0_bits_tmp, src0_bits_tmp, -1, mask * 2, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(src0_bits_tmp, src0_bits_tmp, 1, mask * 2, repeat_times, unary_repeat_params);
CastTensorToHalfNormal(sub_res, src0_bits, src1_bits, inner_dup, mask, repeat_times, dst_stride);
CalcWeightedTensorNormal(value_res, sub_res.ReinterpretCast<half>(), inner_dup, mask, repeat_times, half(4),
half(1), dst_stride);
repeat_params = {1, 1, 1, 2, 4, 4};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(sign_res, sign_res, value_res, mask, repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, 2, 2};
if (mode == CMPMODE::GT) {
AscendC::Maxs(sign_res, sign_res, half(0), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(1), mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 2};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst[outer_for * elem_in_one_repeat], sign_res, AscendC::RoundMode::CAST_NONE, mask, repeat_times,
unary_repeat_params);
} else {
AscendC::Maxs(sign_res, sign_res, half(-1), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(0), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(sign_res, sign_res, half(1), mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 2};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst[outer_for * elem_in_one_repeat], sign_res, AscendC::RoundMode::CAST_NONE, mask, repeat_times,
unary_repeat_params);
}
}
if (element_reminder != 0) {
uint32_t mask = element_reminder;
uint8_t dst_stride = DivCeil(mask, (ONE_BLK_SIZE / sizeof(uint64_t)));
GetSignBitTensorNormal(src0_bits, src0[element_extent * elem_in_one_repeat], inner_dup, mask, repeat_times,
src_repeat_stride, dst_stride);
if (src1.GetSize() * sizeof(int64_t) == 32) {
GetSignBitTensorNormal(src1_bits, src1, inner_dup, mask, repeat_times, src_repeat_stride, dst_stride);
} else {
GetSignBitTensorNormal(src1_bits, src1[element_extent * elem_in_one_repeat], inner_dup, mask, repeat_times,
src_repeat_stride, dst_stride);
}
AscendC::LocalTensor<int32_t> src0_bits_tmp = src0_bits.ReinterpretCast<int32_t>();
AscendC::LocalTensor<int32_t> src1_bits_tmp = src1_bits.ReinterpretCast<int32_t>();
AscendC::PipeBarrier<PIPE_V>();
BinaryRepeatParams repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, mask * 2, repeat_times, repeat_params);
CastTensorToHalfNormal(sub_res, src0_bits, src1_bits, inner_dup, mask, repeat_times, dst_stride);
CalcWeightedTensorNormal(sign_res, sub_res.ReinterpretCast<half>(), inner_dup, mask, repeat_times, half(8),
half(2), dst_stride);
AscendC::Duplicate(inner_dup, uint32_t(0x7FFFFFFF), inner_dup.GetSize());
AscendC::PipeBarrier<PIPE_V>();
repeat_params = {1, 1, 1, dst_stride, src_repeat_stride, src_repeat_stride};
AscendC::LocalTensor<uint16_t> src0_tmp = src0[element_extent * elem_in_one_repeat].ReinterpretCast<uint16_t>();
AscendC::And(src0_bits, src0_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
if (src1.GetSize() * sizeof(int64_t) == 32) {
repeat_params = {1, 0, 1, dst_stride, 0, src_repeat_stride};
AscendC::LocalTensor<uint16_t> src1_tmp = src1.ReinterpretCast<uint16_t>();
AscendC::And(src1_bits, src1_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
} else {
AscendC::LocalTensor<uint16_t> src1_tmp = src1[element_extent * elem_in_one_repeat].ReinterpretCast<uint16_t>();
AscendC::And(src1_bits, src1_tmp, inner_dup_tmp, mask * 4, repeat_times, repeat_params);
}
repeat_params = {1, 1, 1, dst_stride, dst_stride, dst_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, mask * 2, repeat_times, repeat_params);
UnaryRepeatParams unary_repeat_params = {1, 1, dst_stride, dst_stride};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Maxs(src0_bits_tmp, src0_bits_tmp, -1, mask * 2, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(src0_bits_tmp, src0_bits_tmp, 1, mask * 2, repeat_times, unary_repeat_params);
CastTensorToHalfNormal(sub_res, src0_bits, src1_bits, inner_dup, mask, repeat_times, dst_stride);
CalcWeightedTensorNormal(value_res, sub_res.ReinterpretCast<half>(), inner_dup, mask, repeat_times, half(4),
half(1), dst_stride);
repeat_params = {1, 1, 1, 2, 4, 4};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(sign_res, sign_res, value_res, mask, repeat_times, repeat_params);
AscendC::PipeBarrier<PIPE_V>();
unary_repeat_params = {1, 1, 2, 2};
if (mode == CMPMODE::GT) {
AscendC::Maxs(sign_res, sign_res, half(0), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(1), mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 2};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst[element_extent * elem_in_one_repeat], sign_res, AscendC::RoundMode::CAST_NONE, mask,
repeat_times, unary_repeat_params);
} else {
AscendC::Maxs(sign_res, sign_res, half(-1), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(0), mask, repeat_times, unary_repeat_params);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(sign_res, sign_res, half(1), mask, repeat_times, unary_repeat_params);
unary_repeat_params = {1, 1, dst_repeat_stride, 2};
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst[element_extent * elem_in_one_repeat], sign_res, AscendC::RoundMode::CAST_NONE, mask,
repeat_times, unary_repeat_params);
}
}
}
}
template <typename T, CMPMODE mode>
inline __aicore__ void CompareExtendRepeat(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<T> &src0,
const AscendC::LocalTensor<T> &src1,
const uint8_t repeat_times, const uint32_t last_axis,
const uint32_t input_last_dim_stride, const uint32_t output_last_dim_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
const uint32_t max_stride = 255;
if constexpr (AscendC::IsSameType<T, int32_t>::value &&
(mode == CMPMODE::GT || mode == CMPMODE::NE || mode == CMPMODE::GE || mode == CMPMODE::LE ||
mode == CMPMODE::LT)) {
if (input_last_dim_stride * sizeof(int32_t) / ONE_BLK_SIZE <= max_stride &&
output_last_dim_stride / ONE_BLK_SIZE <= max_stride) {
CompareExtendInt32<mode>(dst, src0, src1, repeat_times, last_axis, input_last_dim_stride, output_last_dim_stride,
tmp_buf);
} else {
for (uint32_t i = 0; i < repeat_times; ++i) {
CompareExtendInt32<mode>(dst[i * output_last_dim_stride], src0[i * input_last_dim_stride],
src1[i * input_last_dim_stride], 1, last_axis, input_last_dim_stride,
output_last_dim_stride, tmp_buf);
}
}
} else if constexpr (AscendC::IsSameType<T, int64_t>::value) {
if constexpr (mode != CMPMODE::EQ && mode != CMPMODE::NE && mode != CMPMODE::GT && mode != CMPMODE::GE &&
mode != CMPMODE::LE) {
ASSERT(false && "CompareExtend mode only support EQ/NE/GT/GE/LE when DataType is int64.");
}
if constexpr (mode == CMPMODE::EQ || mode == CMPMODE::NE) {
if (input_last_dim_stride * sizeof(int64_t) / ONE_BLK_SIZE <= max_stride &&
output_last_dim_stride / ONE_BLK_SIZE <= max_stride) {
CompareExtendInt64EqNe<mode>(dst, src0, src1,
repeat_times, last_axis, input_last_dim_stride, output_last_dim_stride, tmp_buf);
} else {
for (uint8_t i = 0; i < repeat_times; ++i) {
CompareExtendInt64EqNe<mode>(dst[i * output_last_dim_stride],
src0[i * input_last_dim_stride],
src1[i * input_last_dim_stride],
1, 0, input_last_dim_stride, output_last_dim_stride, tmp_buf);
}
}
} else if constexpr (mode == CMPMODE::LE || mode == CMPMODE::GT || mode == CMPMODE::GE) {
if (input_last_dim_stride * sizeof(int64_t) / ONE_BLK_SIZE <= max_stride &&
output_last_dim_stride / ONE_BLK_SIZE <= max_stride) {
CompareExtendInt64GtGeLe<mode>(dst, src0, src1,
repeat_times, last_axis, input_last_dim_stride, output_last_dim_stride, tmp_buf);
} else {
for (uint8_t i = 0; i < repeat_times; ++i) {
CompareExtendInt64GtGeLe<mode>(dst[i * output_last_dim_stride],
src0[i * input_last_dim_stride],
src1[i * input_last_dim_stride],
1, 0, input_last_dim_stride, output_last_dim_stride, tmp_buf);
}
}
}
} else {
if (input_last_dim_stride * sizeof(T) / ONE_BLK_SIZE <= max_stride &&
output_last_dim_stride / ONE_BLK_SIZE <= max_stride) {
CompareNormal<T, mode>(dst, src0, src1,
repeat_times, last_axis, input_last_dim_stride, output_last_dim_stride, tmp_buf);
} else {
for (uint8_t i = 0; i < repeat_times; ++i) {
CompareNormal<T, mode>(dst[i * output_last_dim_stride],
src0[i * input_last_dim_stride],
src1[i * input_last_dim_stride],
1, last_axis, input_last_dim_stride, output_last_dim_stride, tmp_buf);
}
}
}
}
template <typename T, CMPMODE mode>
inline __aicore__ void CompareExtend(const AscendC::LocalTensor<uint8_t> &dst,
const AscendC::LocalTensor<T> &src0,
const AscendC::LocalTensor<T> &src1,
const uint32_t first_axis, const uint32_t last_axis,
const uint32_t input_last_dim_stride, const uint32_t output_last_dim_stride,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
uint32_t calc_repeat_time = 0;
uint64_t inputOffset = calc_repeat_time * input_last_dim_stride;
uint64_t outputOffset = calc_repeat_time * output_last_dim_stride;
if (first_axis > MAX_REPEAT_TIME) {
for (; (calc_repeat_time + MAX_REPEAT_TIME) < first_axis; calc_repeat_time += MAX_REPEAT_TIME) {
inputOffset = calc_repeat_time * input_last_dim_stride;
outputOffset = calc_repeat_time * output_last_dim_stride;
if (src1.GetSize() * sizeof(T) == 32) {
CompareExtendRepeat<T, mode>(dst[outputOffset], src0[inputOffset], src1, MAX_REPEAT_TIME, last_axis,
input_last_dim_stride, output_last_dim_stride, tmp_buf);
} else {
CompareExtendRepeat<T, mode>(dst[outputOffset], src0[inputOffset], src1[inputOffset], MAX_REPEAT_TIME, last_axis,
input_last_dim_stride, output_last_dim_stride, tmp_buf);
}
}
}
uint8_t remain_repeat_time = (uint8_t)(first_axis - calc_repeat_time);
inputOffset = calc_repeat_time * input_last_dim_stride;
outputOffset = calc_repeat_time * output_last_dim_stride;
if (src1.GetSize() * sizeof(T) == 32) {
CompareExtendRepeat<T, mode>(dst[outputOffset], src0[inputOffset], src1, remain_repeat_time, last_axis,
input_last_dim_stride, output_last_dim_stride, tmp_buf);
} else {
CompareExtendRepeat<T, mode>(dst[outputOffset], src0[inputOffset], src1[inputOffset], remain_repeat_time, last_axis,
input_last_dim_stride, output_last_dim_stride, tmp_buf);
}
}
#endif