* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __ASCENDC_API_GATHER_H__
#define __ASCENDC_API_GATHER_H__
template <typename T1, typename T2>
inline __aicore__ void GatherUbNotFullLoad(const LocalTensor<T1> &dst, const GlobalTensor<T1> &src0,
const GlobalTensor<T2> &src1, uint32_t vectorized_axis_size,
LocalTensor<uint8_t> &tmp_buf) {
AscendC::DataCopyExtParams param_src1;
param_src1.blockCount = 1;
param_src1.blockLen = vectorized_axis_size * sizeof(T2);
param_src1.srcStride = 0;
param_src1.dstStride = 0;
auto src1_ub = tmp_buf[0].template ReinterpretCast<T2>();
src1_ub.SetSize(vectorized_axis_size);
AscendC::DataCopyPad(src1_ub, src1, param_src1, AscendC::DataCopyPadExtParams<T2>());
int32_t event_id_mte2_to_s = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_S));
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_mte2_to_s);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_mte2_to_s);
for (uint32_t i = 0; i < vectorized_axis_size; i++) {
T2 index = src1_ub.GetValue(i);
T1 value = src0.GetValue(static_cast<uint64_t>(index));
dst.SetValue(i, value);
}
}
template <typename T1, typename T2>
inline __aicore__ void GatherUbFullLoad(const LocalTensor<T1> &dst, const GlobalTensor<T1> &src0,
const GlobalTensor<T2> &src1, uint32_t param_last_axis_size,
uint32_t vectorized_axis_size, uint32_t param_last_axis_bytes,
uint32_t vectorized_axis_bytes, LocalTensor<uint8_t> &tmp_buf) {
AscendC::DataCopyExtParams param_src0;
param_src0.blockCount = 1;
param_src0.blockLen = param_last_axis_size * sizeof(T1);
param_src0.srcStride = 0;
param_src0.dstStride = 0;
AscendC::DataCopyExtParams param_src1;
param_src1.blockCount = 1;
param_src1.blockLen = vectorized_axis_size * sizeof(T2);
param_src1.srcStride = 0;
param_src1.dstStride = 0;
auto src0_ub = tmp_buf[0].template ReinterpretCast<T1>();
src0_ub.SetSize(param_last_axis_size);
uint32_t offset = param_last_axis_bytes;
auto src1_ub = tmp_buf[offset].template ReinterpretCast<T2>();
src1_ub.SetSize(vectorized_axis_size);
offset += vectorized_axis_bytes;
int32_t event_id_mte2_to_v = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_V));
int32_t event_id_v_to_mte2 = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE2));
int32_t event_id_v_to_mte3 = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3));
int32_t event_id_mte3_to_v = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_V));
int32_t event_id_mte3_to_mte2 = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_MTE2));
int32_t event_id_mte2_to_mte3 = static_cast<int32_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE2_MTE3));
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_v_to_mte2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_v_to_mte2);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id_mte3_to_mte2);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id_mte3_to_mte2);
AscendC::DataCopyPad(src0_ub, src0, param_src0, AscendC::DataCopyPadExtParams<T1>());
AscendC::DataCopyPad(src1_ub, src1, param_src1, AscendC::DataCopyPadExtParams<T2>());
int32_t data_bytes_t1 = sizeof(T1);
if constexpr (AscendC::IsSameType<T2, int64_t>::value) {
auto src1_tmp_ub = tmp_buf[offset].template ReinterpretCast<int32_t>();
src1_tmp_ub.SetSize(vectorized_axis_size);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_mte2_to_v);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_mte2_to_v);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id_mte3_to_v);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id_mte3_to_v);
AscendC::Cast(src1_tmp_ub, src1_ub, AscendC::RoundMode::CAST_NONE, vectorized_axis_size);
AscendC::Muls(src1_tmp_ub, src1_tmp_ub, data_bytes_t1, vectorized_axis_size);
auto src1_uint32t_ub = src1_tmp_ub.template ReinterpretCast<uint32_t>();
AscendC::Gather(dst, src0_ub, src1_uint32t_ub, 0, vectorized_axis_size);
} else {
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_mte2_to_v);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_mte2_to_v);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id_mte3_to_v);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id_mte3_to_v);
AscendC::Muls(src1_ub, src1_ub, data_bytes_t1, vectorized_axis_size);
auto src1_uint32t_ub = src1_ub.template ReinterpretCast<uint32_t>();
AscendC::Gather(dst, src0_ub, src1_uint32t_ub, 0, vectorized_axis_size);
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id_v_to_mte3);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id_v_to_mte3);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id_mte2_to_mte3);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id_mte2_to_mte3);
}
template <typename T1, typename T2>
inline __aicore__ void GatherExtend(const LocalTensor<T1> &dst, const GlobalTensor<T1> &src0,
const GlobalTensor<T2> &src1, uint32_t param_last_axis_size,
uint32_t vectorized_axis_size, LocalTensor<uint8_t> &tmp_buf) {
if (tmp_buf.GetSize() < vectorized_axis_size * sizeof(T2)) {
for (uint32_t i = 0; i < vectorized_axis_size; i++) {
T2 index = src1.GetValue(i);
T1 value = src0.GetValue(static_cast<uint64_t>(index));
dst.SetValue(i, value);
}
} else {
uint32_t param_last_axis_bytes = AscendC::AlignUp(param_last_axis_size * sizeof(T1), ONE_BLK_SIZE);
uint32_t vectorized_axis_bytes = AscendC::AlignUp(vectorized_axis_size * sizeof(T2), ONE_BLK_SIZE);
uint32_t ub_full_load_bytes = 0;
if constexpr (AscendC::IsSameType<T2, int64_t>::value) {
ub_full_load_bytes = param_last_axis_bytes + vectorized_axis_bytes + vectorized_axis_size * sizeof(uint32_t);
} else {
ub_full_load_bytes = param_last_axis_bytes + vectorized_axis_bytes;
}
if (tmp_buf.GetSize() <
ub_full_load_bytes) {
GatherUbNotFullLoad(dst, src0, src1, vectorized_axis_size, tmp_buf);
} else {
GatherUbFullLoad(dst, src0, src1, param_last_axis_size, vectorized_axis_size, param_last_axis_bytes,
vectorized_axis_bytes, tmp_buf);
}
}
}
#endif