* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file broadcast_c310_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message("impl/adv_api/detail/pad/broadcast/broadcast_c310_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/pad/broadcast.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_PAD_BROADCAST_BROADCAST_C310_IMPL_H__
#endif
#ifndef IMPL_PAD_BROADCAST_BROADCAST_C310_IMPL_H
#define IMPL_PAD_BROADCAST_BROADCAST_C310_IMPL_H
#include "kernel_basic_intf.h"
#include "kernel_tensor.h"
#include "broadcast_gather_c310_impl.h"
namespace AscendC {
namespace BroadcastInternal {
template <typename T>
__simd_callee__ inline void E2bLoad(Reg::RegTensor<T> &dstReg, __ubuf__ T *srcUb)
{
if constexpr (sizeof(T) == 2) {
Reg::LoadAlign<T, Reg::LoadDist::DIST_E2B_B16>(dstReg, srcUb);
} else {
Reg::LoadAlign<T, Reg::LoadDist::DIST_E2B_B32>(dstReg, srcUb);
}
}
template <typename T>
__simd_callee__ inline void BrcLoad(Reg::RegTensor<T> &dstReg, __ubuf__ T *srcUb)
{
if constexpr (sizeof(T) == 2) {
Reg::LoadAlign<T, Reg::LoadDist::DIST_BRC_B16>(dstReg, srcUb);
} else if constexpr (sizeof(T) == 4) {
Reg::LoadAlign<T, Reg::LoadDist::DIST_BRC_B32>(dstReg, srcUb);
} else {
Reg::LoadAlign<T, Reg::LoadDist::DIST_BRC_B8>(dstReg, srcUb);
}
}
template <typename T>
__simd_vf__ inline void BrcDuplicate(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint32_t dstSize)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t repeatTimes = CeilDivision(dstSize, VF_LEN);
uint32_t sreg = dstSize;
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
BrcLoad<T>(srcReg, srcUb);
for (uint16_t i = 0; i < repeatTimes; ++i) {
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign(dstUb + i * VF_LEN, srcReg, pregCnt);
}
}
template <typename T>
__simd_vf__ inline void GenLastGatherIndex(__ubuf__ T *indexUb, uint32_t size1, uint32_t offset)
{
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::RegTensor<T> indexReg;
Reg::RegTensor<T> tmpReg;
Reg::Duplicate(indexReg, (T)size1, pregFull);
Reg::Arange(tmpReg, (T)offset);
Reg::Div(indexReg, tmpReg, indexReg, pregFull);
Reg::StoreAlign(indexUb, indexReg, pregFull);
}
template <typename T>
__simd_vf__ inline void GenNlastGatherIndex(__ubuf__ T *indexUb, uint32_t size1, uint32_t offset)
{
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::RegTensor<T> indexReg;
Reg::RegTensor<T> tmpReg;
Reg::RegTensor<T> dstReg;
Reg::Duplicate(indexReg, (T)size1, pregFull);
Reg::Arange(tmpReg, (T)offset);
Reg::Div(dstReg, tmpReg, indexReg, pregFull);
Reg::Mul(dstReg, indexReg, dstReg, pregFull);
Reg::Sub(indexReg, tmpReg, dstReg, pregFull);
Reg::StoreAlign(indexUb, indexReg, pregFull);
}
template <typename T, typename IndexT>
__simd_vf__ inline void BrcLastGatherOne(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, __ubuf__ IndexT *indexUb, uint16_t size0, uint16_t size1)
{
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
uint32_t main = size0 * size1;
Reg::MaskReg pregFull = Reg::CreateMask<IndexT>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
Reg::RegTensor<T> dummyReg;
Reg::RegTensor<IndexT> srcReg1;
Reg::RegTensor<IndexT> srcReg2;
Reg::RegTensor<IndexT> indexReg1;
Reg::RegTensor<IndexT> indexReg2;
Reg::LoadAlign(indexReg1, indexUb);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::LoadAlign(indexReg2, indexUb + VF_LEN_HALF);
}
uint32_t sreg = main;
pregCnt = Reg::UpdateMask<T>(sreg);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::Gather(srcReg1, srcUb, indexReg1, pregFull);
Reg::Gather(srcReg2, srcUb, indexReg2, pregFull);
Reg::DeInterleave(
srcReg, dummyReg, (Reg::RegTensor<T> &)srcReg1, (Reg::RegTensor<T> &)srcReg2);
} else {
Reg::Gather(srcReg, srcUb, indexReg1, pregCnt);
}
Reg::StoreAlign(dstUb, srcReg, pregCnt);
}
template <typename T, typename IndexT>
__simd_vf__ inline void BrcLastGatherTwo(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, __ubuf__ IndexT *indexUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
uint16_t factor = VF_LEN / size1;
uint16_t repeatTimes = CeilDivision(size0, factor) - 1;
uint32_t main = factor * size1;
uint32_t mainBlock = main * repeatTimes;
uint32_t offset = factor * repeatTimes;
uint32_t tail = size0 * size1 - mainBlock;
Reg::MaskReg pregFull = Reg::CreateMask<IndexT>();
Reg::RegTensor<T> srcReg;
Reg::RegTensor<T> dummyReg;
Reg::RegTensor<IndexT> indexReg1;
Reg::RegTensor<IndexT> indexReg2;
Reg::RegTensor<IndexT> factorReg;
Reg::RegTensor<IndexT> srcReg1;
Reg::RegTensor<IndexT> srcReg2;
Reg::RegTensor<IndexT> dstReg;
Reg::RegTensor<IndexT> tmpReg;
Reg::UnalignReg ureg0;
Reg::Duplicate(factorReg, (IndexT)factor, pregFull);
Reg::LoadAlign(indexReg1, indexUb);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::LoadAlign(indexReg2, indexUb + VF_LEN_HALF);
}
for (uint16_t i = 0; i < repeatTimes; ++i) {
Reg::Muls(tmpReg, factorReg, (IndexT)i, pregFull);
Reg::Add(dstReg, tmpReg, indexReg1, pregFull);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::Gather(srcReg1, srcUb, dstReg, pregFull);
Reg::Add(dstReg, tmpReg, indexReg2, pregFull);
Reg::Gather(srcReg2, srcUb, dstReg, pregFull);
Reg::DeInterleave(
srcReg, dummyReg, (Reg::RegTensor<T> &)srcReg1, (Reg::RegTensor<T> &)srcReg2);
} else {
Reg::Gather(srcReg, srcUb, dstReg, pregFull);
}
Reg::StoreUnAlign(dstUb, srcReg, ureg0, main);
}
Reg::Adds(dstReg, indexReg1, (IndexT)offset, pregFull);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::Gather(srcReg1, srcUb, dstReg, pregFull);
Reg::Adds(dstReg, indexReg2, (IndexT)offset, pregFull);
Reg::Gather(srcReg2, srcUb, dstReg, pregFull);
Reg::DeInterleave(
srcReg, dummyReg, (Reg::RegTensor<T> &)srcReg1, (Reg::RegTensor<T> &)srcReg2);
} else {
Reg::Gather(srcReg, srcUb, dstReg, pregFull);
}
Reg::StoreUnAlign(dstUb, srcReg, ureg0, tail);
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T, typename IndexT>
__simd_vf__ inline void BrcNlastGatherOne(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, __ubuf__ IndexT *indexUb, uint16_t size0, uint16_t size1)
{
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
uint32_t main = size0 * size1;
Reg::MaskReg pregFull = Reg::CreateMask<IndexT>();
Reg::MaskReg pregCnt;
Reg::RegTensor<IndexT> indexReg1;
Reg::RegTensor<IndexT> indexReg2;
Reg::RegTensor<IndexT> srcReg1;
Reg::RegTensor<IndexT> srcReg2;
Reg::RegTensor<T> srcReg;
Reg::RegTensor<T> dummyReg;
Reg::UnalignReg ureg0;
uint32_t sreg = main;
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::LoadAlign(indexReg1, indexUb);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::LoadAlign(indexReg2, indexUb + VF_LEN_HALF);
Reg::Gather(srcReg1, srcUb, indexReg1, pregFull);
Reg::Gather(srcReg2, srcUb, indexReg2, pregFull);
Reg::DeInterleave(
srcReg, dummyReg, (Reg::RegTensor<T> &)srcReg1, (Reg::RegTensor<T> &)srcReg2);
} else {
Reg::Gather(srcReg, srcUb, indexReg1, pregCnt);
}
Reg::StoreAlign(dstUb, srcReg, pregCnt);
}
template <typename T, typename IndexT>
__simd_vf__ inline void BrcNlastGatherTwo(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, __ubuf__ IndexT *indexUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
uint16_t factor = VF_LEN / size1;
uint16_t repeatTimes = CeilDivision(size0, factor) - 1;
uint32_t main = factor * size1;
uint32_t mainBlock = main * repeatTimes;
uint32_t tail = size0 * size1 - mainBlock;
Reg::MaskReg pregFull = Reg::CreateMask<IndexT>();
Reg::RegTensor<IndexT> indexReg1;
Reg::RegTensor<IndexT> indexReg2;
Reg::RegTensor<IndexT> srcReg1;
Reg::RegTensor<IndexT> srcReg2;
Reg::RegTensor<T> srcReg;
Reg::RegTensor<T> dummyReg;
Reg::UnalignReg ureg0;
Reg::LoadAlign(indexReg1, indexUb);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
Reg::LoadAlign(indexReg2, indexUb + VF_LEN_HALF);
Reg::Gather(srcReg1, srcUb, indexReg1, pregFull);
Reg::Gather(srcReg2, srcUb, indexReg2, pregFull);
Reg::DeInterleave(
srcReg, dummyReg, (Reg::RegTensor<T> &)srcReg1, (Reg::RegTensor<T> &)srcReg2);
} else {
Reg::Gather(srcReg, srcUb, indexReg1, pregFull);
}
for (uint16_t i = 0; i < repeatTimes; ++i) {
Reg::StoreUnAlign(dstUb, srcReg, ureg0, main);
}
Reg::StoreUnAlign(dstUb, srcReg, ureg0, tail);
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcLastE2B(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = VF_LEN / size1;
uint16_t repeatTimes = CeilDivision(size0, factor);
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
uint32_t sreg = size0 * size1;
for (uint16_t i = 0; i < repeatTimes; ++i) {
pregCnt = Reg::UpdateMask<T>(sreg);
E2bLoad<T>(srcReg, srcUb + i * DEFAULT_BLK_NUM);
Reg::StoreAlign(dstUb + i * VF_LEN, srcReg, pregCnt);
}
}
template <typename T>
__simd_vf__ inline void BrcLastE2BLargerThanVL(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1, uint16_t size2, uint16_t srcStride0)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = VF_LEN / size2;
uint16_t repeatTimes = CeilDivision(size1, factor);
uint32_t preg = size1 * size2;
uint32_t sreg;
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
sreg = preg;
for (uint16_t j = 0; j < repeatTimes; ++j) {
pregCnt = Reg::UpdateMask<T>(sreg);
E2bLoad<T>(srcReg, srcUb + j * DEFAULT_BLK_NUM + i * srcStride0);
Reg::StoreAlign(dstUb + i * size1 * size2 + j * VF_LEN, srcReg, pregCnt);
}
}
}
template <typename T>
__simd_vf__ inline void BrcLastE2BLessThanVL(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1, uint16_t size2, uint16_t srcStride0)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint32_t preg = size1 * size2;
uint32_t sreg;
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
sreg = preg;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < size0; ++i) {
E2bLoad<T>(srcReg, srcUb + i * srcStride0);
Reg::StoreAlign(dstUb + i * size1 * size2, srcReg, pregCnt);
}
}
template <typename T>
__simd_vf__ inline void BrcLastE2B(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1,
uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = VF_LEN / size3;
uint16_t repeatTimes = CeilDivision(size2, factor);
uint32_t preg = size2 * size3;
uint32_t sreg;
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
sreg = preg;
for (uint16_t k = 0; k < repeatTimes; ++k) {
pregCnt = Reg::UpdateMask<T>(sreg);
E2bLoad<T>(srcReg, srcUb + i * srcStride0 + j * srcStride1 + k * DEFAULT_BLK_NUM);
Reg::StoreAlign(
dstUb + i * size1 * size2 * size3 + j * size2 * size3 + k * VF_LEN, srcReg, pregCnt);
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastGatherBOne(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, __ubuf__ uint32_t *indexUb, uint16_t size0, uint16_t size1)
{
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint32_t main = size0 * size1;
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
Reg::RegTensor<uint32_t> indexReg;
Reg::LoadAlign(indexReg, indexUb);
Reg::GatherB(srcReg, srcUb, indexReg, pregFull);
uint32_t sreg = main;
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign(dstUb, srcReg, pregCnt);
}
template <typename T>
__simd_vf__ inline void BrcNlastGatherBTwo(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, __ubuf__ uint32_t *indexUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t factor = VF_LEN / size1;
uint16_t repeatTimes = CeilDivision(size0, factor) - 1;
uint32_t main = factor * size1;
uint32_t mainBlock = main * repeatTimes;
uint32_t tail = size0 * size1 - mainBlock;
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
Reg::RegTensor<uint32_t> indexReg;
Reg::LoadAlign(indexReg, indexUb);
Reg::GatherB(srcReg, srcUb, indexReg, pregFull);
uint32_t sreg = main;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < repeatTimes; ++i) {
Reg::StoreAlign(dstUb + i * main, srcReg, pregCnt);
}
sreg = tail;
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign(dstUb + mainBlock, srcReg, pregCnt);
}
template <typename T>
__simd_vf__ inline void BrcLastLessThanVLAligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
uint32_t sreg = size2;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < size1; ++i) {
for (uint16_t j = 0; j < size0; ++j) {
BrcLoad<T>(srcReg, srcUb + j * srcStride0 + i * srcStride1);
Reg::StoreAlign(dstUb + j * size1 * size2 + i * size2, srcReg, pregCnt);
}
}
}
template <typename T>
__simd_vf__ inline void BrcLastLessThanVLAligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
uint32_t sreg = size3;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size2; ++j) {
for (uint16_t k = 0; k < size1; ++k) {
BrcLoad<T>(srcReg, srcUb + i * srcStride0 + j * srcStride2 + k * srcStride1);
Reg::StoreAlign(
dstUb + i * size1 * size2 * size3 + k * size2 * size3 + j * size3, srcReg, pregCnt);
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLessThanVLAligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
uint32_t sreg = size3;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < size1; ++i) {
for (uint16_t j = 0; j < size0; ++j) {
for (uint16_t k = 0; k < size2; ++k) {
Reg::LoadAlign(srcReg, srcUb + i * srcStride1 + j * srcStride0 + k * srcStride2);
Reg::StoreAlign(
dstUb + j * size1 * size2 * size3 + i * size2 * size3 + k * size3, srcReg, pregCnt);
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcLastLessThanVLUnaligned(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
for (uint16_t i = 0; i < size0; ++i) {
BrcLoad<T>(srcReg, srcUb + i);
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size1);
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcLastLessThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
BrcLoad<T>(srcReg, srcUb + j * srcStride1 + i * srcStride0);
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size2);
}
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcLastLessThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
for (uint16_t k = 0; k < size2; ++k) {
BrcLoad<T>(srcReg, srcUb + i * srcStride0 + j * srcStride1 + k * srcStride2);
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size3);
}
}
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcNlastLessThanVLUnaligned(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
uint32_t sreg = size1;
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::LoadAlign(srcReg, srcUb);
for (uint16_t i = 0; i < size0; ++i) {
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size1);
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcNlastLessThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0, ureg1;
uint32_t sreg = size2;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
auto srcUbT = srcUb + i * srcStride0 + j * srcStride1;
Reg::LoadUnAlignPre(ureg0, srcUbT);
Reg::LoadUnAlign(srcReg, ureg0, srcUbT, size2);
Reg::StoreUnAlign(dstUb, srcReg, ureg1, size2);
}
}
Reg::StoreUnAlignPost(dstUb, ureg1, 0);
}
template <typename T>
__simd_vf__ inline void BrcNlastLessThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0, ureg1;
uint32_t sreg = size3;
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
for (uint16_t k = 0; k < size2; ++k) {
__ubuf__ T *srcUbTmp = srcUb + i * srcStride0 + j * srcStride1 + k * srcStride2;
Reg::LoadUnAlignPre(ureg0, srcUbTmp);
Reg::LoadUnAlign(srcReg, ureg0, srcUbTmp, size3);
Reg::StoreUnAlign(dstUb, srcReg, ureg1, size3);
}
}
}
Reg::StoreUnAlignPost(dstUb, ureg1, 0);
}
template <typename T>
__simd_vf__ inline void BrcLastLargerThanVLAligned(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size1, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
BrcLoad<T>(srcReg, srcUb + i);
uint32_t sreg = size1;
for (uint16_t j = 0; j < factor; ++j) {
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign(dstUb + i * size1 + j * VF_LEN, srcReg, pregCnt);
}
}
}
template <typename T>
__simd_vf__ inline void BrcLastLargerThanVLAligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size2, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size1; ++i) {
for (uint16_t j = 0; j < size0; ++j) {
BrcLoad<T>(srcReg, srcUb + i * srcStride1 + j * srcStride0);
uint32_t sreg = size2;
for (uint16_t k = 0; k < factor; ++k) {
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign(dstUb + j * size1 * size2 + i * size2 + k * VF_LEN, srcReg, pregCnt);
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcLastLargerThanVLAligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size3, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size2; ++j) {
uint32_t sreg = size3;
for (uint16_t k = 0; k < factor; ++k) {
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t t = 0; t < size1; ++t) {
BrcLoad<T>(srcReg, srcUb + i * srcStride0 + j * srcStride2 + t * srcStride1);
Reg::StoreAlign(
dstUb + i * size1 * size2 * size3 + t * size2 * size3 + j * size3 + k * VF_LEN,
srcReg,
pregCnt);
}
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLAlignedWithBlock(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size1, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
uint32_t sreg = size1;
for (uint16_t j = 0; j < factor; ++j) {
pregCnt = Reg::UpdateMask<T>(sreg);
Reg::LoadAlign(srcReg, srcUb + j * VF_LEN);
Reg::StoreAlign(dstUb + i * size1 + j * VF_LEN, srcReg, pregCnt);
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLAlignedWithBlock(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
uint16_t size0, uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size2, VF_LEN);
uint16_t jStride = srcStride1 == 0 ? 0 : VF_LEN;
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
uint32_t sreg = size2;
for (uint16_t j = 0; j < factor; ++j) {
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t k = 0; k < size1; ++k) {
Reg::LoadAlign(srcReg, srcUb + k * srcStride1 + i * srcStride0 + j * VF_LEN);
Reg::StoreAlign(dstUb + i * size1 * size2 + k * size2 + j * VF_LEN, srcReg, pregCnt);
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLAlignedWithBlock(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
uint16_t size0, uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1,
uint16_t srcStride2)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size3, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::MaskReg pregCnt;
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size1; ++i) {
uint32_t sreg = size3;
for (uint16_t j = 0; j < factor; ++j) {
pregCnt = Reg::UpdateMask<T>(sreg);
for (uint16_t k = 0; k < size0; ++k) {
for (uint16_t t = 0; t < size2; ++t) {
Reg::LoadAlign(srcReg, srcUb + j * VF_LEN +
i * srcStride1 + k * srcStride0 + t * srcStride2);
Reg::StoreAlign(dstUb + k * size1 * size2 * size3 +
i * size2 * size3 + t * size3 + j * VF_LEN, srcReg, pregCnt);
}
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLAlignedWithVL(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size1, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < factor; ++i) {
Reg::LoadAlign(srcReg, srcUb + i * VF_LEN);
for (uint16_t j = 0; j < size0; ++j) {
Reg::StoreAlign(dstUb + i * VF_LEN + j * size1, srcReg, pregFull);
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLAlignedWithVL(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size2, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < factor; ++j) {
for (uint16_t k = 0; k < size1; ++k) {
Reg::LoadAlign(srcReg, srcUb + i * srcStride0 + j * VF_LEN + k * srcStride1);
Reg::StoreAlign(dstUb + j * VF_LEN + k * size2 + i * size1 * size2, srcReg, pregFull);
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLAlignedWithVL(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = CeilDivision(size3, VF_LEN);
Reg::MaskReg pregFull = Reg::CreateMask<T>();
Reg::RegTensor<T> srcReg;
for (uint16_t i = 0; i < size1; ++i) {
for (uint16_t j = 0; j < factor; ++j) {
for (uint16_t k = 0; k < size0; ++k) {
for (uint16_t t = 0; t < size2; ++t) {
Reg::LoadAlign(srcReg, srcUb + i * srcStride1 +
j * VF_LEN + k * srcStride0 + t * srcStride2);
Reg::StoreAlign(dstUb + j * VF_LEN + t * size3 +
i * size2 * size3 + k * size1 * size2 * size3, srcReg, pregFull);
}
}
}
}
}
template <typename T>
__simd_vf__ inline void BrcLastLargerThanVLUnaligned(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = size1 / VF_LEN;
uint32_t size1tail = size1 - factor * VF_LEN;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
for (uint16_t i = 0; i < size0; ++i) {
BrcLoad<T>(srcReg, srcUb + i);
for (uint16_t j = 0; j < factor; ++j) {
Reg::StoreUnAlign(dstUb, srcReg, ureg0, VF_LEN);
}
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size1tail);
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcLastLargerThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = size2 / VF_LEN;
uint32_t size2tail = size2 - factor * VF_LEN;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
BrcLoad<T>(srcReg, srcUb + j * srcStride1 + i * srcStride0);
for (uint16_t k = 0; k < factor; ++k) {
Reg::StoreUnAlign(dstUb, srcReg, ureg0, VF_LEN);
}
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size2tail);
}
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcLastLargerThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = size3 / VF_LEN;
uint32_t size3tail = size3 - factor * VF_LEN;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
for (uint16_t k = 0; k < size2; ++k) {
BrcLoad<T>(srcReg, srcUb + i * srcStride0 + j * srcStride1 + k * srcStride2);
for (uint16_t t = 0; t < factor; ++t) {
Reg::StoreUnAlign(dstUb, srcReg, ureg0, VF_LEN);
}
Reg::StoreUnAlign(dstUb, srcReg, ureg0, size3tail);
}
}
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLUnaligned(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0, uint16_t size1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = size1 / VF_LEN;
uint32_t size1tail = size1 - factor * VF_LEN;
Reg::RegTensor<T> srcReg;
Reg::RegTensor<T> tmpReg;
Reg::UnalignReg ureg0;
uint32_t sreg = size1tail;
Reg::LoadAlign(tmpReg, srcUb + factor * VF_LEN);
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < factor; ++j) {
Reg::LoadAlign(srcReg, srcUb + j * VF_LEN);
Reg::StoreUnAlign(dstUb, srcReg, ureg0, VF_LEN);
}
Reg::StoreUnAlign(dstUb, tmpReg, ureg0, size1tail);
}
Reg::StoreUnAlignPost(dstUb, ureg0, 0);
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t srcStride0, uint16_t srcStride1)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = size2 / VF_LEN;
uint32_t size2tail = size2 - factor * VF_LEN;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0, ureg1;
uint32_t sreg = size2tail;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
__ubuf__ T *tmpSrcUb = srcUb + i * srcStride0 + j * srcStride1;
Reg::LoadUnAlignPre(ureg0, tmpSrcUb);
for (uint16_t k = 0; k < factor; ++k) {
Reg::LoadUnAlign(srcReg, ureg0, tmpSrcUb, VF_LEN);
Reg::StoreUnAlign(dstUb, srcReg, ureg1, VF_LEN);
}
Reg::LoadUnAlign(srcReg, ureg0, tmpSrcUb, sreg);
Reg::StoreUnAlign(dstUb, srcReg, ureg1, sreg);
}
}
Reg::StoreUnAlignPost(dstUb, ureg1, 0);
}
template <typename T>
__simd_vf__ inline void BrcNlastLargerThanVLUnaligned(__ubuf__ T *dstUb, __ubuf__ T *srcUb, uint16_t size0,
uint16_t size1, uint16_t size2, uint16_t size3, uint16_t srcStride0, uint16_t srcStride1, uint16_t srcStride2)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
uint16_t factor = size3 / VF_LEN;
uint32_t size3tail = size3 - factor * VF_LEN;
Reg::RegTensor<T> srcReg;
Reg::UnalignReg ureg0, ureg1;
uint32_t sreg = size3tail;
for (uint16_t i = 0; i < size0; ++i) {
for (uint16_t j = 0; j < size1; ++j) {
for (uint16_t k = 0; k < size2; ++k) {
__ubuf__ T *tmpSrcUb = srcUb + i * srcStride0 + j * srcStride1 + k * srcStride2;
Reg::LoadUnAlignPre(ureg0, tmpSrcUb);
for (uint16_t t = 0; t < factor; ++t) {
Reg::LoadUnAlign(srcReg, ureg0, tmpSrcUb, VF_LEN);
Reg::StoreUnAlign(dstUb, srcReg, ureg1, VF_LEN);
}
Reg::LoadUnAlign(srcReg, ureg0, tmpSrcUb, sreg);
Reg::StoreUnAlign(dstUb, srcReg, ureg1, sreg);
}
}
}
Reg::StoreUnAlignPost(dstUb, ureg1, 0);
}
template <typename T, int32_t constRank = -1>
__aicore__ inline bool BrcLastWrapperForTwoDim(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, const uint32_t *dstShape)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
using GatherIndexType = typename ExtractSignedTypeBySize<sizeof(T)>::T;
using BrcIndexType = typename ExtractIndexTypeBySize<sizeof(T)>::T;
uint16_t sizeI[2];
sizeI[0] = static_cast<uint16_t>(dstShape[0]);
sizeI[1] = static_cast<uint16_t>(dstShape[1]);
if (sizeI[1] == oneBlockElementNum && sizeof(T) != sizeof(uint8_t)) {
BrcLastE2B(dstUb, srcUb, sizeI[0], sizeI[1]);
} else if (sizeI[1] < VF_LEN_HALF) {
LocalTensor<T> indexLocal;
PopStackBuffer<T, TPosition::LCM>(indexLocal);
__ubuf__ GatherIndexType *indexUb1 = (__ubuf__ GatherIndexType *)indexLocal.GetPhyAddr();
__ubuf__ GatherIndexType *indexUb2 = (__ubuf__ GatherIndexType *)indexLocal.GetPhyAddr(VF_LEN);
GenLastGatherIndex<GatherIndexType>(indexUb1, sizeI[1], 0);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
GenLastGatherIndex<GatherIndexType>(indexUb2, sizeI[1], VF_LEN_HALF);
}
__ubuf__ BrcIndexType *indexUb = (__ubuf__ BrcIndexType *)indexLocal.GetPhyAddr();
if (sizeI[0] * sizeI[1] < VF_LEN) {
BrcLastGatherOne<T, BrcIndexType>(dstUb, srcUb, indexUb, sizeI[0], sizeI[1]);
} else if (sizeI[1] < VF_LEN) {
BrcLastGatherTwo<T, BrcIndexType>(dstUb, srcUb, indexUb, sizeI[0], sizeI[1]);
}
} else if (sizeI[1] <= VF_LEN) {
BrcLastLessThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
} else {
if (sizeI[1] % oneBlockElementNum == 0) {
BrcLastLargerThanVLAligned<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcLastLargerThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
}
}
}
return false;
}
template <typename T, int32_t constRank = -1>
__aicore__ inline bool BrcLastWrapperForThreeDim(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t *dstShape, const uint32_t *srcStride)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t sizeI[3];
uint16_t stride[3];
sizeI[0] = static_cast<uint16_t>(dstShape[0]);
sizeI[1] = static_cast<uint16_t>(dstShape[1]);
sizeI[2] = static_cast<uint16_t>(dstShape[2]);
stride[0] = static_cast<uint16_t>(srcStride[0]);
stride[1] = static_cast<uint16_t>(srcStride[1]);
stride[2] = static_cast<uint16_t>(srcStride[2]);
if (sizeI[2] == oneBlockElementNum && sizeof(T) != sizeof(uint8_t) && sizeI[1] * sizeI[2] > VF_LEN_HALF &&
sizeI[1] % DEFAULT_BLK_NUM == 0 && stride[1] != 0) {
if (sizeI[1] * sizeI[2] > VF_LEN) {
BrcLastE2BLargerThanVL(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0]);
} else {
BrcLastE2BLessThanVL(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0]);
}
} else if (sizeI[2] < VF_LEN_HALF && sizeof(T) != sizeof(uint8_t)) {
uint32_t newDstShape[3] = {dstShape[0], dstShape[1], dstShape[2]};
uint32_t newSrcStride[3] = {srcStride[0], srcStride[1], srcStride[2]};
GatherWrapper(dstUb, srcUb, newDstShape, newSrcStride);
} else if (sizeI[2] <= VF_LEN) {
if (sizeI[2] % oneBlockElementNum == 0) {
BrcLastLessThanVLAligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcLastLessThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
}
}
} else {
if (sizeI[2] % oneBlockElementNum == 0) {
BrcLastLargerThanVLAligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcLastLargerThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
}
}
}
return false;
}
template <typename T, int32_t constRank = -1>
__aicore__ inline bool BrcLastWrapperForFourDim(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t *dstShape, const uint32_t *srcStride)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t sizeI[4];
uint16_t stride[4];
sizeI[0] = static_cast<uint16_t>(dstShape[0]);
sizeI[1] = static_cast<uint16_t>(dstShape[1]);
sizeI[2] = static_cast<uint16_t>(dstShape[2]);
sizeI[3] = static_cast<uint16_t>(dstShape[3]);
stride[0] = static_cast<uint16_t>(srcStride[0]);
stride[1] = static_cast<uint16_t>(srcStride[1]);
stride[2] = static_cast<uint16_t>(srcStride[2]);
stride[3] = static_cast<uint16_t>(srcStride[3]);
if (sizeI[3] == oneBlockElementNum && sizeof(T) != sizeof(uint8_t) &&
stride[2] != 0 && sizeI[2] % DEFAULT_BLK_NUM == 0) {
BrcLastE2B(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1]);
} else if (sizeI[3] < VF_LEN_HALF && sizeof(T) != sizeof(uint8_t)) {
uint32_t newDstShape[4] = {dstShape[0], dstShape[1], dstShape[2], dstShape[3]};
uint32_t newSrcStride[4] = {srcStride[0], srcStride[1], srcStride[2], srcStride[3]};
GatherWrapperForFourDim(dstUb, srcUb, newDstShape, newSrcStride);
} else if (sizeI[3] <= VF_LEN) {
if (sizeI[3] % oneBlockElementNum == 0) {
BrcLastLessThanVLAligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcLastLessThanVLUnaligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
}
}
} else {
if (sizeI[3] % oneBlockElementNum == 0) {
BrcLastLargerThanVLAligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcLastLargerThanVLUnaligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
}
}
}
return false;
}
template <typename T, int32_t constRank = -1>
__aicore__ inline bool BrcNlastWrapperForTwoDim(
__ubuf__ T *dstUb, __ubuf__ T *srcUb, const uint32_t *dstShape)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
using GatherIndexType = typename ExtractSignedTypeBySize<sizeof(T)>::T;
using BrcIndexType = typename ExtractIndexTypeBySize<sizeof(T)>::T;
uint16_t sizeI[2];
sizeI[0] = static_cast<uint16_t>(dstShape[0]);
sizeI[1] = static_cast<uint16_t>(dstShape[1]);
if (sizeI[1] < VF_LEN_HALF) {
LocalTensor<T> indexLocal;
PopStackBuffer<T, TPosition::LCM>(indexLocal);
if (sizeI[1] % oneBlockElementNum == 0) {
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
__ubuf__ uint32_t *indexUb = (__ubuf__ uint32_t *)indexLocal.GetPhyAddr();
if (sizeI[1] / oneBlockElementNum == 1) {
indexUb[0] = 0;
indexUb[1] = 0;
indexUb[2] = 0;
indexUb[3] = 0;
indexUb[4] = 0;
indexUb[5] = 0;
indexUb[6] = 0;
indexUb[7] = 0;
} else if (sizeI[1] / oneBlockElementNum == 2) {
indexUb[0] = 0;
indexUb[1] = 32;
indexUb[2] = 0;
indexUb[3] = 32;
indexUb[4] = 0;
indexUb[5] = 32;
indexUb[6] = 0;
indexUb[7] = 32;
} else {
indexUb[0] = 0;
indexUb[1] = 32;
indexUb[2] = 64;
indexUb[3] = 0;
indexUb[4] = 32;
indexUb[5] = 64;
indexUb[6] = 0;
indexUb[7] = 0;
}
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
if (sizeI[0] * sizeI[1] < VF_LEN) {
BrcNlastGatherBOne<T>(dstUb, srcUb, (__ubuf__ uint32_t *)indexUb, sizeI[0], sizeI[1]);
} else if (sizeI[1] < VF_LEN) {
BrcNlastGatherBTwo<T>(dstUb, srcUb, (__ubuf__ uint32_t *)indexUb, sizeI[0], sizeI[1]);
}
} else {
__ubuf__ GatherIndexType *indexUb1 = (__ubuf__ GatherIndexType *)indexLocal.GetPhyAddr();
__ubuf__ GatherIndexType *indexUb2 = (__ubuf__ GatherIndexType *)indexLocal.GetPhyAddr(VF_LEN);
GenNlastGatherIndex<GatherIndexType>(indexUb1, sizeI[1], 0);
if constexpr (sizeof(T) == sizeof(uint8_t)) {
GenNlastGatherIndex<GatherIndexType>(indexUb2, sizeI[1], VF_LEN_HALF);
}
__ubuf__ BrcIndexType *indexUb = (__ubuf__ BrcIndexType *)indexLocal.GetPhyAddr();
if (sizeI[0] * sizeI[1] < VF_LEN) {
BrcNlastGatherOne<T, BrcIndexType>(dstUb, srcUb, indexUb, sizeI[0], sizeI[1]);
} else if (sizeI[1] < VF_LEN) {
BrcNlastGatherTwo<T, BrcIndexType>(dstUb, srcUb, indexUb, sizeI[0], sizeI[1]);
}
}
} else if (sizeI[1] <= VF_LEN) {
BrcNlastLessThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
} else {
if (sizeI[1] % oneBlockElementNum == 0) {
if (sizeI[1] % VF_LEN == 0 && sizeI[0] > DEFAULT_BLK_NUM) {
BrcNlastLargerThanVLAlignedWithVL<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
} else {
BrcNlastLargerThanVLAlignedWithBlock<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
}
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcNlastLargerThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1]);
}
}
}
return false;
}
template <typename T, int32_t constRank = -1>
__aicore__ inline bool BrcNlastWrapperForThreeDim(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t *dstShape, const uint32_t *srcStride)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t sizeI[3];
uint16_t stride[3];
sizeI[0] = static_cast<uint16_t>(dstShape[0]);
sizeI[1] = static_cast<uint16_t>(dstShape[1]);
sizeI[2] = static_cast<uint16_t>(dstShape[2]);
stride[0] = static_cast<uint16_t>(srcStride[0]);
stride[1] = static_cast<uint16_t>(srcStride[1]);
stride[2] = static_cast<uint16_t>(srcStride[2]);
if (sizeI[2] < VF_LEN_HALF && sizeof(T) != sizeof(uint8_t)) {
uint32_t newDstShape[3] = {dstShape[0], dstShape[1], dstShape[2]};
uint32_t newSrcStride[3] = {srcStride[0], srcStride[1], srcStride[2]};
GatherWrapper(dstUb, srcUb, newDstShape, newSrcStride);
} else if (sizeI[2] <= VF_LEN) {
BrcNlastLessThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
} else {
if (sizeI[2] % oneBlockElementNum == 0) {
if (sizeI[2] % VF_LEN == 0) {
BrcNlastLargerThanVLAlignedWithVL<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
} else {
BrcNlastLargerThanVLAlignedWithBlock<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
}
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcNlastLargerThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], stride[0], stride[1]);
}
}
}
return false;
}
template <typename T, int32_t constRank = -1>
__aicore__ inline bool BrcNlastWrapperForFourDim(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t *dstShape, const uint32_t *srcStride)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t sizeI[4];
uint16_t stride[4];
sizeI[0] = static_cast<uint16_t>(dstShape[0]);
sizeI[1] = static_cast<uint16_t>(dstShape[1]);
sizeI[2] = static_cast<uint16_t>(dstShape[2]);
sizeI[3] = static_cast<uint16_t>(dstShape[3]);
stride[0] = static_cast<uint16_t>(srcStride[0]);
stride[1] = static_cast<uint16_t>(srcStride[1]);
stride[2] = static_cast<uint16_t>(srcStride[2]);
stride[3] = static_cast<uint16_t>(srcStride[3]);
if (sizeI[3] < VF_LEN_HALF && sizeof(T) != sizeof(uint8_t)) {
uint32_t newDstShape[4] = {dstShape[0], dstShape[1], dstShape[2], dstShape[3]};
uint32_t newSrcStride[4] = {srcStride[0], srcStride[1], srcStride[2], srcStride[3]};
GatherWrapperForFourDim(dstUb, srcUb, newDstShape, newSrcStride);
} else if (sizeI[3] <= VF_LEN) {
if (sizeI[3] % oneBlockElementNum == 0) {
BrcNlastLessThanVLAligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcNlastLessThanVLUnaligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
}
}
} else {
if (sizeI[3] % oneBlockElementNum == 0) {
if (sizeI[3] % VF_LEN == 0) {
BrcNlastLargerThanVLAlignedWithVL<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
} else {
BrcNlastLargerThanVLAlignedWithBlock<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
}
} else {
if constexpr (constRank == -1) {
return true;
} else {
BrcNlastLargerThanVLUnaligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
}
}
}
return false;
}
template <typename T>
__aicore__ inline void BrcNlastWrapperForMoreDim(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t *dstShape, const uint32_t *dstStride, const uint32_t *srcStride)
{
uint16_t sizeI[4];
uint16_t stride[4];
sizeI[0] = static_cast<uint16_t>(dstShape[1]);
sizeI[1] = static_cast<uint16_t>(dstShape[2]);
sizeI[2] = static_cast<uint16_t>(dstShape[3]);
sizeI[3] = static_cast<uint16_t>(dstShape[4]);
stride[0] = static_cast<uint16_t>(srcStride[1]);
stride[1] = static_cast<uint16_t>(srcStride[2]);
stride[2] = static_cast<uint16_t>(srcStride[3]);
stride[3] = static_cast<uint16_t>(srcStride[4]);
uint32_t totalDim = 9;
__ubuf__ T *srcUbTmp = srcUb;
__ubuf__ T *dstUbTmp = dstUb;
for (uint16_t p = 0; p < static_cast<uint16_t>(dstShape[0]); ++p) {
dstUb = dstUbTmp + p * dstStride[0];
srcUb = srcUbTmp + p * srcStride[0];
uint32_t newDstShape[4] = {
dstShape[1], dstShape[2], dstShape[3], dstShape[4]};
uint32_t newSrcStride[4] = {
srcStride[1], srcStride[2], srcStride[3], srcStride[4]};
GatherWrapperForFourDim(dstUb, srcUb, newDstShape, newSrcStride);
}
}
template <typename T>
__aicore__ inline void BrcNlastWrapperForMoreDimDynamicShape(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t dim, const uint32_t *dstShape, const uint32_t *dstStride, const uint32_t *srcStride)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint16_t VF_LEN_HALF = GetVecLen() / 2 / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t sizeI[4] = {1, 1, 1, 1};
if(dim > 4) {
sizeI[0] = dstShape[dim - 4];
sizeI[1] = dstShape[dim - 3];
sizeI[2] = dstShape[dim - 2];
sizeI[3] = dstShape[dim - 1];
} else {
for (uint16_t i = 0; i < dim; ++i) {
sizeI[4 - dim + i] = dstShape[i];
}
}
uint32_t totalDim = 9;
uint16_t loops[5] = {1, 1, 1, 1, 1};
for (int16_t i = dim - 5, j = 4; i >= 0; --i, --j) {
loops[j] = static_cast<uint16_t>(dstShape[i]);
}
uint16_t stride[4] = {0, 0, 0, 0};
if (dim > 4) {
stride[0] = srcStride[dim - 4];
stride[1] = srcStride[dim - 3];
stride[2] = srcStride[dim - 2];
stride[3] = srcStride[dim - 1];
} else {
for (uint16_t i = 0; i < dim; ++i) {
stride[4 - dim + i] = srcStride[i];
}
}
__ubuf__ T *srcUbTmp = srcUb;
__ubuf__ T *dstUbTmp = dstUb;
for (uint16_t i = 0; i < loops[0]; ++i) {
for (uint16_t j = 0; j < loops[1]; ++j) {
for (uint16_t k = 0; k < loops[2]; ++k) {
for (uint16_t t = 0; t < loops[3]; ++t) {
for (uint16_t p = 0; p < loops[4]; ++p) {
dstUb = dstUbTmp + p * dstStride[(dim - 5 + totalDim) % totalDim] +
t * dstStride[(dim - 6 + totalDim) % totalDim] +
k * dstStride[(dim - 7 + totalDim) % totalDim] +
j * dstStride[(dim - 8 + totalDim) % totalDim] +
i * dstStride[(dim - 9 + totalDim) % totalDim];
srcUb = srcUbTmp + p * srcStride[(dim - 5 + totalDim) % totalDim] +
t * srcStride[(dim - 6 + totalDim) % totalDim] +
k * srcStride[(dim - 7 + totalDim) % totalDim] +
j * srcStride[(dim - 8 + totalDim) % totalDim] +
i * srcStride[(dim - 9 + totalDim) % totalDim];
if (sizeI[3] < VF_LEN_HALF && sizeof(T) != sizeof(uint8_t)) {
uint32_t newDstShape[4] = {
dstShape[dim - 4], dstShape[dim - 3], dstShape[dim - 2], dstShape[dim - 1]};
uint32_t newSrcStride[4] = {
srcStride[dim - 4], srcStride[dim - 3], srcStride[dim - 2], srcStride[dim - 1]};
GatherWrapperForFourDim(dstUb, srcUb, newDstShape, newSrcStride);
} else if (sizeI[3] <= VF_LEN) {
if (sizeI[3] % oneBlockElementNum == 0) {
BrcNlastLessThanVLAligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2],
sizeI[3], stride[0], stride[1], stride[2]);
} else {
BrcNlastLessThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2],
sizeI[3], stride[0], stride[1], stride[2]);
}
} else {
if (sizeI[3] % oneBlockElementNum == 0) {
BrcNlastLargerThanVLAlignedWithBlock<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2],
sizeI[3], stride[0], stride[1], stride[2]);
} else {
BrcNlastLargerThanVLUnaligned<T>(dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2],
sizeI[3], stride[0], stride[1], stride[2]);
}
}
}
}
}
}
}
}
template <typename T>
__aicore__ inline void BrcLastWrapperForMoreDimDynamicShape(__ubuf__ T *dstUb, __ubuf__ T *srcUb,
const uint32_t dim, const uint32_t *dstShape, const uint32_t *dstStride, const uint32_t *srcStride)
{
constexpr uint16_t VF_LEN = GetVecLen() / sizeof(T);
constexpr uint32_t oneBlockElementNum = GetDataBlockSizeInBytes() / sizeof(T);
uint16_t sizeI[4] = {1, 1, 1, 1};
if(dim > 4) {
sizeI[0] = dstShape[dim - 4];
sizeI[1] = dstShape[dim - 3];
sizeI[2] = dstShape[dim - 2];
sizeI[3] = dstShape[dim - 1];
} else {
for (uint16_t i = 0; i < dim; ++i) {
sizeI[4 - dim + i] = dstShape[i];
}
}
uint32_t totalDim = 9;
uint16_t loops[5] = {1, 1, 1, 1, 1};
for (int16_t i = dim - 5, j = 4; i >= 0; --i, --j) {
loops[j] = static_cast<uint16_t>(dstShape[i]);
}
uint16_t stride[4] = {0, 0, 0, 0};
if (dim > 4) {
stride[0] = srcStride[dim - 4];
stride[1] = srcStride[dim - 3];
stride[2] = srcStride[dim - 2];
stride[3] = srcStride[dim - 1];
} else {
for (uint16_t i = 0; i < dim; ++i) {
stride[4 - dim + i] = srcStride[i];
}
}
__ubuf__ T *srcUbTmp = srcUb;
__ubuf__ T *dstUbTmp = dstUb;
for (uint16_t i = 0; i < loops[0]; ++i) {
for (uint16_t j = 0; j < loops[1]; ++j) {
for (uint16_t k = 0; k < loops[2]; ++k) {
for (uint16_t t = 0; t < loops[3]; ++t) {
for (uint16_t p = 0; p < loops[4]; ++p) {
dstUb = dstUbTmp + p * dstStride[(dim - 5 + totalDim) % totalDim] +
t * dstStride[(dim - 6 + totalDim) % totalDim] +
k * dstStride[(dim - 7 + totalDim) % totalDim] +
j * dstStride[(dim - 8 + totalDim) % totalDim] +
i * dstStride[(dim - 9 + totalDim) % totalDim];
srcUb = srcUbTmp + p * srcStride[(dim - 5 + totalDim) % totalDim] +
t * srcStride[(dim - 6 + totalDim) % totalDim] +
k * srcStride[(dim - 7 + totalDim) % totalDim] +
j * srcStride[(dim - 8 + totalDim) % totalDim] +
i * srcStride[(dim - 9 + totalDim) % totalDim];
if (sizeI[3] <= VF_LEN) {
BrcLastLessThanVLUnaligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
} else {
BrcLastLargerThanVLUnaligned<T>(
dstUb, srcUb, sizeI[0], sizeI[1], sizeI[2], sizeI[3], stride[0], stride[1], stride[2]);
}
}
}
}
}
}
}
}
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_PAD_BROADCAST_BROADCAST_C310_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_PAD_BROADCAST_BROADCAST_C310_IMPL_H__
#endif