* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file feeds_repeat.h
* \brief
*/
#ifndef FEEDS_REPEAT_H
#define FEEDS_REPEAT_H
#include "kernel_operator.h"
namespace FeedsRepeat {
using namespace AscendC;
template <typename T1, typename T2>
class FeedsRepeatND {
public:
__aicore__ inline FeedsRepeatND()
{}
__aicore__ inline void Init(
GM_ADDR feeds, GM_ADDR feeds_repeat_times, GM_ADDR y, const FeedsRepeatTilingData* __restrict tiling_data);
__aicore__ inline void Process();
private:
__aicore__ inline void RepeatTimesCast();
__aicore__ inline void ClearOutputSpace();
__aicore__ inline void ComputeStartDest();
__aicore__ inline void RepeatSingleRow();
__aicore__ inline void RepeatMultiRow();
protected:
TPipe pipe;
GlobalTensor<T1> feeds_gm;
GlobalTensor<T2> feeds_repeat_times_gm;
GlobalTensor<T1> y_gm;
TQue<QuePosition::VECIN, 1> in_queue;
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> in_out_queue;
TBuf<TPosition::VECCALC> feeds_repeat_times_float_buf;
TBuf<TPosition::VECCALC> end_sum_buf;
TBuf<TPosition::VECCALC> end_sum_int64_buf;
TBuf<TPosition::VECCALC> sum_result_buf;
TBuf<TPosition::VECCALC> sum_result_int64_buf;
LocalTensor<T1> row;
LocalTensor<T2> feeds_repeat_times_ub;
LocalTensor<float> feeds_repeat_times_float;
LocalTensor<float> end_sum;
LocalTensor<float> sum_result;
LocalTensor<int64_t> end_sum_int64;
LocalTensor<int64_t> sum_result_int64;
uint32_t length;
uint32_t length_aligned;
uint32_t start_row;
uint32_t end_row;
int64_t elem_row;
int64_t elem_per_loop;
int64_t core_index;
int64_t max_core_num;
int64_t core_per_group = 0;
int64_t core_moreover = 0;
int64_t empty_size;
int64_t row_per_core;
int64_t row_left;
int64_t repeat_times = 0;
int64_t y_start_index = 0;
int64_t end_index = 0;
int64_t group_num = 0;
int64_t group_id = -1;
int64_t block_in_group = 1;
int64_t id_in_group = -1;
int64_t index;
uint32_t start_aligned;
int64_t loop_num;
int64_t loop_left;
int64_t start_index;
int64_t elem_loop;
int64_t left_index;
const int64_t align_num = 32;
const int64_t cast_num = 4;
event_t event_v_to_s;
event_t event_mte3_to_v;
DataCopyExtParams copyParams{1, 0, 0, 0, 0};
DataCopyExtParams copyParams_repeat{1, 0, 0, 0, 0};
DataCopyExtParams copyParams_left{1, 0, 0, 0, 0};
DataCopyPadExtParams<T1> padParams{false, 0, 0, 0};
DataCopyPadExtParams<T2> padParams_repeat{false, 0, 0, 0};
SumParams sumParams{1, 0, 0};
SumParams sumParams_total{1, 0, 0};
};
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::Init(
GM_ADDR feeds, GM_ADDR feeds_repeat_times, GM_ADDR y, const FeedsRepeatTilingData* __restrict tiling_data)
{
elem_row = tiling_data->elem_row;
elem_per_loop = tiling_data->elem_per_loop;
length = tiling_data->length;
length_aligned = tiling_data->length_aligned;
max_core_num = tiling_data->max_core_num;
core_per_group = tiling_data->core_per_group;
core_moreover = tiling_data->core_moreover;
empty_size = tiling_data->empty_size;
row_per_core = tiling_data->row_per_core;
row_left = tiling_data->row_left;
core_index = GetBlockIdx();
feeds_gm.SetGlobalBuffer((__gm__ T1*)feeds);
feeds_repeat_times_gm.SetGlobalBuffer((__gm__ T2*)feeds_repeat_times);
y_gm.SetGlobalBuffer((__gm__ T1*)y);
pipe.InitBuffer(in_out_queue, 2, elem_per_loop * sizeof(T1));
pipe.InitBuffer(in_queue, 1, length_aligned * sizeof(T2));
pipe.InitBuffer(feeds_repeat_times_float_buf, length_aligned * sizeof(float));
pipe.InitBuffer(end_sum_buf, align_num);
pipe.InitBuffer(end_sum_int64_buf, align_num);
pipe.InitBuffer(sum_result_buf, align_num);
pipe.InitBuffer(sum_result_int64_buf, align_num);
event_mte3_to_v = static_cast<event_t>(pipe.FetchEventID<HardEvent::MTE3_V>());
event_v_to_s = static_cast<event_t>(pipe.FetchEventID<HardEvent::V_S>());
}
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::RepeatTimesCast()
{
copyParams_repeat.blockLen = length * (uint32_t)sizeof(T2);
feeds_repeat_times_ub = in_queue.AllocTensor<T2>();
DataCopyPad(feeds_repeat_times_ub, feeds_repeat_times_gm, copyParams_repeat, padParams_repeat);
in_queue.EnQue(feeds_repeat_times_ub);
feeds_repeat_times_ub = in_queue.DeQue<T2>();
Cast(feeds_repeat_times_float, feeds_repeat_times_ub, RoundMode::CAST_RINT, length_aligned);
in_queue.FreeTensor(feeds_repeat_times_ub);
}
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::ClearOutputSpace()
{
end_sum = end_sum_buf.Get<float>();
end_sum_int64 = end_sum_int64_buf.Get<int64_t>();
sumParams_total.inner = length_aligned;
sumParams_total.n = length;
Sum(end_sum, feeds_repeat_times_float, sumParams_total);
Cast(end_sum_int64, end_sum, RoundMode::CAST_RINT, cast_num);
SetFlag<HardEvent::V_S>(event_v_to_s);
WaitFlag<HardEvent::V_S>(event_v_to_s);
end_index += end_sum_int64.GetValue(0);
int64_t empty_per_core = ((empty_size - end_index) * elem_row) / max_core_num;
int64_t empty_left = ((empty_size - end_index) * elem_row) % max_core_num;
if ((core_index == 0) && ((empty_per_core + empty_left) != 0)) {
InitOutput<T1>(y_gm[end_index * elem_row], empty_per_core + empty_left, 0);
}
if ((core_index != 0) && (empty_per_core != 0)) {
InitOutput<T1>(y_gm[end_index * elem_row + core_index * empty_per_core + empty_left], empty_per_core, 0);
}
SetFlag<HardEvent::MTE3_V>(event_mte3_to_v);
WaitFlag<HardEvent::MTE3_V>(event_mte3_to_v);
}
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::ComputeStartDest()
{
sum_result = sum_result_buf.Get<float>();
sum_result_int64 = sum_result_int64_buf.Get<int64_t>();
if (start_row != 0) {
sumParams.inner = start_aligned;
sumParams.n = start_row;
Sum(sum_result, feeds_repeat_times_float, sumParams);
Cast(sum_result_int64, sum_result, RoundMode::CAST_RINT, align_num);
SetFlag<HardEvent::V_S>(event_v_to_s);
WaitFlag<HardEvent::V_S>(event_v_to_s);
y_start_index += sum_result_int64.GetValue(0);
}
}
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::RepeatSingleRow()
{
feeds_repeat_times_float = feeds_repeat_times_float_buf.Get<float>();
repeat_times = feeds_repeat_times_ub.GetValue(start_row);
int64_t repeat_left = repeat_times % block_in_group;
int64_t repeat_start =
id_in_group * (repeat_times / block_in_group) + (id_in_group < repeat_left ? id_in_group : repeat_left);
int64_t repeat_end = repeat_start + (repeat_times / block_in_group) + (id_in_group < repeat_left ? 1 : 0);
for (int j = 0; j < loop_num; j++) {
row = in_out_queue.AllocTensor<T1>();
DataCopyPad(row, feeds_gm[start_index + elem_per_loop * j], copyParams, padParams);
in_out_queue.EnQue<QuePosition::VECIN, QuePosition::VECOUT, T1>(row);
row = in_out_queue.DeQue<QuePosition::VECIN, QuePosition::VECOUT, T1>();
for (int k = repeat_start; k < repeat_end; k++) {
DataCopyPad(y_gm[(y_start_index + k) * elem_row + elem_per_loop * j], row, copyParams);
}
in_out_queue.FreeTensor(row);
}
if (loop_left != 0) {
copyParams_left.blockLen = (uint32_t)(loop_left * sizeof(T1));
row = in_out_queue.AllocTensor<T1>();
DataCopyPad(row, feeds_gm[left_index], copyParams_left, padParams);
in_out_queue.EnQue<QuePosition::VECIN, QuePosition::VECOUT, T1>(row);
row = in_out_queue.DeQue<QuePosition::VECIN, QuePosition::VECOUT, T1>();
for (int k = repeat_start; k < repeat_end; k++) {
DataCopyPad(y_gm[(y_start_index + k) * elem_row + elem_loop], row, copyParams_left);
}
in_out_queue.FreeTensor(row);
}
}
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::RepeatMultiRow()
{
int64_t loop_start;
for (int i = start_row; i < end_row; i++) {
loop_start = elem_row * i;
repeat_times = feeds_repeat_times_ub.GetValue(i);
for (int j = 0; j < loop_num; j++) {
row = in_out_queue.AllocTensor<T1>();
DataCopyPad(row, feeds_gm[loop_start + elem_per_loop * j], copyParams, padParams);
in_out_queue.EnQue<QuePosition::VECIN, QuePosition::VECOUT, T1>(row);
row = in_out_queue.DeQue<QuePosition::VECIN, QuePosition::VECOUT, T1>();
for (int k = 0; k < repeat_times; k++) {
DataCopyPad(y_gm[(y_start_index + k) * elem_row + elem_per_loop * j], row, copyParams);
}
in_out_queue.FreeTensor(row);
}
if (loop_left != 0) {
copyParams_left.blockLen = (uint32_t)(loop_left * sizeof(T1));
row = in_out_queue.AllocTensor<T1>();
DataCopyPad(row, feeds_gm[loop_start + elem_loop], copyParams_left, padParams);
in_out_queue.EnQue<QuePosition::VECIN, QuePosition::VECOUT, T1>(row);
row = in_out_queue.DeQue<QuePosition::VECIN, QuePosition::VECOUT, T1>();
for (int k = 0; k < repeat_times; k++) {
DataCopyPad(y_gm[(y_start_index + k) * elem_row + elem_loop], row, copyParams_left);
}
in_out_queue.FreeTensor(row);
}
y_start_index += repeat_times;
}
}
template <typename T1, typename T2>
__aicore__ inline void FeedsRepeatND<T1, T2>::Process()
{
if (core_per_group != 0) {
group_num = (max_core_num - core_moreover) / core_per_group;
if (core_index < core_moreover * (core_per_group + 1)) {
block_in_group = core_per_group + 1;
group_id = core_index / block_in_group;
id_in_group = core_index % block_in_group;
} else {
block_in_group = core_per_group;
group_id = (core_index - core_moreover) / core_per_group;
id_in_group = (core_index - core_moreover) % core_per_group;
}
}
index = core_per_group == 0 ? core_index : group_id;
start_row = index * row_per_core + (index < row_left ? index : row_left);
end_row = start_row + row_per_core + (index < row_left ? 1 : 0);
start_aligned = (start_row * sizeof(float) + align_num - 1) / align_num * align_num / sizeof(float);
loop_num = elem_row / elem_per_loop;
loop_left = elem_row % elem_per_loop;
start_index = elem_row * start_row;
elem_loop = elem_per_loop * loop_num;
left_index = start_index + elem_loop;
feeds_repeat_times_float = feeds_repeat_times_float_buf.Get<float>();
RepeatTimesCast();
ClearOutputSpace();
ComputeStartDest();
if (elem_row <= elem_per_loop) {
copyParams.blockLen = (uint32_t)(elem_row * sizeof(T1));
} else {
copyParams.blockLen = (uint32_t)(elem_per_loop * sizeof(T1));
}
if (core_per_group != 0) {
RepeatSingleRow();
} else {
RepeatMultiRow();
}
}
}
#endif