* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __ASCENDC_API_BROADCAST_H__
#define __ASCENDC_API_BROADCAST_H__
constexpr uint32_t BRCB_ONE_SIZE_A = 8;
constexpr uint32_t BRCB_HALF_MAX_REPEATE_TIMES_A = 254;
constexpr uint32_t BRCB_FLOAT_MAX_REPEATE_TIMES_A = 255;
constexpr uint32_t GATHER_MASK_PATTERN_A = 7;
template <typename T>
inline __aicore__ void BroadcastFirstDim(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<uint8_t> &tmp_buf) {
uint32_t dst_inner_offset = dst_k * (src_z == 0 ? 1 : src_z);
uint32_t one_blk_num = KernelUtils::BlkSize<T>();
AscendC::SetMaskCount();
if (dst_inner_offset <= one_blk_num) {
uint32_t total_count = dst_m * dst_inner_offset;
AscendC::SetVectorMask<T, MaskMode::COUNTER>(total_count);
AscendC::CopyRepeatParams repeat_params{1, 0, 1, 0};
AscendC::Copy<T, false>(dst, src, MASK_PLACEHOLDER, 1, repeat_params);
} else {
const uint32_t max_rpt_cnt = dst_m / MAX_REPEAT_TIME;
uint32_t calc_size = 0;
AscendC::SetVectorMask<T, MaskMode::COUNTER>(dst_inner_offset);
const uint16_t dst_rpt_stride = dst_inner_offset * sizeof(T) / ONE_BLK_SIZE;
AscendC::CopyRepeatParams repeat_params{1, 1, dst_rpt_stride, 0};
const uint32_t max_rpt_calc_num = dst_inner_offset * MAX_REPEAT_TIME;
for (uint32_t idx = 0; idx < max_rpt_cnt; idx++) {
AscendC::Copy<T, false>(dst[calc_size], src, MASK_PLACEHOLDER, MAX_REPEAT_TIME, repeat_params);
calc_size += max_rpt_calc_num;
}
uint32_t tail_rpt_times = dst_m - max_rpt_cnt * MAX_REPEAT_TIME;
if (tail_rpt_times != 0) {
AscendC::Copy<T, false>(dst[calc_size], src, MASK_PLACEHOLDER, tail_rpt_times, repeat_params);
}
}
AscendC::SetMaskNorm();
AscendC::ResetMask();
}
template <typename T>
inline __aicore__ void BroadcastMiddleDim(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<uint8_t> &tmp_buf) {
uint32_t inner_offset = src_z == 0 ? 1 : src_z;
uint32_t dst_inner_offset = dst_k * inner_offset;
constexpr uint32_t one_blk_num = ONE_BLK_SIZE / sizeof(T);
constexpr uint32_t outer_dim_max_loop = MAX_REPEAT_TIME * one_blk_num;
const uint32_t src_inner_offset = src_k * inner_offset;
if (src_inner_offset >= outer_dim_max_loop) {
for (uint32_t i = 0; i < dst_m; i++) {
AscendC::DataCopy(dst[i * dst_inner_offset], src, dst_inner_offset);
}
AscendC::PipeBarrier<PIPE_V>();
} else {
AscendC::Duplicate(tmp_buf.ReinterpretCast<uint16_t>(), (uint16_t)0, ONE_BLK_SIZE / sizeof(uint16_t));
AscendC::PipeBarrier<PIPE_V>();
int32_t dtype_count = 1;
if constexpr (sizeof(T) == sizeof(float)) {
dtype_count = 2;
}
uint32_t loop_cnt = dst_m / 8U;
uint32_t rpt = dst_inner_offset / one_blk_num;
AscendC::SetMaskNorm();
AscendC::SetVectorMask<uint16_t, MaskMode::NORMAL>(128);
uint32_t dst_blk_stride = dst_inner_offset * dtype_count / 16U;
BinaryRepeatParams binary_params((uint8_t)dst_blk_stride, 0, 0, 1, 1, 0);
uint32_t dst_offset = 0;
for (uint32_t i = 0; i < loop_cnt; i++) {
AscendC::Or<uint16_t, false>(dst[dst_offset].template ReinterpretCast<uint16_t>(),
src.template ReinterpretCast<uint16_t>(),
tmp_buf.template ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, rpt, binary_params);
dst_offset += dst_inner_offset * 8;
}
uint32_t loop_tail = dst_m - loop_cnt * 8;
if (loop_tail > 0) {
AscendC::SetMaskNorm();
AscendC::SetVectorMask<uint16_t, MaskMode::NORMAL>(loop_tail * 16);
AscendC::Or<uint16_t, false>(dst[dst_offset].template ReinterpretCast<uint16_t>(),
src.template ReinterpretCast<uint16_t>(),
tmp_buf.template ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, rpt, binary_params);
}
AscendC::SetMaskNorm();
AscendC::ResetMask();
}
}
template <typename T>
__aicore__ inline void BrcbSrcToOneBlock(const LocalTensor<T> &src_local, const uint32_t first_dim,
uint32_t one_blk_num, LocalTensor<T> &brcb_buf) {
const uint32_t brcb_rpt_times = (first_dim + BRCB_ONE_SIZE_A - 1) / BRCB_ONE_SIZE_A;
uint32_t brcb_max_rpt_times = BRCB_HALF_MAX_REPEATE_TIMES_A;
if constexpr (sizeof(T) == sizeof(float)) {
brcb_max_rpt_times = BRCB_FLOAT_MAX_REPEATE_TIMES_A;
}
const uint32_t loop_cnt = brcb_rpt_times / brcb_max_rpt_times;
const uint32_t tail_rpt = brcb_rpt_times % brcb_max_rpt_times;
uint32_t brcb_src_offset = 0;
uint32_t brcb_dst_offset = 0;
for (uint32_t i = 0; i < loop_cnt; i++) {
Brcb(brcb_buf[brcb_dst_offset], src_local[brcb_src_offset], brcb_max_rpt_times, {1, DEFAULT_REPEAT_STRIDE});
brcb_dst_offset += brcb_max_rpt_times * one_blk_num * BRCB_ONE_SIZE_A;
brcb_src_offset += brcb_max_rpt_times * BRCB_ONE_SIZE_A;
}
if (tail_rpt != 0) {
Brcb(brcb_buf[brcb_dst_offset], src_local[brcb_src_offset], tail_rpt, {1, DEFAULT_REPEAT_STRIDE});
}
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void TwoDimBroadCastLastDimAlign(const LocalTensor<T> &dst_local, const LocalTensor<T> &src_local,
LocalTensor<T> &tmp_buf, const uint32_t first_dim,
const uint32_t block_count) {
constexpr uint32_t one_blk_num = ONE_BLK_SIZE / sizeof(T);
BrcbSrcToOneBlock(src_local, first_dim, one_blk_num, tmp_buf);
SetVectorMask<T, MaskMode::COUNTER>(block_count);
const CopyRepeatParams copy_params = {1, 0, (uint16_t)(block_count / one_blk_num), 1};
uint32_t loop_cnt = first_dim / MAX_REPEAT_TIMES;
uint32_t dst_offset = 0;
uint32_t src_offset = 0;
for (uint32_t i = 0; i < loop_cnt; i++) {
Copy<T, false>(dst_local[dst_offset], tmp_buf[src_offset], MASK_PLACEHOLDER, MAX_REPEAT_TIMES, copy_params);
dst_offset += MAX_REPEAT_TIMES * block_count;
src_offset += MAX_REPEAT_TIMES * one_blk_num;
}
uint32_t tail_rpt = first_dim % MAX_REPEAT_TIMES;
if (tail_rpt != 0) {
Copy<T, false>(dst_local[dst_offset], tmp_buf[src_offset], MASK_PLACEHOLDER, tail_rpt, copy_params);
}
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void TwoDimBroadCastLastDimNotAlign(const LocalTensor<T> &dst_local, const LocalTensor<T> &src_local,
LocalTensor<T> &tmp_buf, const uint32_t first_dim,
const uint32_t block_count) {
constexpr uint32_t one_blk_num = ONE_BLK_SIZE / sizeof(T);
BrcbSrcToOneBlock(src_local, first_dim, one_blk_num, tmp_buf);
const uint32_t align_blk_num = (block_count + one_blk_num - 1) / one_blk_num;
const uint32_t block_count_align = align_blk_num * one_blk_num;
SetVectorMask<T, MaskMode::COUNTER>(block_count_align);
const CopyRepeatParams copy_params = {1, 0, (uint16_t)align_blk_num, 1};
uint32_t copy_counts = first_dim / MAX_REPEAT_TIMES;
uint32_t dst_offset = 0;
uint32_t src_offset = 0;
auto copy_tmp_buffer = tmp_buf[first_dim * one_blk_num];
for (uint32_t i = 0; i < copy_counts; i++) {
Copy<T, false>(copy_tmp_buffer[dst_offset], tmp_buf[src_offset], MASK_PLACEHOLDER, MAX_REPEAT_TIMES, copy_params);
dst_offset += MAX_REPEAT_TIMES * block_count_align;
src_offset += MAX_REPEAT_TIMES * one_blk_num;
}
uint32_t tail_rpt = first_dim % MAX_REPEAT_TIMES;
if (tail_rpt != 0) {
Copy<T, false>(copy_tmp_buffer[dst_offset], tmp_buf[src_offset], MASK_PLACEHOLDER, tail_rpt, copy_params);
}
AscendC::PipeBarrier<PIPE_V>();
const GatherMaskParams gather_params = {1, (uint16_t)first_dim, (uint16_t)align_blk_num, 0};
uint64_t rsvd_cnt = 0;
GatherMask(dst_local, copy_tmp_buffer, GATHER_MASK_PATTERN_A, true, block_count, gather_params, rsvd_cnt);
SetMaskCount();
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void GetBrcAlignLoopNumbers(const uint32_t first_dim, const uint32_t block_count,
const uint32_t tmp_buf_size, uint32_t &one_repeat_size, uint32_t &range_m,
uint32_t &tail_m) {
constexpr uint32_t one_blk_num = ONE_BLK_SIZE / sizeof(T);
constexpr uint32_t min_brcb_temp_buffer_size = one_blk_num * one_blk_num + one_blk_num;
constexpr uint32_t min_tmp_buf_size = min_brcb_temp_buffer_size;
ASCENDC_ASSERT((tmp_buf_size >= min_tmp_buf_size), {
KERNEL_LOG(KERNEL_ERROR,
"tmp_buf_size can't smaller than min_tmp_buf_size, tmp_buf_size is %u, min_tmp_buf_size is %u!",
tmp_buf_size, min_tmp_buf_size);
});
one_repeat_size = tmp_buf_size / min_tmp_buf_size * one_blk_num;
range_m = first_dim / one_repeat_size;
tail_m = first_dim - one_repeat_size * range_m;
}
template <typename T>
__aicore__ inline void GetBrcNotAlignLoopNumbers(const uint32_t first_dim, const uint32_t block_count,
const uint32_t tmp_buf_size, uint32_t &one_repeat_size,
uint32_t &range_m, uint32_t &tail_m) {
constexpr uint32_t one_blk_num = ONE_BLK_SIZE / sizeof(T);
constexpr uint32_t min_brcb_temp_buffer_size = one_blk_num * one_blk_num + one_blk_num;
const uint32_t align_blk_num = (block_count + one_blk_num - 1) / one_blk_num;
const uint32_t block_count_align = align_blk_num * one_blk_num;
const uint32_t min_copy_temp_buffer_size = one_blk_num * block_count_align;
const uint32_t min_tmp_buf_size = min_brcb_temp_buffer_size + min_copy_temp_buffer_size;
ASCENDC_ASSERT((tmp_buf_size >= min_tmp_buf_size), {
KERNEL_LOG(KERNEL_ERROR,
"tmp_buf_size can't smaller than min_tmp_buf_size, tmp_buf_size is %u, min_tmp_buf_size is %u!",
tmp_buf_size, min_tmp_buf_size);
});
one_repeat_size = tmp_buf_size / min_tmp_buf_size * one_blk_num;
range_m = first_dim / one_repeat_size;
tail_m = first_dim - one_repeat_size * range_m;
}
template <typename T>
inline __aicore__ void GetSrcTensorWithoutStride(const LocalTensor<T> &tmp_src, const LocalTensor<T> &src,
const uint32_t cal_cnt, const uint32_t offset) {
uint64_t src_offset = 0;
event_t event_id_v_to_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
event_t event_id_s_to_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
AscendC::SetFlag<HardEvent::V_S>(event_id_v_to_s);
AscendC::WaitFlag<HardEvent::V_S>(event_id_v_to_s);
for (uint32_t i = 0; i < cal_cnt; i++) {
auto tmp = src.GetValue(src_offset);
tmp_src.SetValue(i, tmp);
src_offset += offset;
}
AscendC::SetFlag<HardEvent::S_V>(event_id_s_to_v);
AscendC::WaitFlag<HardEvent::S_V>(event_id_s_to_v);
}
template <typename T>
inline __aicore__ void BroadcastWithStride(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<T> &tmp_buf) {
uint32_t one_repeat_size = 0;
uint32_t range_m = 0;
uint32_t tail_m = 0;
uint64_t dst_offset = 0;
uint64_t src_offset = 0;
if (dst_k * sizeof(T) % ONE_BLK_SIZE == 0) {
GetBrcAlignLoopNumbers<T>(dst_m, dst_k, tmp_buf.GetSize(), one_repeat_size, range_m, tail_m);
LocalTensor<T> tmp_src = tmp_buf;
tmp_src.SetSize(one_repeat_size);
LocalTensor<T> brcb_buf = tmp_buf[one_repeat_size];
for (uint32_t i = 0; i < range_m; i++) {
GetSrcTensorWithoutStride(tmp_src, src[src_offset], one_repeat_size, src_k);
TwoDimBroadCastLastDimAlign(dst[dst_offset], tmp_src, brcb_buf, one_repeat_size, dst_k);
dst_offset += one_repeat_size * dst_k;
src_offset += one_repeat_size * src_k;
}
if (tail_m != 0) {
GetSrcTensorWithoutStride(tmp_src, src[src_offset], tail_m, src_k);
TwoDimBroadCastLastDimAlign(dst[dst_offset], tmp_src, brcb_buf, tail_m, dst_k);
}
} else {
GetBrcNotAlignLoopNumbers<T>(dst_m, dst_k, tmp_buf.GetSize(), one_repeat_size, range_m, tail_m);
LocalTensor<T> tmp_src = tmp_buf;
tmp_src.SetSize(one_repeat_size);
LocalTensor<T> brcb_buf = tmp_buf[one_repeat_size];
for (uint32_t i = 0; i < range_m; i++) {
GetSrcTensorWithoutStride(tmp_src, src[src_offset], one_repeat_size, src_k);
TwoDimBroadCastLastDimNotAlign(dst[dst_offset], tmp_src, brcb_buf, one_repeat_size, dst_k);
dst_offset += one_repeat_size * dst_k;
src_offset += one_repeat_size * src_k;
}
if (tail_m != 0) {
GetSrcTensorWithoutStride(tmp_src, src[src_offset], tail_m, src_k);
TwoDimBroadCastLastDimNotAlign(dst[dst_offset], tmp_src, brcb_buf, tail_m, dst_k);
}
}
}
template <typename T>
inline __aicore__ void BroadcastMiddleDimWithCopy(const LocalTensor<T> &dst, const LocalTensor<T> &src,
const uint32_t src_m, const uint32_t src_k, const uint32_t src_z,
const uint32_t dst_m, const uint32_t dst_k, const uint32_t dst_z,
LocalTensor<uint8_t> &tmp_buf, const uint32_t last_dim_stride = 1) {
const uint32_t max_rpt_cnt = dst_m / MAX_REPEAT_TIME;
uint32_t calc_size = 0;
const uint32_t dst_inner_offset = dst_k * dst_z;
AscendC::SetMaskCount();
AscendC::SetVectorMask<T, MaskMode::COUNTER>(dst_inner_offset);
const uint16_t dst_rpt_stride = dst_inner_offset * sizeof(T) / ONE_BLK_SIZE;
const uint32_t max_rpt_calc_num = dst_inner_offset * MAX_REPEAT_TIME;
AscendC::CopyRepeatParams repeat_params{1, 0, dst_rpt_stride, 1};
for (uint32_t idx = 0; idx < max_rpt_cnt; idx++) {
AscendC::Copy<T, false>(dst[calc_size], src, MASK_PLACEHOLDER, MAX_REPEAT_TIME, repeat_params);
calc_size += max_rpt_calc_num;
}
uint32_t tail_rpt_times = dst_m - max_rpt_cnt * MAX_REPEAT_TIME;
if (tail_rpt_times != 0) {
AscendC::Copy<T, false>(dst[calc_size], src, MASK_PLACEHOLDER, tail_rpt_times, repeat_params);
}
AscendC::SetMaskNorm();
AscendC::ResetMask();
}
template <typename T>
inline __aicore__ void BroadcastCommon(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<uint8_t> &tmp_buf,
const uint32_t last_dim_stride = 1) {
uint32_t inner_offset = src_z == 0 ? 1 : src_z;
uint32_t dst_inner_offset = dst_k * inner_offset;
if (src_m == 1 && src_k == dst_k && src_z == dst_z) {
BroadcastFirstDim(dst, src, src_m, src_k, src_z, dst_m, dst_k, dst_z, tmp_buf);
} else if ((src_k == 1) && (src_m == dst_m) && (src_z == dst_z)) {
if (src_z == 0) {
if (last_dim_stride > 1) {
LocalTensor<T> tmp_t_buf = tmp_buf.template ReinterpretCast<T>();
AscendC::SetMaskCount();
BroadcastWithStride(dst, src, src_m, last_dim_stride, src_z, dst_m, dst_k, dst_z, tmp_t_buf);
AscendC::SetMaskNorm();
AscendC::ResetMask();
} else {
const uint32_t dst_shape[2]{dst_m, dst_k};
const uint32_t src_shape[2]{src_m, src_k};
AscendC::Broadcast<T, 2, 1>(dst, src, dst_shape, src_shape, tmp_buf);
}
} else {
if (dst_z <= KernelUtils::BlkSize<T>()) {
BroadcastMiddleDimWithCopy(dst, src, src_m, src_k, src_z, dst_m, dst_k, dst_z, tmp_buf);
} else {
for (int i = 0; i < src_m; i++) {
BroadcastMiddleDim(dst[i * dst_inner_offset], src[i * src_z], src_k, src_z, 0, dst_k, dst_z, 0, tmp_buf);
}
}
}
} else if (src_k == 1 && src_m == 1 && src_z == 0 && dst_z == 0) {
const uint32_t dst_shape[2]{dst_m, dst_k};
const uint32_t src_shape[2]{src_m, src_k};
AscendC::Broadcast<T, 2, 1>(dst, src, dst_shape, src_shape, tmp_buf);
} else {
ASSERT(false && "Broadcast size not support.");
}
}
template <typename T>
inline __aicore__ void BroadcastWithCast(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<uint8_t> &tmp_buf,
const uint32_t last_dim_stride = 1) {
if (src_z != dst_z) {
ASSERT(false && "Broadcast inner_axis mismatch is not supported.");
}
LocalTensor<half> dup_tmp = tmp_buf.template ReinterpretCast<half>();
if (src_m == 1 && src_k == dst_k) {
uint32_t inner_offset = src_z == 0 ? 1 : src_z;
constexpr uint32_t kRatio = sizeof(uint16_t) / sizeof(T);
uint32_t dst_inner_offset = dst_k * inner_offset / kRatio;
LocalTensor<uint16_t> src_tmp = src.template ReinterpretCast<uint16_t>();
LocalTensor<uint16_t> dst_tmp = dst.template ReinterpretCast<uint16_t>();
BroadcastFirstDim(dst_tmp, src_tmp, src_m, dst_inner_offset, 0, dst_m, dst_inner_offset, 0, tmp_buf);
} else if (src_k == 1 && src_m == dst_m) {
if (src_z == 0) {
const uint32_t dst_shape[2]{dst_m, dst_k};
const uint32_t src_shape[2]{src_m, src_k};
AscendC::Broadcast<T, 2, 1>(dst, src, dst_shape, src_shape, tmp_buf);
} else {
const uint8_t scalar = 0.0;
const uint32_t dst_shape[2]{dst_k, dst_z};
const uint32_t src_shape[2]{src_k, src_z};
uint32_t inner_offset = src_z == 0 ? 1 : src_z;
uint32_t dst_inner_offset = dst_k * inner_offset;
for (int i = 0; i < src_m; i++) {
AscendC::Broadcast<T, 2, 0>(dst[i * dst_inner_offset], src[i * inner_offset], dst_shape, src_shape, tmp_buf);
}
}
} else if (src_k == 1 && src_m == 1 && src_z == 0 && dst_z == 0) {
const uint32_t src_shape[2]{src_m, src_k};
const uint32_t dst_shape[2]{dst_m, dst_k};
AscendC::Broadcast<T, 2, 1>(dst, src, dst_shape, src_shape, tmp_buf);
} else {
ASSERT(false && "Broadcast size not support.");
}
}
inline __aicore__ void BroadcastInt64LastDim(const LocalTensor<int32_t> &dst_int32,
const LocalTensor<int32_t> &src_int32, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<int32_t> &calc_buf,
const uint32_t last_dim_stride = 1) {
constexpr uint32_t ONE_LOOP_CALC_NUM = 2;
auto event_id_vs = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
auto event_id = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
AscendC::SetMaskNorm();
const uint32_t loop_cnt = src_m / ONE_LOOP_CALC_NUM;
uint32_t calc_buf_offset = 0;
constexpr uint32_t one_loop_offset = ONE_BLK_SIZE / sizeof(int32_t);
uint32_t src_offset = 0;
AscendC::SetFlag<HardEvent::V_S>(event_id_vs);
AscendC::WaitFlag<HardEvent::V_S>(event_id_vs);
for (uint32_t loop = 0; loop < loop_cnt; loop++) {
int32_t calc_element_1 = src_int32.GetValue(src_offset++);
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
AscendC::SetVectorMask<int32_t, MaskMode::NORMAL>(0, 0x55);
AscendC::Duplicate<int32_t, false>(calc_buf[calc_buf_offset], calc_element_1, MASK_PLACEHOLDER, 1, 1, 0);
int32_t calc_element_2 = src_int32.GetValue(src_offset++);
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
AscendC::SetVectorMask<int32_t, MaskMode::NORMAL>(0, 0xaa);
AscendC::Duplicate<int32_t, false>(calc_buf[calc_buf_offset], calc_element_2, MASK_PLACEHOLDER, 1, 1, 0);
calc_buf_offset += one_loop_offset;
int32_t calc_element_3 = src_int32.GetValue(src_offset++);
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
AscendC::SetVectorMask<int32_t, MaskMode::NORMAL>(0, 0x55);
AscendC::Duplicate<int32_t, false>(calc_buf[calc_buf_offset], calc_element_3, MASK_PLACEHOLDER, 1, 1, 0);
int32_t calc_element_4 = src_int32.GetValue(src_offset++);
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
AscendC::SetVectorMask<int32_t, MaskMode::NORMAL>(0, 0xaa);
AscendC::Duplicate<int32_t, false>(calc_buf[calc_buf_offset], calc_element_4, MASK_PLACEHOLDER, 1, 1, 0);
calc_buf_offset += one_loop_offset;
}
const uint32_t remain_cnt = src_m - loop_cnt * ONE_LOOP_CALC_NUM;
if (remain_cnt != 0) {
int32_t calc_element = src_int32.GetValue(src_offset++);
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
AscendC::SetVectorMask<int32_t, MaskMode::NORMAL>(0, 0x55);
AscendC::Duplicate<int32_t, false>(calc_buf[calc_buf_offset], calc_element, MASK_PLACEHOLDER, 1, 1, 0);
int32_t calc_element_2 = src_int32.GetValue(src_offset++);
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
AscendC::SetVectorMask<int32_t, MaskMode::NORMAL>(0, 0xaa);
AscendC::Duplicate<int32_t, false>(calc_buf[calc_buf_offset], calc_element_2, MASK_PLACEHOLDER, 1, 1, 0);
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetMaskCount();
constexpr uint32_t RATIAO = sizeof(int64_t) / sizeof(int32_t);
const uint32_t calc_cnt = dst_k * RATIAO;
AscendC::SetVectorMask<int32_t, MaskMode::COUNTER>(calc_cnt);
AscendC::CopyRepeatParams repeat_params{1, 0, uint16_t(dst_k * sizeof(int64_t) / ONE_BLK_SIZE), 1};
const uint32_t max_repeat_count = src_m / MAX_REPEAT_TIMES;
const uint32_t dst_stride = MAX_REPEAT_TIMES * calc_cnt;
constexpr uint32_t src_stride = MAX_REPEAT_TIMES * ONE_BLK_SIZE / sizeof(int32_t);
uint32_t buf_offset = 0;
uint32_t dst_offset = 0;
for (uint32_t rpt_time = 0; rpt_time < max_repeat_count; rpt_time++) {
AscendC::Copy<int32_t, false>(dst_int32[dst_offset], calc_buf[buf_offset], MASK_PLACEHOLDER, MAX_REPEAT_TIMES,
repeat_params);
dst_offset += dst_stride;
buf_offset += src_stride;
}
uint32_t tail_rpt_times = src_m - max_repeat_count * MAX_REPEAT_TIMES;
if (tail_rpt_times > 0) {
AscendC::Copy<int32_t, false>(dst_int32[dst_offset], calc_buf[buf_offset], MASK_PLACEHOLDER, tail_rpt_times,
repeat_params);
}
}
template <typename T>
inline __aicore__ void BroadcastInt64(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_k, const uint32_t dst_z, LocalTensor<uint8_t> &tmp_buf,
const uint32_t last_dim_stride = 1) {
if (src_z != dst_z) {
ASSERT(false && "Broadcast inner_axis mismatch is not supported.");
}
LocalTensor<int32_t> dst_int32 = dst.template ReinterpretCast<int32_t>();
LocalTensor<int32_t> src_int32 = src.template ReinterpretCast<int32_t>();
constexpr uint32_t kRatio = sizeof(T) / sizeof(int32_t);
uint32_t inner_offset = src_z == 0 ? 1 : src_z;
uint32_t dst_inner_offset = dst_k * inner_offset * kRatio;
if (src_m == 1 && src_k == dst_k) {
BroadcastFirstDim(dst_int32, src_int32, src_m, dst_inner_offset, 0, dst_m, dst_inner_offset, 0, tmp_buf);
} else if (src_k == 1 && src_m == dst_m) {
if (src_z == 0) {
const uint32_t buf_max_calc_num = tmp_buf.GetSize() * sizeof(uint8_t) / ONE_BLK_SIZE;
LocalTensor<int32_t> calc_buf = tmp_buf.template ReinterpretCast<int32_t>();
uint32_t buf_loop_cnt = src_m * src_k / buf_max_calc_num;
uint32_t src_offset = 0;
uint32_t dst_offset = 0;
const uint32_t dst_stride = buf_max_calc_num * dst_inner_offset;
const uint32_t src_stride = buf_max_calc_num * kRatio;
for (uint32_t loop_idx = 0; loop_idx < buf_loop_cnt; loop_idx++) {
BroadcastInt64LastDim(dst_int32[dst_offset], src_int32[src_offset], buf_max_calc_num, src_k, src_z,
buf_max_calc_num, dst_k, dst_z, calc_buf, last_dim_stride);
dst_offset += dst_stride;
src_offset += src_stride;
}
const uint32_t tail_calc_num = src_m - buf_loop_cnt * buf_max_calc_num;
if (tail_calc_num != 0) {
BroadcastInt64LastDim(dst_int32[dst_offset], src_int32[src_offset], tail_calc_num, src_k, src_z, tail_calc_num,
dst_k, dst_z, calc_buf, last_dim_stride);
}
AscendC::SetMaskNorm();
AscendC::ResetMask();
} else {
uint32_t dst_offset = 0;
uint32_t src_offset = 0;
const uint32_t src_inner_offset = inner_offset * kRatio;
for (int i = 0; i < src_m; i++) {
const uint32_t dst_shape[2]{dst_k, dst_z * kRatio};
const uint32_t src_shape[2]{src_k, src_z * kRatio};
AscendC::Broadcast<int32_t, 2, 0>(dst_int32[dst_offset], src_int32[src_offset], dst_shape, src_shape, tmp_buf);
dst_offset += dst_inner_offset;
src_offset += src_inner_offset;
}
}
} else if (src_k == 1 && src_m == 1 && src_z == 0 && dst_z == 0) {
const T scalarVlue = src.GetValue(0);
auto event_id = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
AscendC::SetFlag<HardEvent::S_V>(event_id);
AscendC::WaitFlag<HardEvent::S_V>(event_id);
Duplicate(dst, scalarVlue, dst_m * dst_k, tmp_buf);
} else {
ASSERT(false && "Broadcast size not support.");
}
}
template <typename T>
inline __aicore__ void Broadcast(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m, const uint32_t dst_k,
const uint32_t dst_z, LocalTensor<uint8_t> &tmp_buf,
const uint32_t last_dim_stride = 1) {
if constexpr (AscendC::SupportType<T, int64_t, uint64_t>()) {
BroadcastInt64(dst, src, src_m, src_k, src_z, dst_m, dst_k, dst_z, tmp_buf, last_dim_stride);
} else if constexpr (AscendC::SupportType<T, uint8_t, int8_t>()) {
BroadcastWithCast(dst, src, src_m, src_k, src_z, dst_m, dst_k, dst_z, tmp_buf, last_dim_stride);
} else if constexpr (AscendC::SupportType<T, int16_t, uint16_t, half, float, int32_t, uint32_t>()) {
BroadcastCommon(dst, src, src_m, src_k, src_z, dst_m, dst_k, dst_z, tmp_buf, last_dim_stride);
} else {
ASSERT(false && "Broadcast type not support.");
}
}
template <typename T>
inline __aicore__ void Broadcast(const LocalTensor<T> &dst, const LocalTensor<T> &src, const uint32_t src_m,
const uint32_t src_n, const uint32_t src_k, const uint32_t src_z, const uint32_t dst_m,
const uint32_t dst_n, const uint32_t dst_k, const uint32_t dst_z,
LocalTensor<uint8_t> &tmp_buf, const uint32_t last_dim_stride = 1) {
if (src_m == 1 && src_k == 1 && src_n == dst_n && src_z == dst_z) {
Broadcast(dst, src, src_n, src_k, src_z, dst_n, dst_k, dst_z, tmp_buf);
AscendC::PipeBarrier<PIPE_V>();
const uint32_t offset = dst_n * dst_k * dst_z;
Broadcast(dst[offset], dst, src_m, offset, 0, dst_m - 1, offset, 0, tmp_buf);
} else if (src_n == 1 && src_z == 1) {
const uint32_t offset = dst_m * src_n * dst_k;
ASSERT((tmp_buf.GetSize() > (offset * dst_z + dst_m * dst_n) * sizeof(T)) && "tmp_buf size is not enough.");
LocalTensor<T> inter_buf = tmp_buf.template ReinterpretCast<T>();
inter_buf.SetSize(offset * dst_z);
LocalTensor<uint8_t> left_buf = tmp_buf[offset * dst_z * sizeof(T)];
Broadcast(inter_buf, src, offset, 1, 0, offset, dst_z, 0, left_buf, last_dim_stride);
AscendC::PipeBarrier<PIPE_V>();
Broadcast(dst, inter_buf, src_m, 1, dst_k * dst_z, dst_m, dst_n, dst_k * dst_z, left_buf);
} else {
ASSERT(false && "Broadcast type not support.");
}
}
#endif