* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __ASCENDC_API_COMPARE_H__
#define __ASCENDC_API_COMPARE_H__
inline __aicore__ void TmpDupInt64(int64_t src, AscendC::LocalTensor<uint8_t> &tmp_buf) {
uint16_t bit_00 = src & 0xFFFF;
uint16_t bit_16 = src >> 16 & 0xFFFF;
uint16_t bit_32 = src >> 32 & 0xFFFF;
uint16_t bit_48 = src >> 48 & 0xFFFF;
constexpr uint32_t TRANSPOSE_MIN_NUM = 16;
constexpr uint32_t TRANSPOSE_MIN_CUBE_BYTE_SIZE = TRANSPOSE_MIN_NUM * TRANSPOSE_MIN_NUM * sizeof(uint16_t);
AscendC::LocalTensor<uint16_t> init_buf = KernelUtils::NewTensor<uint16_t>(tmp_buf, 0, TRANSPOSE_MIN_CUBE_BYTE_SIZE);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 0], bit_00, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 1], bit_16, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 2], bit_32, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 3], bit_48, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 4], bit_00, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 5], bit_16, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 6], bit_32, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 7], bit_48, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 8], bit_00, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 9], bit_16, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 10], bit_32, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 11], bit_48, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 12], bit_00, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 13], bit_16, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 14], bit_32, TRANSPOSE_MIN_NUM);
Duplicate(init_buf[TRANSPOSE_MIN_NUM * 15], bit_48, TRANSPOSE_MIN_NUM);
AscendC::PipeBarrier<PIPE_V>();
Transpose(init_buf, init_buf);
}
template <typename T>
inline __aicore__ void CompareScalarExtend(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const int64_t src1,
CMPMODE mode,
const uint32_t cal_cnt,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
if (mode != CMPMODE::EQ && mode != CMPMODE::NE) {
ASSERT(false && "CompareScalarExtend mode only support EQ/NE when DataType is int64.");
}
TmpDupInt64(src1, tmp_buf);
LocalTensor<int64_t> src1_buf = KernelUtils::NewTensor<int64_t>(tmp_buf, 0, KernelUtils::BlkSize<int64_t>());
LocalTensor<uint8_t> rest_buf = tmp_buf[ONE_BLK_SIZE].template ReinterpretCast<uint8_t>();
LocalTensor<half> all_one_buf = KernelUtils::NewTensor<half>(tmp_buf, ONE_BLK_SIZE, KernelUtils::RptSize<half>());
Duplicate(all_one_buf, (half)1.0f, KernelUtils::RptSize<half>());
constexpr uint32_t RATIO = sizeof(int64_t) / sizeof(int32_t);
constexpr uint32_t ONE_RPT_SIZE = KernelUtils::RptSize<int32_t>();
constexpr uint32_t PARTITION_NUMBER = sizeof(int32_t) + sizeof(float);
const uint32_t rest_buf_byte_size = tmp_buf.GetSize() - ONE_BLK_SIZE - ONE_REPEAT_BYTE_SIZE;
const uint32_t each_part_buf_byte_size =
rest_buf_byte_size / PARTITION_NUMBER / ONE_REPEAT_BYTE_SIZE * ONE_REPEAT_BYTE_SIZE;
const uint32_t size = cal_cnt * RATIO;
LocalTensor<int32_t> sub_res_buf =
KernelUtils::NewTensor<int32_t>(tmp_buf, ONE_BLK_SIZE + ONE_REPEAT_BYTE_SIZE, each_part_buf_byte_size);
LocalTensor<int64_t> sub_res_64_buf = sub_res_buf.template ReinterpretCast<int64_t>();
uint32_t offset = ONE_BLK_SIZE + ONE_REPEAT_BYTE_SIZE + each_part_buf_byte_size * sizeof(int32_t);
LocalTensor<float> cast_res_buf = KernelUtils::NewTensor<float>(tmp_buf, offset, each_part_buf_byte_size);
LocalTensor<uint8_t> sub_res_reuse_buf = sub_res_buf.template ReinterpretCast<uint8_t>();
LocalTensor<uint8_t> cmp_res_buf = KernelUtils::NewTensor<uint8_t>(sub_res_reuse_buf, 0, each_part_buf_byte_size);
LocalTensor<half> sel_res_buf =
KernelUtils::NewTensor<half>(sub_res_reuse_buf, each_part_buf_byte_size, each_part_buf_byte_size);
LocalTensor<int32_t> src0_new = src0.ReinterpretCast<int32_t>();
LocalTensor<int32_t> src1_new = src1_buf.ReinterpretCast<int32_t>();
BinaryRepeatParams sub_params;
sub_params.src1BlkStride = 0;
sub_params.src1RepStride = 0;
BinaryRepeatParams sel_params;
sel_params.src0BlkStride = 0;
sel_params.src0RepStride = 0;
sel_params.src1RepStride = 0;
sel_params.src1RepStride = 0;
uint32_t max_buf_size = sub_res_buf.GetSize();
uint32_t max_buf_rpt_num = KernelUtils::RptNum<int32_t>(max_buf_size);
uint32_t max_do_rpt_num = KernelUtils::Min(MAX_REPEAT_TIME, max_buf_rpt_num);
uint32_t max_do_size = max_do_rpt_num * ONE_RPT_SIZE;
sub_params.blockNumber = KernelUtils::BlkNum<int32_t>(max_do_size);
uint32_t max_dst_rpt_num = max_do_rpt_num / RATIO;
uint32_t max_dst_size = max_do_size / RATIO;
sel_params.blockNumber = KernelUtils::BlkNum<half>(max_dst_size);
uint32_t calc_size = 0;
for (; calc_size + max_do_size <= size; calc_size += max_do_size) {
AscendC::Sub(sub_res_buf[0], src0_new[calc_size], src1_new[0], ONE_RPT_SIZE, max_do_rpt_num, sub_params);
AscendC::Cast(cast_res_buf[0], sub_res_64_buf[0], AscendC::RoundMode::CAST_RINT, max_dst_size);
AscendC::CompareScalar(cmp_res_buf[0], cast_res_buf[0], 0.0f, mode, max_dst_size);
AscendC::Select(sel_res_buf[0],
cmp_res_buf[0],
all_one_buf[0],
(half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(),
max_dst_rpt_num,
sel_params);
AscendC::Cast(dst[calc_size / RATIO], sel_res_buf[0], AscendC::RoundMode::CAST_NONE, max_dst_size);
}
if (calc_size + ONE_RPT_SIZE <= size) {
uint32_t left_rpt_num = KernelUtils::RptNum<int32_t>(size - calc_size);
uint32_t do_size = left_rpt_num * KernelUtils::RptSize<int32_t>();
uint32_t dst_size = do_size / RATIO;
AscendC::Sub(sub_res_buf[0], src0_new[calc_size], src1_new[0], ONE_RPT_SIZE, left_rpt_num, sub_params);
AscendC::Cast(cast_res_buf[0], sub_res_64_buf[0], AscendC::RoundMode::CAST_RINT, dst_size);
AscendC::CompareScalar(cmp_res_buf[0], cast_res_buf[0], 0.0f, mode, do_size);
AscendC::Select(sel_res_buf[0],
cmp_res_buf[0],
all_one_buf[0],
(half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(),
left_rpt_num,
sel_params);
AscendC::Cast(dst[calc_size / RATIO], sel_res_buf[0], AscendC::RoundMode::CAST_NONE, dst_size);
calc_size += do_size;
}
if (calc_size < size) {
uint32_t left_size = size - calc_size;
uint32_t dst_size = left_size / RATIO;
uint32_t mask = KernelUtils::SizeAlign(left_size, KernelUtils::BlkSize<int32_t>());
sub_params.blockNumber = KernelUtils::BlkNum<int32_t>(mask);
sel_params.blockNumber = KernelUtils::SizeAlign(dst_size, KernelUtils::BlkSize<half>());
AscendC::Sub(sub_res_buf[0], src0_new[calc_size], src1_new[0], mask, 1, sub_params);
AscendC::Cast(cast_res_buf[0], sub_res_64_buf[0], AscendC::RoundMode::CAST_RINT, dst_size);
AscendC::CompareScalar(cmp_res_buf[0], cast_res_buf[0], 0.0f, mode, ONE_RPT_SIZE);
AscendC::Select(sel_res_buf[0],
cmp_res_buf[0],
all_one_buf[0],
(half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(),
1,
sel_params);
AscendC::Cast(dst[calc_size / RATIO], sel_res_buf[0], AscendC::RoundMode::CAST_NONE, dst_size);
}
}
inline __aicore__ void ApplyCompareMode(LocalTensor<int32_t> &inter_buf, CMPMODE mode, uint32_t num_elements) {
switch (mode) {
case CMPMODE::GE:
AscendC::Adds(inter_buf, inter_buf, (int32_t)1, num_elements);
AscendC::PipeBarrier<PIPE_V>();
case CMPMODE::GT:
AscendC::Maxs(inter_buf, inter_buf, (int32_t)0, num_elements);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, num_elements);
break;
case CMPMODE::LE:
AscendC::Adds(inter_buf, inter_buf, (int32_t)(-1), num_elements);
AscendC::PipeBarrier<PIPE_V>();
case CMPMODE::LT:
case CMPMODE::NE:
AscendC::Maxs(inter_buf, inter_buf, (int32_t)(-1), num_elements);
AscendC::PipeBarrier<PIPE_V>();
if (mode == CMPMODE::NE) {
AscendC::Mins(inter_buf, inter_buf, (int32_t)1, num_elements);
} else {
AscendC::Mins(inter_buf, inter_buf, (int32_t)0, num_elements);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(inter_buf, inter_buf, inter_buf, num_elements);
break;
default:
break;
}
AscendC::PipeBarrier<PIPE_V>();
}
template <typename OutT>
inline __aicore__ void PerformTypeConversion(const LocalTensor<OutT> &dst, LocalTensor<int32_t> &inter_buf,
uint32_t offset, uint32_t num_elements) {
LocalTensor<int16_t> int16_buf = inter_buf.template ReinterpretCast<int16_t>();
AscendC::Cast(int16_buf, inter_buf, RoundMode::CAST_NONE, num_elements);
LocalTensor<half> half_buf = inter_buf.template ReinterpretCast<half>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(half_buf, int16_buf, RoundMode::CAST_NONE, num_elements);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst[offset], half_buf, RoundMode::CAST_NONE, num_elements);
}
template <typename OutT>
inline __aicore__ void ProcessScalarCompareBlock(const LocalTensor<OutT> &dst, const LocalTensor<int32_t> &src0,
const LocalTensor<int32_t> &tensor_src1,
LocalTensor<int32_t> &inter_buf, uint32_t offset, uint32_t mask,
uint32_t repeat_times, uint32_t num_elements, CMPMODE mode,
const BinaryRepeatParams &binary_param) {
AscendC::Sub(inter_buf, src0[offset], tensor_src1[0], mask, repeat_times, binary_param);
AscendC::PipeBarrier<PIPE_V>();
ApplyCompareMode(inter_buf, mode, num_elements);
PerformTypeConversion(dst, inter_buf, offset, num_elements);
}
template <typename OutT>
inline __aicore__ void CompareScalarExtendInt32(const LocalTensor<OutT> &dst, const LocalTensor<int32_t> &src0,
const int32_t scalar_src1, CMPMODE mode, const uint32_t cal_cnt,
LocalTensor<uint8_t> &tmp_buf) {
uint32_t tmp_buf_size = tmp_buf.GetSize() * sizeof(uint8_t) - ONE_BLK_SIZE;
uint32_t loop_cnt;
uint32_t repeat_times = MAX_REPEAT_TIME;
uint32_t one_blk_num = KernelUtils::BlkSize<int32_t>();
uint32_t one_repeat_num = KernelUtils::RptSize<int32_t>();
if (tmp_buf_size >= (ONE_REPEAT_BYTE_SIZE * MAX_REPEAT_TIME)) {
loop_cnt = cal_cnt / (one_repeat_num * MAX_REPEAT_TIME);
} else {
repeat_times = tmp_buf_size / ONE_REPEAT_BYTE_SIZE;
loop_cnt = cal_cnt / (repeat_times * one_repeat_num);
}
uint32_t one_step_num = repeat_times * one_repeat_num;
LocalTensor<int32_t> tensor_src1 = tmp_buf[0].template ReinterpretCast<int32_t>();
tensor_src1.SetSize(one_blk_num);
Duplicate(tensor_src1, scalar_src1, one_blk_num);
uint32_t inter_buf_offset = KernelUtils::BlkAlign<uint8_t>(one_blk_num * sizeof(int32_t));
LocalTensor<int32_t> inter_buf = tmp_buf[inter_buf_offset].template ReinterpretCast<int32_t>();
inter_buf.SetSize(one_step_num);
BinaryRepeatParams binary_param(1, 1, 0, 8, 8, 0);
uint64_t mask = one_repeat_num;
uint32_t offset = 0;
for (uint32_t i = 0; i < loop_cnt; i++) {
ProcessScalarCompareBlock(dst, src0, tensor_src1, inter_buf, offset, mask,
repeat_times, one_step_num, mode, binary_param);
offset += one_step_num;
}
uint32_t remain_rpt_times = (cal_cnt - offset) / one_repeat_num;
uint32_t remain_nums = remain_rpt_times * one_repeat_num;
if (remain_rpt_times != 0) {
ProcessScalarCompareBlock(dst, src0, tensor_src1, inter_buf, offset, mask,
remain_rpt_times, remain_nums, mode, binary_param);
offset += remain_nums;
}
uint32_t calc_tail = cal_cnt - offset;
if (calc_tail != 0) {
uint32_t aligned_tail = KernelUtils::BlkAlign<int32_t>(calc_tail);
BinaryRepeatParams tail_param(1, 1, 0, 8, 8, 0);
tail_param.src0RepStride = aligned_tail;
tail_param.dstRepStride = aligned_tail;
ProcessScalarCompareBlock(dst, src0, tensor_src1, inter_buf, offset, calc_tail,
1, calc_tail, mode, tail_param);
}
}
template <typename InT, typename OutT>
inline __aicore__ void CompareScalarExtend(const LocalTensor<OutT> &dst, const LocalTensor<InT> &src,
const InT constant_y, CMPMODE mode, const uint32_t cal_cnt,
LocalTensor<uint8_t> &tmp_buf) {
if constexpr (AscendC::IsSameType<InT, int32_t>::value) {
if (mode == CMPMODE::GT || mode == CMPMODE::GE || mode == CMPMODE::LE || mode == CMPMODE::LT ||
mode == CMPMODE::NE) {
return CompareScalarExtendInt32(dst, src, constant_y, mode, cal_cnt, tmp_buf);
}
}
constexpr int32_t one_block_cnt = ONE_BLK_SIZE / sizeof(InT);
constexpr uint16_t compare_size = KernelUtils::BlkAlign<uint8_t>((64 * MAX_REPEAT_TIMES) / 8);
constexpr uint16_t max_compare_size = (compare_size > 3072) ? 3072 : compare_size;
constexpr uint16_t max_block_cnt = max_compare_size / sizeof(InT);
const int32_t loop_num = cal_cnt / max_block_cnt;
uint32_t tmp_offset = 0;
LocalTensor<half> tensor_src1 = tmp_buf[0].template ReinterpretCast<half>();
tensor_src1.SetSize(KernelUtils::BlkSize<half>());
Duplicate(tensor_src1, (half)1.0, KernelUtils::BlkSize<half>());
tmp_offset += ONE_BLK_SIZE;
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
compare_out.SetSize(max_block_cnt / 8);
tmp_offset += max_block_cnt * sizeof(uint8_t) / 8;
LocalTensor<half> select_out = tmp_buf[tmp_offset].template ReinterpretCast<half>();
select_out.SetSize(max_block_cnt);
tmp_offset += max_block_cnt * sizeof(half);
BinaryRepeatParams binary_param(1, 0, 0, 8, 0, 0);
constexpr uint16_t one_repeat_cnt_half = KernelUtils::RptSize<half>();
constexpr int32_t one_repeat_cnt_in = KernelUtils::RptSize<InT>();
constexpr int32_t one_repeat_cnt_out = KernelUtils::RptSize<OutT>();
constexpr uint64_t cast_mask = sizeof(OutT) > sizeof(half) ? one_repeat_cnt_out : one_repeat_cnt_half;
UnaryRepeatParams cast_param{1, 1, 8 * cast_mask / one_repeat_cnt_out, 8 * cast_mask / one_repeat_cnt_half};
int32_t cnt = 0;
int loop = 0;
for (loop = 0; loop < loop_num; loop++) {
AscendC::PipeBarrier<PIPE_V>();
CompareScalar(compare_out[0], src[cnt], constant_y, mode, one_repeat_cnt_in, max_block_cnt / one_repeat_cnt_in,
{1, 1, 8, 8});
AscendC::PipeBarrier<PIPE_V>();
Select(select_out[0], compare_out[0], tensor_src1[0], (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE,
one_repeat_cnt_half, max_block_cnt / one_repeat_cnt_half, {1, 0, 0, 8, 0, 0});
AscendC::PipeBarrier<PIPE_V>();
Cast(dst[cnt], select_out[0], RoundMode::CAST_NONE, cast_mask, max_block_cnt / cast_mask, cast_param);
cnt += max_block_cnt;
}
int32_t left_cnt_num = cal_cnt % max_block_cnt;
if (left_cnt_num >= one_repeat_cnt_in) {
int32_t do_block_cnt = left_cnt_num / one_repeat_cnt_in * one_repeat_cnt_in;
AscendC::PipeBarrier<PIPE_V>();
CompareScalar(compare_out[0], src[cnt], constant_y, mode, one_repeat_cnt_in, do_block_cnt / one_repeat_cnt_in,
{1, 1, 8, 8});
AscendC::PipeBarrier<PIPE_V>();
Select(select_out[0], compare_out[0], tensor_src1[0], (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE,
one_repeat_cnt_half, (do_block_cnt + one_repeat_cnt_half - 1) / one_repeat_cnt_half, binary_param);
AscendC::PipeBarrier<PIPE_V>();
Cast(dst[cnt], select_out[0], RoundMode::CAST_NONE, do_block_cnt);
cnt += do_block_cnt;
}
int32_t block_cnt = DivCeil((int32_t)cal_cnt - cnt, one_block_cnt);
if (block_cnt > 0) {
const int32_t left_cnt = block_cnt * one_block_cnt;
AscendC::PipeBarrier<PIPE_V>();
CompareScalar(compare_out[0], src[cnt], constant_y, mode, left_cnt, 1, {1, 1, 8, 8});
AscendC::PipeBarrier<PIPE_V>();
Select(select_out[0], compare_out[0], tensor_src1[0], (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, left_cnt, 1,
{1, 0, 0, 8, 0, 0});
AscendC::PipeBarrier<PIPE_V>();
Cast(dst[cnt], select_out[0], RoundMode::CAST_NONE, left_cnt, 1, cast_param);
}
}
template <typename T>
inline __aicore__ void CompareScalarExtend(const LocalTensor<T> &dst, const LocalTensor<T> &src, const T constant_y,
CMPMODE mode, const uint32_t cal_cnt, LocalTensor<uint8_t> &tmp_buf) {
const int32_t one_block_cnt = ONE_BLK_SIZE / sizeof(T);
const uint16_t compare_size = KernelUtils::BlkAlign<uint8_t>((64 * MAX_REPEAT_TIMES) / 8);
const uint16_t max_compare_size = (compare_size > 49152) ? 49152 : compare_size;
const uint16_t max_block_cnt = max_compare_size / sizeof(T);
const int32_t loop_num = cal_cnt / max_block_cnt;
uint32_t tmp_offset = 0;
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
compare_out.SetSize(max_block_cnt / 8);
Duplicate(dst[0], (T)1.0, max_block_cnt);
tmp_offset += max_block_cnt * sizeof(T) / 8;
LocalTensor<T> src_tmp = tmp_buf[tmp_offset].template ReinterpretCast<T>();
int32_t cnt = 0;
for (int loop = 0; loop < loop_num; loop++) {
AscendC::PipeBarrier<PIPE_V>();
CompareScalar(compare_out[0], src[cnt], constant_y, mode, max_block_cnt);
AscendC::PipeBarrier<PIPE_V>();
Select(dst[cnt], compare_out[0], dst[cnt], (T)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, max_block_cnt);
cnt += max_block_cnt;
}
const int32_t block_cnt = ((int32_t)cal_cnt - cnt) / one_block_cnt;
if (block_cnt > 0) {
int32_t left_cnt = block_cnt * one_block_cnt;
if (left_cnt * sizeof(T) < 256) {
AscendC::PipeBarrier<PIPE_V>();
DataCopy(src_tmp[0], src[cnt], left_cnt);
AscendC::PipeBarrier<PIPE_V>();
CompareScalar(compare_out[0], src_tmp[cnt], constant_y, mode, 256 / sizeof(T));
} else {
AscendC::PipeBarrier<PIPE_V>();
CompareScalar(compare_out[0], src[cnt], constant_y, mode, left_cnt);
}
AscendC::PipeBarrier<PIPE_V>();
Select(dst[cnt], compare_out[0], dst[cnt], (T)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, left_cnt);
cnt += left_cnt;
}
const int32_t tail_size = cal_cnt - cnt;
if (tail_size > 0) {
ASSERT(false && "CompareScalarExtend size not support.");
}
}
template <typename T>
inline __aicore__ void CompareExtendEQ(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
CMPMODE mode,
const uint32_t cal_cnt,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
LocalTensor<half> all_one_buf = KernelUtils::NewTensor<half>(tmp_buf, ONE_BLK_SIZE, KernelUtils::RptSize<half>());
Duplicate(all_one_buf, (half)1.0f, KernelUtils::RptSize<half>());
constexpr uint32_t RATIO = sizeof(int64_t) / sizeof(int32_t);
constexpr uint32_t ONE_RPT_SIZE = KernelUtils::RptSize<int32_t>();
constexpr uint32_t PARTITION_NUMBER = sizeof(int32_t) + sizeof(float);
const uint32_t rest_buf_byte_size = tmp_buf.GetSize() - ONE_BLK_SIZE - ONE_REPEAT_BYTE_SIZE;
const uint32_t each_part_buf_byte_size =
rest_buf_byte_size / PARTITION_NUMBER / ONE_REPEAT_BYTE_SIZE * ONE_REPEAT_BYTE_SIZE;
const uint32_t size = cal_cnt * RATIO;
LocalTensor<int32_t> sub_res_buf =
KernelUtils::NewTensor<int32_t>(tmp_buf, ONE_BLK_SIZE + ONE_REPEAT_BYTE_SIZE, each_part_buf_byte_size);
LocalTensor<int64_t> sub_res_64_buf = sub_res_buf.ReinterpretCast<int64_t>();
uint32_t offset = ONE_BLK_SIZE + ONE_REPEAT_BYTE_SIZE + each_part_buf_byte_size * sizeof(int32_t);
LocalTensor<float> cast_res_buf = KernelUtils::NewTensor<float>(tmp_buf, offset, each_part_buf_byte_size);
LocalTensor<uint8_t> sub_res_reuse_buf = sub_res_buf.template ReinterpretCast<uint8_t>();
LocalTensor<uint8_t> cmp_res_buf = KernelUtils::NewTensor<uint8_t>(sub_res_reuse_buf, 0, each_part_buf_byte_size);
LocalTensor<half> sel_res_buf =
KernelUtils::NewTensor<half>(sub_res_reuse_buf, each_part_buf_byte_size, each_part_buf_byte_size);
LocalTensor<int32_t> src0_new = src0.ReinterpretCast<int32_t>();
LocalTensor<int32_t> src1_new = src1.ReinterpretCast<int32_t>();
BinaryRepeatParams sel_params;
sel_params.src0BlkStride = 0;
sel_params.src0RepStride = 0;
sel_params.src1RepStride = 0;
sel_params.src1RepStride = 0;
uint32_t max_buf_size = sub_res_buf.GetSize();
uint32_t max_buf_rpt_num = KernelUtils::RptNum<int32_t>(max_buf_size);
uint32_t max_do_rpt_num = KernelUtils::Min(MAX_REPEAT_TIME, max_buf_rpt_num);
uint32_t max_do_size = max_do_rpt_num * ONE_RPT_SIZE;
uint32_t max_dst_rpt_num = max_do_rpt_num / RATIO;
uint32_t max_dst_size = max_do_size / RATIO;
sel_params.blockNumber = KernelUtils::BlkNum<half>(max_dst_size);
uint32_t calc_size = 0;
for (; calc_size + max_do_size <= size; calc_size += max_do_size) {
AscendC::Sub(sub_res_buf[0], src0_new[calc_size], src1_new[calc_size], max_do_size);
AscendC::Cast(cast_res_buf[0], sub_res_64_buf[0], AscendC::RoundMode::CAST_RINT, max_dst_size);
AscendC::CompareScalar(cmp_res_buf[0], cast_res_buf[0], 0.0f, mode, max_dst_size);
AscendC::Select(sel_res_buf[0],
cmp_res_buf[0],
all_one_buf[0],
(half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(),
max_dst_rpt_num,
sel_params);
AscendC::Cast(dst[calc_size / RATIO], sel_res_buf[0], AscendC::RoundMode::CAST_NONE, max_dst_size);
}
if (calc_size + ONE_RPT_SIZE <= size) {
uint32_t left_rpt_num = KernelUtils::RptNum<int32_t>(size - calc_size);
uint32_t do_size = left_rpt_num * KernelUtils::RptSize<int32_t>();
uint32_t dst_size = do_size / RATIO;
AscendC::Sub(sub_res_buf[0], src0_new[calc_size], src1_new[calc_size], (ONE_RPT_SIZE * left_rpt_num));
AscendC::Cast(cast_res_buf[0], sub_res_64_buf[0], AscendC::RoundMode::CAST_RINT, dst_size);
AscendC::CompareScalar(cmp_res_buf[0], cast_res_buf[0], 0.0f, mode, do_size);
AscendC::Select(sel_res_buf[0],
cmp_res_buf[0],
all_one_buf[0],
(half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(),
left_rpt_num,
sel_params);
AscendC::Cast(dst[calc_size / RATIO], sel_res_buf[0], AscendC::RoundMode::CAST_NONE, dst_size);
calc_size += do_size;
}
if (calc_size < size) {
uint32_t left_size = size - calc_size;
uint32_t dst_size = left_size / RATIO;
uint32_t mask = KernelUtils::SizeAlign(left_size, KernelUtils::BlkSize<int32_t>());
sel_params.blockNumber = KernelUtils::SizeAlign(dst_size, KernelUtils::BlkSize<half>());
AscendC::Sub(sub_res_buf[0], src0_new[calc_size], src1_new[calc_size], mask);
AscendC::Cast(cast_res_buf[0], sub_res_64_buf[0], AscendC::RoundMode::CAST_RINT, dst_size);
AscendC::CompareScalar(cmp_res_buf[0], cast_res_buf[0], 0.0f, mode, ONE_RPT_SIZE);
AscendC::Select(sel_res_buf[0],
cmp_res_buf[0],
all_one_buf[0],
(half)0,
AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE,
KernelUtils::RptSize<half>(),
1,
sel_params);
AscendC::Cast(dst[calc_size / RATIO], sel_res_buf[0], AscendC::RoundMode::CAST_NONE, dst_size);
}
}
template <typename OutT>
inline __aicore__ void ProcessTensorCompareBlock(const LocalTensor<OutT> &dst, const LocalTensor<int32_t> &src0,
const LocalTensor<int32_t> &src1, LocalTensor<int32_t> &inter_buf,
uint32_t offset, uint32_t num_elements, CMPMODE mode) {
AscendC::Sub(inter_buf, src0[offset], src1[offset], num_elements);
AscendC::PipeBarrier<PIPE_V>();
ApplyCompareMode(inter_buf, mode, num_elements);
PerformTypeConversion(dst, inter_buf, offset, num_elements);
}
template <typename OutT>
inline __aicore__ void CompareExtendInt32(const LocalTensor<OutT> &dst, const LocalTensor<int32_t> &src0,
const LocalTensor<int32_t> &src1, CMPMODE mode, const uint32_t cal_cnt,
LocalTensor<uint8_t> &tmp_buf) {
uint32_t buf_max_cnt = tmp_buf.GetSize() * sizeof(uint8_t) / sizeof(int32_t);
uint32_t loop_cnt = cal_cnt / buf_max_cnt;
uint32_t max_rpt_cnts = buf_max_cnt * sizeof(int32_t) / ONE_REPEAT_BYTE_SIZE;
LocalTensor<int32_t> inter_buf = tmp_buf.ReinterpretCast<int32_t>();
uint32_t offset = 0;
for (uint32_t i = 0; i < loop_cnt; i++) {
ProcessTensorCompareBlock(dst, src0, src1, inter_buf, offset, buf_max_cnt, mode);
offset += buf_max_cnt;
}
uint32_t tail_cnt = cal_cnt - offset;
if (tail_cnt > 0) {
ProcessTensorCompareBlock(dst, src0, src1, inter_buf, offset, tail_cnt, mode);
}
}
template <typename InT, typename OutT>
inline __aicore__ void CompareExtend(const LocalTensor<OutT> &dst, const LocalTensor<InT> &src0,
const LocalTensor<InT> &src1, CMPMODE mode, const uint32_t cal_cnt,
LocalTensor<uint8_t> &tmp_buf) {
if constexpr (AscendC::IsSameType<InT, int32_t>::value && AscendC::IsSameType<OutT, uint8_t>::value) {
if (mode != CMPMODE::GT && mode != CMPMODE::GE && mode != CMPMODE::LE && mode != CMPMODE::LT &&
mode != CMPMODE::NE && mode != CMPMODE::EQ) {
ASSERT(false && "CompareExtend mode only support EQ/GT/GE/LE/LT/NE when DataType is int32.");
}
if (mode == CMPMODE::GT || mode == CMPMODE::GE || mode == CMPMODE::LE || mode == CMPMODE::LT ||
mode == CMPMODE::NE) {
return CompareExtendInt32(dst, src0, src1, mode, cal_cnt, tmp_buf);
}
}
constexpr int32_t one_block_cnt = ONE_BLK_SIZE / sizeof(half);
constexpr int32_t one_repeat_cnt_in = KernelUtils::RptSize<InT>();
uint32_t max_block_cnt = (tmp_buf.GetSize() - ONE_BLK_SIZE) / 3;
max_block_cnt = max_block_cnt * sizeof(half) / ONE_REPEAT_BYTE_SIZE * ONE_REPEAT_BYTE_SIZE / sizeof(half);
const int32_t loop_num = max_block_cnt == 0 ? 0 : cal_cnt / max_block_cnt;
uint32_t tmp_offset = 0;
LocalTensor<half> ones = tmp_buf.template ReinterpretCast<half>();
Duplicate(ones, (half)1.0, one_block_cnt);
tmp_offset += ONE_BLK_SIZE;
LocalTensor<half> select_out = tmp_buf[tmp_offset].template ReinterpretCast<half>();
tmp_offset += max_block_cnt == 0 ? ONE_REPEAT_BYTE_SIZE : max_block_cnt * sizeof(half);
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
int32_t cnt = 0;
int loop = 0;
BinaryRepeatParams select_param{1, 0, 0, 8, 0, 0};
BinaryRepeatParams compare_param{1, 1, 1, 8, 8, 8};
constexpr uint16_t one_repeat_cnt_half = KernelUtils::RptSize<half>();
constexpr int32_t one_repeat_cnt_out = KernelUtils::RptSize<OutT>();
const uint32_t compare_max_rpt_times = max_block_cnt / one_repeat_cnt_in;
const uint32_t select_max_rpt_times = max_block_cnt / one_repeat_cnt_half;
for (loop = 0; loop < loop_num; loop++) {
Compare(compare_out, src0[cnt], src1[cnt], mode, one_repeat_cnt_in, compare_max_rpt_times, compare_param);
AscendC::PipeBarrier<PIPE_V>();
Select(select_out, compare_out, ones, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, one_repeat_cnt_half,
select_max_rpt_times, select_param);
AscendC::PipeBarrier<PIPE_V>();
Cast(dst[cnt], select_out, RoundMode::CAST_NONE, max_block_cnt);
cnt += max_block_cnt;
}
int32_t left_cnt_num = cal_cnt - cnt;
if (left_cnt_num >= one_repeat_cnt_in) {
const int32_t repeat_times = left_cnt_num / one_repeat_cnt_in;
const int32_t do_block_cnt = repeat_times * one_repeat_cnt_in;
Compare(compare_out, src0[cnt], src1[cnt], mode, one_repeat_cnt_in, repeat_times, compare_param);
AscendC::PipeBarrier<PIPE_V>();
Select(select_out, compare_out, ones, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, one_repeat_cnt_half,
(do_block_cnt + one_repeat_cnt_half - 1) / one_repeat_cnt_half, select_param);
AscendC::PipeBarrier<PIPE_V>();
Cast(dst[cnt], select_out, RoundMode::CAST_NONE, do_block_cnt);
cnt += do_block_cnt;
}
const int32_t tail_cnt = cal_cnt - cnt;
if (tail_cnt > 0) {
Compare(compare_out, src0[cnt], src1[cnt], mode, tail_cnt, 1, compare_param);
AscendC::PipeBarrier<PIPE_V>();
Select(select_out, compare_out, ones, (half)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, tail_cnt, 1, select_param);
AscendC::PipeBarrier<PIPE_V>();
Cast(dst[cnt], select_out, RoundMode::CAST_NONE, tail_cnt);
}
}
template <typename T>
inline __aicore__ void CompareExtend(const LocalTensor<T> &dst, const LocalTensor<T> &src0, const LocalTensor<T> &src1,
CMPMODE mode, const uint32_t cal_cnt, LocalTensor<uint8_t> &tmp_buf) {
const int32_t one_block_cnt = ONE_BLK_SIZE / sizeof(T);
const uint16_t compare_size = KernelUtils::BlkAlign<uint8_t>((64 * MAX_REPEAT_TIMES) / 8);
const uint16_t max_compare_size = (compare_size > 49152) ? 49152 : compare_size;
const uint16_t max_block_cnt = max_compare_size / sizeof(T);
const int32_t loop_num = cal_cnt / max_block_cnt;
uint32_t tmp_offset = 0;
LocalTensor<uint8_t> compare_out = tmp_buf[tmp_offset].template ReinterpretCast<uint8_t>();
compare_out.SetSize(max_block_cnt / 8);
Duplicate(dst[0], (T)1.0, max_block_cnt);
tmp_offset += max_block_cnt * sizeof(T) / 8;
LocalTensor<T> src_tmp = tmp_buf[tmp_offset].template ReinterpretCast<T>();
int32_t cnt = 0;
int loop = 0;
for (loop = 0; loop < loop_num; loop++) {
AscendC::PipeBarrier<PIPE_V>();
Compare(compare_out[0], src0[cnt], src1[cnt], mode, max_block_cnt);
AscendC::PipeBarrier<PIPE_V>();
Duplicate(dst[cnt], (T)1.0, max_block_cnt);
AscendC::PipeBarrier<PIPE_V>();
Select(dst[cnt], compare_out[0], dst[cnt], (T)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, max_block_cnt);
cnt += max_block_cnt;
}
const int32_t block_cnt = ((int32_t)cal_cnt - cnt) / one_block_cnt;
if (block_cnt > 0) {
int32_t left_cnt = block_cnt * one_block_cnt;
if (left_cnt * sizeof(T) < 256) {
AscendC::PipeBarrier<PIPE_V>();
DataCopy(src_tmp[0], src0[cnt], left_cnt);
AscendC::PipeBarrier<PIPE_V>();
DataCopy(src_tmp[256 / sizeof(T)], src1[cnt], left_cnt);
AscendC::PipeBarrier<PIPE_V>();
Compare(compare_out[0], src_tmp[cnt], src_tmp[256 / sizeof(T)], mode, 256 / sizeof(T));
} else {
AscendC::PipeBarrier<PIPE_V>();
Compare(compare_out[0], src0[cnt], src1[cnt], mode, left_cnt);
}
AscendC::PipeBarrier<PIPE_V>();
Duplicate(dst[cnt], (T)1.0, left_cnt);
AscendC::PipeBarrier<PIPE_V>();
Select(dst[cnt], compare_out[0], dst[cnt], (T)0.0, SELMODE::VSEL_TENSOR_SCALAR_MODE, left_cnt);
cnt += left_cnt;
}
const int32_t tail_size = cal_cnt - cnt;
if (tail_size > 0) {
ASSERT(false && "CompareExtend size not support.");
}
}
inline __aicore__ void GetSignBitTensor(const AscendC::LocalTensor<uint16_t> &dst0, const AscendC::LocalTensor<uint16_t> &dst1,
const AscendC::LocalTensor<int64_t> &src0, const AscendC::LocalTensor<int64_t> &src1,
const AscendC::LocalTensor<uint32_t> &inner_dup, const uint32_t cal_cnt) {
AscendC::PipeBarrier<PIPE_V>();
uint32_t quadruple_cal_cnt = 4 * cal_cnt;
AscendC::Duplicate(inner_dup, uint32_t(0x80000000), 2 * cal_cnt);
AscendC::LocalTensor<uint16_t> inner_dup_tmp = inner_dup.ReinterpretCast<uint16_t>();
AscendC::LocalTensor<uint16_t> src0_tmp = src0.ReinterpretCast<uint16_t>();
AscendC::LocalTensor<uint16_t> src1_tmp = src1.ReinterpretCast<uint16_t>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(dst0, src0_tmp, inner_dup_tmp, quadruple_cal_cnt);
AscendC::And(dst1, src1_tmp, inner_dup_tmp, quadruple_cal_cnt);
AscendC::LocalTensor<int32_t> dst0_tmp = dst0.ReinterpretCast<int32_t>();
AscendC::LocalTensor<int32_t> dst1_tmp = dst1.ReinterpretCast<int32_t>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::Maxs(dst0_tmp, dst0_tmp, -1, 2 * cal_cnt);
AscendC::Maxs(dst1_tmp, dst1_tmp, -1, 2 * cal_cnt);
uint32_t one_rpt_cnt = KernelUtils::RptSize<int64_t>();
uint32_t rpt_times = cal_cnt / one_rpt_cnt;
uint32_t tail_rpt_cnt = cal_cnt - rpt_times * one_rpt_cnt;
uint64_t mask[2] = {uint64_t(0x5555555555555555), 0};
AscendC::Duplicate(inner_dup, 1U, mask, rpt_times, 1, 8);
if (tail_rpt_cnt != 0) {
uint64_t mask_tail = 0b01;
for (uint32_t i = 1; i < tail_rpt_cnt; i++) {
mask_tail += (0b01 << (2 * i));
}
mask[0] = mask_tail;
AscendC::Duplicate(inner_dup[2 * (cal_cnt - tail_rpt_cnt)], 1U, mask, 1, 1, 8);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(dst0, dst0, inner_dup_tmp, quadruple_cal_cnt);
AscendC::And(dst1, dst1, inner_dup_tmp, quadruple_cal_cnt);
}
inline __aicore__ void CastTensorToHalf(const AscendC::LocalTensor<int32_t> &dst,
const AscendC::LocalTensor<uint16_t> &src0_bits,
const AscendC::LocalTensor<uint16_t> &src1_bits,
const AscendC::LocalTensor<uint32_t> &inner_dup, uint32_t cal_cnt) {
AscendC::PipeBarrier<PIPE_V>();
uint32_t quadruple_cal_cnt = 4 * cal_cnt;
uint32_t double_cal_cnt = 2 * cal_cnt;
int32_t repeat_times = AscendC::DivCeil(cal_cnt * sizeof(int64_t), ONE_REPEAT_BYTE_SIZE);
AscendC::Duplicate(inner_dup, uint32_t(0x0000BC00), double_cal_cnt);
AscendC::LocalTensor<uint16_t> inner_dup_tmp = inner_dup.ReinterpretCast<uint16_t>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(src1_bits, src0_bits, inner_dup_tmp, quadruple_cal_cnt);
AscendC::LocalTensor<int32_t> sub_tmp = inner_dup.ReinterpretCast<int32_t>();
AscendC::Duplicate(sub_tmp, 0, double_cal_cnt);
AscendC::Sub(dst, sub_tmp, src0_bits.ReinterpretCast<int32_t>(), double_cal_cnt);
AscendC::Duplicate(inner_dup, uint32_t(0x00003C00), double_cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(src0_bits, dst.ReinterpretCast<uint16_t>(), inner_dup_tmp, quadruple_cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Or(src0_bits, src1_bits, src0_bits, quadruple_cal_cnt);
uint32_t one_rpt_cnt = KernelUtils::RptSize<half>();
uint32_t rpt_cnt = quadruple_cal_cnt / one_rpt_cnt;
uint32_t tail_cnt = quadruple_cal_cnt - rpt_cnt * one_rpt_cnt;
AscendC::LocalTensor<half> dst_tmp = dst.ReinterpretCast<half>();
AscendC::LocalTensor<half> sign_tmp = src0_bits.ReinterpretCast<half>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::PairReduceSum(dst_tmp, sign_tmp, rpt_cnt, 128, 1, 1, 8);
if (tail_cnt != 0) {
AscendC::PairReduceSum(dst_tmp[rpt_cnt * one_rpt_cnt / 2], sign_tmp[rpt_cnt * one_rpt_cnt], 1, tail_cnt, 1, 1, 8);
}
}
inline __aicore__ void CalcWeightedTensor(const AscendC::LocalTensor<half> &dst, const AscendC::LocalTensor<half> &src,
const AscendC::LocalTensor<uint32_t> &inner_dup, const uint32_t cal_cnt,
const half weight0, const half weight1) {
AscendC::PipeBarrier<PIPE_V>();
AscendC::LocalTensor<half> inner_dup_tmp = inner_dup.ReinterpretCast<half>();
AscendC::Duplicate(inner_dup_tmp, weight0, 2 * cal_cnt);
uint32_t one_rpt_cnt = KernelUtils::RptSize<half>();
uint32_t rpt_times = 2 * cal_cnt / one_rpt_cnt;
uint32_t tail_rpt_cnt = 2 * cal_cnt - rpt_times * one_rpt_cnt;
uint64_t mask[2] = {uint64_t(0x5555555555555555), uint64_t(0x5555555555555555)};
AscendC::Duplicate(inner_dup_tmp, weight1, mask, rpt_times, 1, 8);
if (tail_rpt_cnt != 0) {
if (tail_rpt_cnt * sizeof(half) < ONE_REPEAT_BYTE_SIZE / 2) {
uint64_t mask_tail = 0b01;
for (uint32_t i = 1; i < tail_rpt_cnt; i++) {
mask_tail += (0b01 << (2 * i));
}
mask[0] = mask_tail;
mask[1] = 0;
} else {
mask[0] = uint64_t(0x5555555555555555);
uint32_t tail_cnt = tail_rpt_cnt - 128 / sizeof(half);
uint64_t mask_tail = 0b01;
for (uint32_t i = 1; i < tail_cnt; i++) {
mask_tail += (0b01 << (2 * i));
}
mask[1] = mask_tail;
}
AscendC::Duplicate(inner_dup_tmp[rpt_times * one_rpt_cnt], weight1, mask, 1, 1, 8);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(src, src, inner_dup_tmp, 2 * cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::PairReduceSum(dst, src, rpt_times, 128, 1, 1, 8);
if (tail_rpt_cnt != 0) {
AscendC::PairReduceSum(dst[rpt_times * one_rpt_cnt / 2], src[rpt_times * one_rpt_cnt], 1, tail_rpt_cnt, 1, 1, 8);
}
}
template <typename T>
inline __aicore__ void CompareSingleLoopGT(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
CMPMODE mode,
const uint32_t cal_cnt,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
uint32_t quadruple_cal_cnt = 4 * cal_cnt;
uint32_t double_cal_cnt = 2 * cal_cnt;
uint32_t offset_aligned = DivCeil(cal_cnt * sizeof(int64_t), ONE_BLK_SIZE) * ONE_BLK_SIZE;
uint32_t offset = 0;
AscendC::LocalTensor<uint16_t> src0_bits = KernelUtils::NewTensor<uint16_t>(tmp_buf, 0, quadruple_cal_cnt);
offset += offset_aligned;
AscendC::LocalTensor<uint16_t> src1_bits = KernelUtils::NewTensor<uint16_t>(tmp_buf, offset, quadruple_cal_cnt);
offset += offset_aligned;
AscendC::LocalTensor<int32_t> sub_res = KernelUtils::NewTensor<int32_t>(tmp_buf, offset, double_cal_cnt);
offset += offset_aligned;
AscendC::LocalTensor<half> sign_res = KernelUtils::NewTensor<half>(tmp_buf, offset, cal_cnt);
offset += DivCeil(offset_aligned / 4, ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<half> value_res = KernelUtils::NewTensor<half>(tmp_buf, offset, cal_cnt);
offset += DivCeil(offset_aligned / 4, ONE_BLK_SIZE) * ONE_BLK_SIZE;
AscendC::LocalTensor<uint32_t> inner_dup = KernelUtils::NewTensor<uint32_t>(tmp_buf, offset, double_cal_cnt);
AscendC::LocalTensor<uint16_t> inner_dup_tmp = inner_dup.ReinterpretCast<uint16_t>();
GetSignBitTensor(src0_bits, src1_bits, src0, src1, inner_dup, cal_cnt);
AscendC::LocalTensor<int32_t> src0_bits_tmp = src0_bits.ReinterpretCast<int32_t>();
AscendC::LocalTensor<int32_t> src1_bits_tmp = src1_bits.ReinterpretCast<int32_t>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, double_cal_cnt);
CastTensorToHalf(sub_res, src0_bits, src1_bits, inner_dup, cal_cnt);
CalcWeightedTensor(sign_res, sub_res.ReinterpretCast<half>(), inner_dup, cal_cnt, half(8), half(2));
AscendC::Duplicate(inner_dup, uint32_t(0x7FFFFFFF), 2 * cal_cnt);
AscendC::LocalTensor<uint16_t> src0_tmp = src0.ReinterpretCast<uint16_t>();
AscendC::LocalTensor<uint16_t> src1_tmp = src1.ReinterpretCast<uint16_t>();
AscendC::PipeBarrier<PIPE_V>();
AscendC::And(src0_bits, src0_tmp, inner_dup_tmp, quadruple_cal_cnt);
AscendC::And(src1_bits, src1_tmp, inner_dup_tmp, quadruple_cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Sub(src0_bits_tmp, src0_bits_tmp, src1_bits_tmp, double_cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Maxs(src0_bits_tmp, src0_bits_tmp, -1, double_cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(src0_bits_tmp, src0_bits_tmp, 1, double_cal_cnt);
CastTensorToHalf(sub_res, src0_bits, src1_bits, inner_dup, cal_cnt);
CalcWeightedTensor(value_res, sub_res.ReinterpretCast<half>(), inner_dup, cal_cnt, half(4), half(1));
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(sign_res, sign_res, value_res, cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
if (mode == CMPMODE::GT) {
AscendC::Maxs(sign_res, sign_res, half(0), cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(1), cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst, sign_res, AscendC::RoundMode::CAST_NONE, cal_cnt);
} else {
AscendC::Maxs(sign_res, sign_res, half(-1), cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mins(sign_res, sign_res, half(0), cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(sign_res, sign_res, half(1), cal_cnt);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(dst, sign_res, AscendC::RoundMode::CAST_NONE, cal_cnt);
}
}
template <typename T>
inline __aicore__ void CompareExtendGT(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
CMPMODE mode,
const uint32_t cal_cnt,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
const uint32_t max_buf_cnt = tmp_buf.GetSize() / (5 * sizeof(uint64_t));
const uint32_t one_max_rpt_cal_cnt = MAX_REPEAT_TIME * ONE_REPEAT_BYTE_SIZE / sizeof(int64_t);
uint32_t max_cnt = KernelUtils::Min(max_buf_cnt, one_max_rpt_cal_cnt);
max_cnt = max_cnt / ONE_BLK_SIZE * ONE_BLK_SIZE;
const uint32_t loop_num = cal_cnt / max_cnt;
uint32_t offset = 0;
for (uint32_t loop = 0; loop < loop_num; loop++) {
CompareSingleLoopGT(dst[offset], src0[offset], src1[offset], mode, max_cnt, tmp_buf);
offset += max_cnt;
}
uint32_t tail_cnt = cal_cnt - offset;
if (tail_cnt > 0) {
CompareSingleLoopGT(dst[offset], src0[offset], src1[offset], mode, tail_cnt, tmp_buf);
}
}
template <typename T>
inline __aicore__ void CompareExtend(const AscendC::LocalTensor<T> &dst,
const AscendC::LocalTensor<int64_t> &src0,
const AscendC::LocalTensor<int64_t> &src1,
CMPMODE mode,
const uint32_t cal_cnt,
AscendC::LocalTensor<uint8_t> &tmp_buf) {
if (mode != CMPMODE::EQ && mode != CMPMODE::NE && mode != CMPMODE::GT && mode != CMPMODE::GE && mode != CMPMODE::LE) {
ASSERT(false && "CompareExtend mode only support EQ/NE/GT/GE/LE when DataType is int64.");
}
if (mode == CMPMODE::EQ || mode == CMPMODE::NE) {
CompareExtendEQ(dst, src0, src1, mode, cal_cnt, tmp_buf);
} else if (mode == CMPMODE::LE) {
CompareExtendGT(dst, src1, src0, CMPMODE::GE, cal_cnt, tmp_buf);
} else {
CompareExtendGT(dst, src0, src1, mode, cal_cnt, tmp_buf);
}
}
#endif