* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __ASCENDC_API_ARGMAX_WITH_VALUE_H__
#define __ASCENDC_API_ARGMAX_WITH_VALUE_H__
namespace AscendC {
* @brief UpdateMaxIndexAndValue - 更新最大值及其索引(tensor版本)
* @note 使用CompareExtend和Where算子来更新保存的最大值和索引
* @param max_index_temp 当次计算的索引tensor(局部索引,会被修改为加上offset后的值)
* @param max_value_temp 当次计算的值tensor
* @param max_index_saved 之前保存的索引tensor(会被更新)
* @param max_value_saved 之前保存的值tensor(会被更新)
* @param offset 索引偏移量(标量),用于将局部索引转换为全局索引
* @param tmp_buf 临时buffer
* @param cal_cnt 计算的数据量
*
* @note max_index_temp 会被原地修改为 max_index_temp + offset,调用后不应再使用原始值
* @note 仅支持 cal_cnt > 1 的 tensor 情况
*/
template <typename T>
inline __aicore__ void
UpdateMaxIndexAndValue(const LocalTensor<int64_t> &max_index_temp,
const LocalTensor<T> &max_value_temp,
const LocalTensor<int64_t> &max_index_saved,
const LocalTensor<T> &max_value_saved,
const int64_t offset, LocalTensor<uint8_t> &tmp_buf,
const uint32_t cal_cnt) {
uint32_t mask_size = ((cal_cnt + 7) / 8 + 31) / 32 * 32;
uint32_t index32_size = cal_cnt * sizeof(int32_t);
LocalTensor<uint8_t> index32_tmp_uint8 = tmp_buf[0];
index32_tmp_uint8.SetSize(index32_size);
LocalTensor<int32_t> index32_tmp = index32_tmp_uint8.ReinterpretCast<int32_t>();
LocalTensor<uint8_t> mask = tmp_buf[index32_size];
mask.SetSize(mask_size);
LocalTensor<uint8_t> common_tmp_buf = tmp_buf[index32_size + mask_size];
Cast(index32_tmp, max_index_temp, RoundMode::CAST_NONE, cal_cnt);
Adds(index32_tmp, index32_tmp, static_cast<int32_t>(offset), cal_cnt);
Cast(max_index_temp, index32_tmp, RoundMode::CAST_NONE, cal_cnt);
CompareExtend(mask, max_value_temp, max_value_saved, CMPMODE::GT, cal_cnt, common_tmp_buf);
Where(max_value_saved, mask, max_value_temp, max_value_saved, cal_cnt, common_tmp_buf);
Where(max_index_saved, mask, max_index_temp, max_index_saved, cal_cnt, common_tmp_buf);
}
* @brief 在64个元素中找到最大值及其索引(标量归约)
*
* @details
* 算法逻辑:
* - 遍历64个元素,逐个比较
* - 如果 current > max_value,则更新 max_value 和 max_index
* - 特殊处理NaN:使用 !(current <= max_value) && (max_value == max_value)
* - 如果 max_value 是 NaN,则 max_value == max_value 为 false,不更新
* - 否则,如果 current 是 NaN,则 !(current <= max_value) 为 true,会更新
* - 这样可以正确处理包含NaN情况
*
* 数学公式:
* max_index = argmax_i src[i] (i ∈ [0, 63])
* max_value = max_i src[i] (i ∈ [0, 63])
*
* @param dst_index 输出:最大值的索引(作为float返回)
* @param dst_value 输出:最大值
* @param src 输入:64个元素的tensor
*
* @note 正确处理NaN:如果所有元素都是NaN,返回第一个NaN
*/
template <typename T>
inline __aicore__ void ScalarReduceMax64(float &dst_index, T &dst_value,
const LocalTensor<T> &src) {
T max_value = src(0);
int32_t max_index = 0;
for (uint32_t i = 1; i < 64; ++i) {
T current = src(i);
max_index = (!(current <= max_value) && (max_value == max_value)) ? i : max_index;
max_value = (!(current <= max_value) && (max_value == max_value)) ? current : max_value;
}
dst_index = static_cast<float>(max_index);
dst_value = max_value;
}
* @brief 在64个元素中根据mask找到最小值及其索引(标量归约)
*
* @details
* 算法逻辑:
* - mask用于标记哪些元素是有效的(mask的第i位为1表示src[i]有效)
* - 只考虑mask中为1的位对应的元素
* - 如果min_index == -1,说明还没有找到有效元素,直接使用当前元素
* - 否则,如果 current < min_value,则更新
*
* 数学公式:
* 有效索引集合 S = {i | ((mask >> i) & 1) == 1, i ∈ [0, 63]}
* min_index = argmin_i src[i] (i ∈ S)
* min_value = min_i src[i] (i ∈ S)
*
* @param dst_index 输出:最小值的索引(作为float返回)
* @param dst_value 输出:最小值
* @param src 输入:64个元素的tensor
* @param mask 输入:64位掩码,标记哪些元素有效
*
* @note src不会包含NaN值,故不需要特殊处理
*/
template <typename T>
inline __aicore__ void ScalarReduceMin64(float &dst_index, T &dst_value,
const LocalTensor<T> &src,
const uint64_t mask) {
T min_value;
int32_t min_index = -1;
for (uint32_t i = 0; i < 64; ++i) {
T current = src(i);
min_value = (((mask >> i) & 1) && (min_index == -1)) ? current : min_value;
min_index = (((mask >> i) & 1) && (min_index == -1 || current < min_value)) ? i : min_index;
min_value = (((mask >> i) & 1) && (current < min_value)) ? current : min_value;
}
dst_index = static_cast<float>(min_index);
dst_value = min_value;
}
* @brief Argmax with value - AR模式,last维度 <= 64
*
* @details
* 算法逻辑:
* - 对每个first维度的元素,在last维度中找最大值及其索引
* - 对于float类型,需要特殊处理±0的情况(+0和-0在比较时相等)
*
* 步骤:
* 1. 处理±0:将所有0值统一为+0(通过比较 src != 0,然后选择)
* 2. 初始化tail_tmp为最小值(-INFINITY或INT32_MIN)
* 3. 复制当前first维度的数据到tail_tmp
* 4. 使用ScalarReduceMax64在64个元素中找最大值及其索引
* 5. 将结果存入dst_index和dst_value
*
* 数学公式:
* 对于每个 i ∈ [0, repeat_times-1]:
* dst_index[i] = argmax_j src[i*last + j] (j ∈ [0, last-1])
* dst_value[i] = max_j src[i*last + j] (j ∈ [0, last-1])
*
* @param dst_index 输出:每个first维度的最大值索引
* @param dst_value 输出:每个first维度的最大值
* @param src 输入:2D tensor,形状为 [repeat_times, last]
* @param shared_tmp_buffer 临时buffer
* @param repeat_times first维度的元素个数
* @param last last维度的元素个数(<= 64)
*
* @note 对于float类型,正确处理±0和NaN
*/
template <typename T, bool isReuseSource = false>
inline __aicore__ void ArgmaxLEOneRepeatWithValue(
const LocalTensor<int64_t> &dst_index, const LocalTensor<T> &dst_value,
const LocalTensor<T> &src, const LocalTensor<uint8_t> &shared_tmp_buffer,
const uint32_t repeat_times, const uint32_t last) {
LocalTensor<T> tail_tmp = shared_tmp_buffer.ReinterpretCast<T>();
LocalTensor<uint8_t> tmp_mask = shared_tmp_buffer[256];
uint32_t copy_num = last;
for (uint32_t i = 0; i < repeat_times; ++i) {
if constexpr (std::is_same_v<T, float>) {
AscendC::PipeBarrier<PIPE_V>();
Duplicate(tail_tmp, 0.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask, src, tail_tmp, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Select(src, tmp_mask, tail_tmp, src, AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::PipeBarrier<PIPE_V>();
if constexpr (std::is_same_v<T, float>) {
Duplicate(tail_tmp, -INFINITY, 64);
} else if constexpr (std::is_same_v<T, int32_t>) {
Duplicate(tail_tmp, INT32_MIN, 64);
}
AscendC::PipeBarrier<PIPE_V>();
DataCopy(tail_tmp, src[i * last], copy_num);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
float max_index;
T max_value;
ScalarReduceMax64<T>(max_index, max_value, tail_tmp);
int32_t index = static_cast<int32_t>(max_index);
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
dst_index(i) = static_cast<int64_t>(index);
dst_value(i) = static_cast<T>(max_value);
}
}
* @brief Argmax with value - AR模式,last维度 > 64
*
* @details
* 算法逻辑:
* - 通过滑动窗口的方式,在每64元素的块间逐块比较并跟踪最大值和索引
* - 使用mask机制记录最大值出现的位置,避免丢失索引信息
*
* 步骤:
* 1. 处理±0:将所有0值统一为+0(通过比较 src != 0,然后选择)
* 2. 初始化max_before为当前行的第一个64元素块
* 3. 遍历后续块(repeat_times_one_row - 2个完整块):
* a. max_after = Max(max_before, next_block)
* b. 生成mask:mask1 = (max_before == max_before), mask2 = (max_after == max_before)
* c. 计算有效mask:mask2 = mask2 | (~mask1),表示max_after中是从max_before继承的位置(或者NaN的位置)
* d. 更新索引:tmp_dst_tensor = mask2 ? tmp_dst_tensor : current_block_index
* e. max_before = max_after
* 4. 处理尾部(用最小值补齐不足64个元素的元素块)
* 5. 在tmp_dst_tensor中找索引最大值中block_index最小值对应的索引,得到相对索引 relative_index
* 6. 计算最终索引:final_index = block_index + relative_index
*
* 数学公式:
* 对于每个 i ∈ [0, first-1]:
* dst_index[i] = argmax_j src[i*last + j] (j ∈ [0, last-1])
* dst_value[i] = max_j src[i*last + j] (j ∈ [0, last-1])
*
* 内存布局说明:
* - 最小处理单位:64个元素
* - float/int32_t:64×4=256B(精确匹配)
* - uint8_t:64×1=64B(分配256B保持一致性)
* - 满足32字节对齐要求,DMA传输高效
*
* @param dst_index 输出:每个first维度的最大值全局索引
* @param dst_value 输出:每个first维度的最大值
* @param src 输入:2D tensor,形状为 [first, last]
* @param shared_tmp_buffer 临时buffer
* @param first first维度的元素个数
* @param last last维度的元素个数(> 64)
* @param repeat_times_one_row last维度按64分块的块数
*
* @note 使用mask机制确保在有多个最大值时,选择第一个出现的
*/
template <typename T, class pattern, bool isReuseSource = false>
__aicore__ void ArgmaxGTOneRepeatWithValue(
const LocalTensor<int64_t> &dst_index, const LocalTensor<T> &dst_value,
const LocalTensor<T> &src, const LocalTensor<uint8_t> &shared_tmp_buffer,
const uint32_t first, const uint32_t last,
const uint32_t repeat_times_one_row) {
LocalTensor<float> tmp_dst_tensor = shared_tmp_buffer.ReinterpretCast<float>();
LocalTensor<float> dup_index_tensor = shared_tmp_buffer[256].ReinterpretCast<float>();
LocalTensor<uint8_t> tmp_mask = shared_tmp_buffer[512].ReinterpretCast<uint8_t>();
LocalTensor<uint8_t> tmp_mask1 = shared_tmp_buffer[768].ReinterpretCast<uint8_t>();
LocalTensor<uint8_t> tmp_mask2 = shared_tmp_buffer[1024].ReinterpretCast<uint8_t>();
LocalTensor<T> max_before = shared_tmp_buffer[1280].ReinterpretCast<T>();
LocalTensor<T> max_after = shared_tmp_buffer[1536].ReinterpretCast<T>();
LocalTensor<T> tail_tmp = shared_tmp_buffer[1792].ReinterpretCast<T>();
uint32_t tail_num = last % 64;
tail_num = (tail_num == 0) ? 64 : tail_num;
uint32_t copy_num = tail_num;
uint32_t tail_ctrl = 1;
for (uint32_t i = 0; i < first; ++i) {
for (uint32_t j = 0; j < repeat_times_one_row; ++j) {
if constexpr (std::is_same_v<T, float>) {
AscendC::PipeBarrier<PIPE_V>();
Duplicate(tail_tmp, 0.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask, src[i * last + j * 64], tail_tmp, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Select(src[i * last + j * 64], tmp_mask, tail_tmp, src[i * last + j * 64], AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
}
}
}
for (uint32_t i = 0; i < first; ++i) {
Duplicate(tmp_dst_tensor, 0.0f, tmp_dst_tensor.GetSize());
AscendC::PipeBarrier<PIPE_V>();
DataCopy(max_before, src[i * last], 64);
AscendC::PipeBarrier<PIPE_V>();
for (uint32_t j = 0; j < repeat_times_one_row - 2; ++j) {
AscendC::PipeBarrier<PIPE_V>();
Duplicate(dup_index_tensor, (j + 1) * 64.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
Max(max_after, max_before, src[i * last + (j + 1) * 64], 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask1, max_before, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask2, max_after, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
for (uint32_t k = 0; k < 8; ++k) {
uint8_t val1 = tmp_mask1.GetValue(k);
uint8_t val2 = tmp_mask2.GetValue(k);
tmp_mask2.SetValue(k, val2 | (~val1));
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
Select(tmp_dst_tensor, tmp_mask2, tmp_dst_tensor, dup_index_tensor,
AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
DataCopy(max_before, max_after, 64);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::PipeBarrier<PIPE_V>();
for (uint32_t j = 0; j < tail_ctrl; ++j) {
if constexpr (std::is_same_v<T, float>) {
Duplicate(tail_tmp, -INFINITY, 64);
} else if constexpr (std::is_same_v<T, int32_t>) {
Duplicate(tail_tmp, INT32_MIN, 64);
}
AscendC::PipeBarrier<PIPE_V>();
DataCopy(tail_tmp, src[(i + 1) * last - copy_num], copy_num);
AscendC::PipeBarrier<PIPE_V>();
Duplicate(dup_index_tensor, (repeat_times_one_row - 1) * 64.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
Max(max_after, max_before, tail_tmp, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask1, max_before, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask2, max_after, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
for (uint32_t k = 0; k < 8; ++k) {
uint8_t val1 = tmp_mask1.GetValue(k);
uint8_t val2 = tmp_mask2.GetValue(k);
tmp_mask2.SetValue(k, val2 | (~val1));
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
Select(tmp_dst_tensor, tmp_mask2, tmp_dst_tensor, dup_index_tensor,
AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
float max_index;
T max_value;
ScalarReduceMax64<T>(max_index, max_value, max_after);
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
Compare(tmp_mask1, max_after, max_after, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Compares(tmp_mask2, max_after, max_value, AscendC::CMPMODE::EQ, 64);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
for (uint32_t k = 0; k < 8; ++k) {
uint8_t val1 = tmp_mask1.GetValue(k);
uint8_t val2 = tmp_mask2.GetValue(k);
tmp_mask2.SetValue(k, val2 | (~val1));
}
uint64_t min_mask[1] = {(static_cast<uint64_t>(tmp_mask2(0)) << 0) |
(static_cast<uint64_t>(tmp_mask2(1)) << 8) |
(static_cast<uint64_t>(tmp_mask2(2)) << 16) |
(static_cast<uint64_t>(tmp_mask2(3)) << 24) |
(static_cast<uint64_t>(tmp_mask2(4)) << 32) |
(static_cast<uint64_t>(tmp_mask2(5)) << 40) |
(static_cast<uint64_t>(tmp_mask2(6)) << 48) |
(static_cast<uint64_t>(tmp_mask2(7)) << 56)};
float tmp_index, tmp_value;
ScalarReduceMin64<float>(tmp_index, tmp_value, tmp_dst_tensor, min_mask[0]);
int32_t relative_index = static_cast<int32_t>(tmp_index);
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
float final_index = tmp_dst_tensor(relative_index) + relative_index * 1.0f;
dst_index(i) = static_cast<int64_t>(final_index);
dst_value(i) = static_cast<T>(max_value);
}
}
* @brief Argmax with value - RA模式(Reduce Axis)
*
* @details
* 算法逻辑:
* - 在last维度上计算argmax,结果形状为 [first, last]
* - 对每个last维度的位置,在first维度上找最大值及其索引
*
* 步骤:
* 1. 处理±0:将所有0值统一为+0(通过比较 src != 0,然后选择)
* 2. 特殊处理:如果first == 1,直接复制数据
* 3. 遍历每个64长度的元素块(除了尾部):
* a. 初始化max_before为当前块
* b. 遍历first维度:
* i. max_after = Max(max_before, next_element)
* ii. 生成mask并更新索引(与ArgmaxGTOneRepeatWithValue类似)
* iii. 将索引结果存入dst_index
* c. 将max_after存入dst_value
* 4. 处理尾部(不足64的元素)
*
* 数学公式:
* 对于每个 j ∈ [0, last-1]:
* dst_index[j] = argmax_i src[i*last + j] (i ∈ [0, first-1])
* dst_value[j] = max_i src[i*last + j] (i ∈ [0, first-1])
*
* 内存布局说明:
* - 最小处理单位:64个元素
* - float/int32_t:64×4=256B(精确匹配)
* - uint8_t:64×1=64B(分配256B保持一致性)
* - 满足32字节对齐要求,DMA传输高效
*
* @param dst_index 输出:2D tensor,形状为 [first, last],存储每个位置的argmax索引
* @param dst_value 输出:2D tensor,形状为 [first, last],存储每个位置的最大值
* @param src 输入:2D tensor,形状为 [first, last]
* @param shared_tmp_buffer 临时buffer
* @param first first维度的元素个数
* @param last last维度的元素个数
* @param repeat_times_one_row last维度按64分块的块数
*/
template <typename T, class pattern, bool isReuseSource = false>
__aicore__ void ArgmaxRAWithValue(const LocalTensor<int64_t> &dst_index,
const LocalTensor<T> &dst_value,
const LocalTensor<T> &src,
const LocalTensor<uint8_t> &shared_tmp_buffer,
const uint32_t first, const uint32_t last,
const uint32_t repeat_times_one_row) {
LocalTensor<float> tmp_dst_tensor = shared_tmp_buffer.ReinterpretCast<float>();
LocalTensor<float> dup_index_tensor = shared_tmp_buffer[256].ReinterpretCast<float>();
LocalTensor<uint8_t> tmp_mask = shared_tmp_buffer[512].ReinterpretCast<uint8_t>();
LocalTensor<uint8_t> tmp_mask1 = shared_tmp_buffer[768].ReinterpretCast<uint8_t>();
LocalTensor<uint8_t> tmp_mask2 = shared_tmp_buffer[1024].ReinterpretCast<uint8_t>();
LocalTensor<T> max_before = shared_tmp_buffer[1280].ReinterpretCast<T>();
LocalTensor<T> max_after = shared_tmp_buffer[1536].ReinterpretCast<T>();
LocalTensor<T> tail_tmp = shared_tmp_buffer[1792].ReinterpretCast<T>();
Duplicate(tmp_dst_tensor, 0.0f, tmp_dst_tensor.GetSize());
uint32_t tail_num = last % 64;
tail_num = (tail_num == 0) ? 64 : tail_num;
uint32_t copy_num = tail_num;
uint32_t tail_ctrl = 1;
uint32_t tail_offset = last - tail_num;
for (uint32_t i = 0; i < first; ++i) {
for (uint32_t j = 0; j < repeat_times_one_row; ++j) {
if constexpr (std::is_same_v<T, float>) {
AscendC::PipeBarrier<PIPE_V>();
Duplicate(tail_tmp, 0.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask, src[i * last + j * 64], tail_tmp, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Select(src[i * last + j * 64], tmp_mask, tail_tmp, src[i * last + j * 64], AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
}
}
}
if (first == 1) {
Duplicate(tmp_dst_tensor, 0.0f, dst_index.GetSize());
AscendC::PipeBarrier<PIPE_V>();
Cast<int64_t, float>(dst_index, tmp_dst_tensor, AscendC::RoundMode::CAST_RINT,
dst_index.GetSize());
DataCopy(dst_value, src, last);
} else {
for (uint32_t i = 0; i < repeat_times_one_row - 1; ++i) {
Duplicate(tmp_dst_tensor, 0.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
DataCopy(max_before, src[i * 64], 64);
AscendC::PipeBarrier<PIPE_V>();
for (uint32_t j = 0; j < first - 1; ++j) {
Duplicate(dup_index_tensor, (j + 1) * 1.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
Max(max_after, max_before, src[i * 64 + (j + 1) * last], 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask1, max_before, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask2, max_after, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
for (uint32_t k = 0; k < 8; ++k) {
uint8_t val1 = tmp_mask1.GetValue(k);
uint8_t val2 = tmp_mask2.GetValue(k);
tmp_mask2.SetValue(k, val2 | (~val1));
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
Select(tmp_dst_tensor, tmp_mask2, tmp_dst_tensor, dup_index_tensor,
AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
DataCopy(max_before, max_after, 64);
AscendC::PipeBarrier<PIPE_V>();
Cast<int64_t, float>(dst_index[i * 64], tmp_dst_tensor,
AscendC::RoundMode::CAST_RINT, 64);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::PipeBarrier<PIPE_V>();
DataCopy(dst_value[i * 64], max_after, 64);
AscendC::PipeBarrier<PIPE_V>();
}
for (uint32_t j = 0; j < tail_ctrl; ++j) {
Duplicate(tmp_dst_tensor, 0.0f, 64);
if constexpr (std::is_same_v<T, float>) {
Duplicate(tail_tmp, -INFINITY, 64);
} else if constexpr (std::is_same_v<T, int32_t>) {
Duplicate(tail_tmp, INT32_MIN, 64);
}
AscendC::PipeBarrier<PIPE_V>();
DataCopy(tail_tmp, src[tail_offset], tail_num);
AscendC::PipeBarrier<PIPE_V>();
DataCopy(max_before, tail_tmp, 64);
AscendC::PipeBarrier<PIPE_V>();
for (uint32_t k = 0; k < first - 1; ++k) {
AscendC::PipeBarrier<PIPE_V>();
Duplicate(dup_index_tensor, (k + 1) * 1.0f, 64);
AscendC::PipeBarrier<PIPE_V>();
if constexpr (std::is_same_v<T, float>) {
Duplicate(tail_tmp, -INFINITY, 64);
} else if constexpr (std::is_same_v<T, int32_t>) {
Duplicate(tail_tmp, INT32_MIN, 64);
}
AscendC::PipeBarrier<PIPE_V>();
DataCopy(tail_tmp, src[(k + 1) * last + tail_offset], tail_num);
AscendC::PipeBarrier<PIPE_V>();
Max(max_after, max_before, tail_tmp, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask1, max_before, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::PipeBarrier<PIPE_V>();
Compare(tmp_mask2, max_after, max_before, AscendC::CMPMODE::EQ, 64);
AscendC::SetFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(EVENT_ID0);
for (uint32_t k = 0; k < 8; ++k) {
uint8_t val1 = tmp_mask1.GetValue(k);
uint8_t val2 = tmp_mask2.GetValue(k);
tmp_mask2.SetValue(k, val2 | (~val1));
}
AscendC::SetFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(EVENT_ID0);
Select(tmp_dst_tensor, tmp_mask2, tmp_dst_tensor, dup_index_tensor,
AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, 64);
AscendC::PipeBarrier<PIPE_V>();
DataCopy(max_before, max_after, 64);
AscendC::PipeBarrier<PIPE_V>();
Cast<int64_t, float>(dst_index[tail_offset], tmp_dst_tensor,
AscendC::RoundMode::CAST_RINT, tail_num);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::PipeBarrier<PIPE_V>();
DataCopy(dst_value[tail_offset], max_after, tail_num);
AscendC::PipeBarrier<PIPE_V>();
}
}
}
* @brief Argmax with value 扩展函数 - 支持AR和RA两种模式
*
* @details
* 算法逻辑:
* - 根据reduce模式(AR或RA)和last维度大小,选择合适的实现
* - AR模式:在last维度上reduce,输出形状为 [first]
* - RA模式:在first维度上reduce,输出形状为 [last]
*
* 模式选择:
* 1. AR模式:
* a. 如果last <= 64:调用ArgmaxLEOneRepeatWithValue
* b. 如果last > 64:调用ArgmaxGTOneRepeatWithValue
* 2. RA模式:
* a. 调用ArgmaxRAWithValue
*
* @param dst_index 输出:argmax的索引
* @param dst_value 输出:最大值
* @param src 输入:2D tensor,形状为 [first, last]
* @param shared_tmp_buffer_buffer 临时buffer
* @param src_shape 输入tensor的形状,src_shape[0]=first, src_shape[1]=last
* @param src_inner_pad 是否内部padding(未使用)
*
* @tparam T 索引类型(必须是int64_t)
* @tparam U 值类型(float或int32_t)
* @tparam pattern reduce模式(Pattern::Reduce::AR或Pattern::Reduce::RA)
*
* @note 仅支持AICORE平台
* @note src仅支持float和int32_t数据类型
* @note 索引类型必须是int64_t
*/
template <typename T, typename U, class pattern>
__aicore__ inline void
ArgMaxWithValueExtend(const LocalTensor<T> &dst_index,
const LocalTensor<U> &dst_value, const LocalTensor<U> &src,
const LocalTensor<uint8_t> &shared_tmp_buffer,
const uint32_t src_shape[], bool src_inner_pad) {
if ASCEND_IS_AIC {
return;
}
static_assert(std::is_same_v<U, float> || std::is_same_v<U, int32_t>,
"ArgMaxWithValue src only support float and int32_t");
static_assert(std::is_same_v<T, int64_t>,
"ArgMaxWithValue dst_index only support int64_t dst type");
static_assert(
SupportType<pattern, Pattern::Reduce::AR, Pattern::Reduce::RA>(),
"failed to check Reduce pattern, only support AR/RA pattern!");
uint32_t first = src_shape[0];
uint32_t last = src_shape[1];
if constexpr (SupportType<pattern, Pattern::Reduce::AR>()) {
if (last <= 64) {
uint32_t repeat_times = first;
ArgmaxLEOneRepeatWithValue<U, false>(dst_index, dst_value, src,
shared_tmp_buffer, repeat_times, last);
} else {
uint32_t repeat_times_one_row = CeilDivision(last, 64);
ArgmaxGTOneRepeatWithValue<U, Pattern::Reduce::AR, false>(
dst_index, dst_value, src, shared_tmp_buffer, first, last,
repeat_times_one_row);
}
} else if constexpr (SupportType<pattern, Pattern::Reduce::RA>()) {
uint32_t repeat_times_one_row = CeilDivision(last, 64);
ArgmaxRAWithValue<U, Pattern::Reduce::RA, false>(dst_index, dst_value, src,
shared_tmp_buffer, first,
last, repeat_times_one_row);
}
}
}
#endif