* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_micro_datacopy_impl.h
* \brief
*/
#ifndef ASCENDC_MODULE_MICRO_DATACOPY_IMPL_H
#define ASCENDC_MODULE_MICRO_DATACOPY_IMPL_H
#include "micro_api/dav_m510/kernel_micro_datacopy_load_impl.h"
#include "micro_api/dav_m510/kernel_micro_datacopy_store_impl.h"
namespace AscendC {
namespace MicroAPI {
template <typename T0 = DefaultType, typename T1, typename T2 = DefaultType, typename T3, typename T4>
__simd_callee__ inline void DataCopyGatherB64Impl(T3& dstReg, __ubuf__ T1* baseAddr, T4& index, MaskReg& mask)
{
if constexpr (CheckRegTrait<T3, RegTraitNumOne>()) {
MaskReg dstMask;
RegTensor<uint32_t> oddIndex;
RegTensor<uint32_t> evenIndex;
RegTensor<uint32_t> oddReg;
RegTensor<uint32_t> evenReg;
RegTensor<uint32_t> tmpReg;
MaskPack(dstMask, mask);
MaskReg lowerMask = CreateMask<uint32_t, MaskPattern::VL32>();
MaskReg preg = CreateMask<uint32_t, MaskPattern::ALL>();
MaskAnd(dstMask, dstMask, lowerMask, preg);
Muls(oddIndex, index, uint32_t(2), dstMask);
Adds(evenIndex, oddIndex, uint32_t(1), dstMask);
vgather2(oddReg, (__ubuf__ uint32_t*)baseAddr, oddIndex, dstMask);
vgather2(evenReg, (__ubuf__ uint32_t*)baseAddr, evenIndex, dstMask);
Interleave((RegTensor<uint32_t>&)dstReg, tmpReg, oddReg, evenReg);
} else {
RegTensor<uint32_t> oddIndex;
RegTensor<uint32_t> evenIndex;
Muls(oddIndex, index, uint32_t(2), mask);
Adds(evenIndex, oddIndex, uint32_t(1), mask);
vgather2((RegTensor<uint32_t>&)dstReg.reg[0], (__ubuf__ uint32_t*)baseAddr, oddIndex, mask);
vgather2((RegTensor<uint32_t>&)dstReg.reg[1], (__ubuf__ uint32_t*)baseAddr, evenIndex, mask);
}
}
template <typename T0 = DefaultType, typename T1, typename T2 = DefaultType, typename T3, typename T4>
__simd_callee__ inline void DataCopyGatherImpl(T3& dstReg, __ubuf__ T1* baseAddr, T4& index, MaskReg& mask)
{
using ActualDstT = typename T3::ActualT;
using ActualIndexT = typename T4::ActualT;
static_assert(std::is_same_v<T0, DefaultType> || std::is_same_v<T0, ActualDstT>, "T0 type is not correct!");
static_assert(std::is_same_v<T2, DefaultType> || std::is_same_v<T2, ActualIndexT>, "T2 type is not correct!");
static_assert((sizeof(T1) == 1 && sizeof(ActualDstT) == 2 && std::is_same_v<ActualIndexT, uint16_t>) ||
(sizeof(T1) == 2 && sizeof(ActualDstT) == 2 && std::is_same_v<ActualIndexT, uint16_t>) ||
(sizeof(T1) == 4 && sizeof(ActualDstT) == 4 && std::is_same_v<ActualIndexT, uint32_t>) ||
(sizeof(T1) == 8 && sizeof(ActualDstT) == 8 && SupportType<ActualIndexT, uint32_t, uint64_t>()),
"DataCopyGather only support src data type b8/b16/b32/b64 with dst type is b16/b16/b32/b64 respectively and "
"each index type is u16/u16/u32/(u32/u64) respectively on current device");
static_assert(!(sizeof(T1) == 8 && std::is_same_v<ActualIndexT, uint64_t> &&
CheckRegTrait<T4, RegTraitNumOne>() && CheckRegTrait<T3, RegTraitNumTwo>()),
"current data type is not supported on current device!");
if constexpr (sizeof(T1) == 1 && sizeof(ActualDstT) == 2) {
vgather2((vector_s16&)dstReg, (__ubuf__ int8_t*)baseAddr, index, mask);
} else if constexpr (sizeof(T1) == 2 && sizeof(ActualDstT) == 2) {
vgather2((vector_s16&)dstReg, (__ubuf__ int16_t*)baseAddr, index, mask);
} else if constexpr (sizeof(T1) == 4 && sizeof(ActualDstT) == 4) {
vgather2((vector_s32&)dstReg, (__ubuf__ int32_t*)baseAddr, index, mask);
} else {
if constexpr (std::is_same_v<ActualIndexT, uint32_t>) {
DataCopyGatherB64Impl(dstReg, baseAddr, index, mask);
} else if constexpr (std::is_same_v<ActualIndexT, uint64_t>) {
if constexpr (CheckRegTrait<T4, RegTraitNumOne>() && CheckRegTrait<T3, RegTraitNumOne>()) {
RegTensor<uint32_t> lowIndex;
RegTensor<uint32_t> highIndex;
DeInterleave(lowIndex, highIndex, (RegTensor<uint32_t>&)index, (RegTensor<uint32_t>&)index);
DataCopyGatherB64Impl(dstReg, baseAddr, lowIndex, mask);
} else if constexpr ((CheckRegTrait<T4, RegTraitNumTwo>() && CheckRegTrait<T3, RegTraitNumOne>()) ||
(CheckRegTrait<T4, RegTraitNumTwo>() && CheckRegTrait<T3, RegTraitNumTwo>())) {
DataCopyGatherB64Impl(dstReg, baseAddr, (RegTensor<uint32_t>&)index.reg[0], mask);
}
}
}
}
template <typename T = DefaultType, typename U, typename S>
__simd_callee__ inline void DataCopyGatherBImpl(U& dstReg, __ubuf__ T* baseAddr, S& index, MaskReg& mask)
{
using ActualT = typename U::ActualT;
using ActualIndexT = typename S::ActualT;
static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
static_assert(std::is_same_v<ActualIndexT, uint32_t>, "IndexT type is not correct!");
static_assert(CheckRegTrait<U, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
static_assert(CheckRegTrait<S, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(),
"DataCopyGatherB only support src & dst datatype b8/b16/b32/b64 on current device");
if constexpr (sizeof(ActualT) == 1) {
vgatherb((vector_s8&)dstReg, (__ubuf__ int8_t*)baseAddr, index, mask);
} else if constexpr (sizeof(ActualT) == 2) {
vgatherb((vector_s16&)dstReg, (__ubuf__ int16_t*)baseAddr, index, mask);
} else if constexpr (sizeof(ActualT) == 4) {
vgatherb((vector_s32&)dstReg, (__ubuf__ int32_t*)baseAddr, index, mask);
} else {
vgatherb((vector_s64&)dstReg, (__ubuf__ int64_t*)baseAddr, index, mask);
}
}
template <typename T = DefaultType, typename U = DefaultType, typename S, typename V>
__simd_callee__ inline void DataCopyScatterB64Impl(__ubuf__ T* baseAddr, S& srcReg, V& index, MaskReg& mask)
{
if constexpr (CheckRegTrait<S, RegTraitNumOne>()) {
MaskReg dstMask;
RegTensor<uint32_t> oddIndex;
RegTensor<uint32_t> evenIndex;
RegTensor<uint32_t> oddReg;
RegTensor<uint32_t> evenReg;
RegTensor<uint32_t> dstReg0;
RegTensor<uint32_t> dstReg1;
MaskPack(dstMask, mask);
MaskReg lowerMask = CreateMask<uint32_t, MaskPattern::VL32>();
MaskReg preg = CreateMask<uint32_t, MaskPattern::ALL>();
MaskAnd(dstMask, dstMask, lowerMask, preg);
Muls(oddIndex, index, uint32_t(2), dstMask);
Adds(evenIndex, oddIndex, uint32_t(1), dstMask);
DeInterleave(dstReg0, dstReg1, (RegTensor<uint32_t>&)srcReg, (RegTensor<uint32_t>&)srcReg);
vscatter(dstReg0, (__ubuf__ uint32_t*)baseAddr, oddIndex, dstMask);
vscatter(dstReg1, (__ubuf__ uint32_t*)baseAddr, evenIndex, dstMask);
} else {
RegTensor<uint32_t> oddIndex;
RegTensor<uint32_t> evenIndex;
Muls(oddIndex, index, uint32_t(2), mask);
Adds(evenIndex, oddIndex, uint32_t(1), mask);
vscatter((RegTensor<uint32_t>&)srcReg.reg[0], (__ubuf__ uint32_t*)baseAddr, oddIndex, mask);
vscatter((RegTensor<uint32_t>&)srcReg.reg[1], (__ubuf__ uint32_t*)baseAddr, evenIndex, mask);
}
}
template <typename T = DefaultType, typename U = DefaultType, typename S, typename V>
__simd_callee__ inline void DataCopyScatterImpl(__ubuf__ T* baseAddr, S& srcReg, V& index, MaskReg& mask)
{
using ActualT = typename S::ActualT;
using ActualIndexT = typename V::ActualT;
static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
static_assert(std::is_same_v<U, DefaultType> || std::is_same_v<U, ActualIndexT>, "U type is not correct!");
static_assert((sizeof(ActualT) == 1 && std::is_same_v<ActualIndexT, uint16_t>) ||
(sizeof(ActualT) == 2 && std::is_same_v<ActualIndexT, uint16_t>) ||
(sizeof(ActualT) == 4 && std::is_same_v<ActualIndexT, uint32_t>) ||
(sizeof(ActualT) == 8 && SupportType<ActualIndexT, uint32_t, uint64_t>()),
"DataCopyScatter only support data type b8/b16/b32/b64"
"with each index type is u16/u16/u32/(u32/u64) respectively on current device");
static_assert(!(sizeof(ActualT) == 8 && std::is_same_v<ActualIndexT, uint64_t> &&
CheckRegTrait<S, RegTraitNumTwo>() && CheckRegTrait<V, RegTraitNumOne>()),
"current data type is not supported on current device!");
if constexpr (sizeof(ActualT) == 8) {
if constexpr (std::is_same_v<ActualIndexT, uint32_t>) {
DataCopyScatterB64Impl(baseAddr, srcReg, index, mask);
} else if constexpr (std::is_same_v<ActualIndexT, uint64_t>) {
if constexpr (CheckRegTrait<S, RegTraitNumOne>() && CheckRegTrait<V, RegTraitNumOne>()) {
RegTensor<uint32_t> lowIndex;
RegTensor<uint32_t> highIndex;
DeInterleave(lowIndex, highIndex, (RegTensor<uint32_t>&)index, (RegTensor<uint32_t>&)index);
DataCopyScatterB64Impl(baseAddr, srcReg, lowIndex, mask);
} else if constexpr ((CheckRegTrait<S, RegTraitNumOne>() && CheckRegTrait<V, RegTraitNumTwo>()) ||
(CheckRegTrait<S, RegTraitNumTwo>() && CheckRegTrait<V, RegTraitNumTwo>())) {
DataCopyScatterB64Impl(baseAddr, srcReg, (RegTensor<uint32_t>&)index.reg[0], mask);
}
}
} else {
vscatter(srcReg, baseAddr, index, mask);
}
}
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(MaskReg& mask, __ubuf__ T* srcAddr, AddrReg offset)
{
static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_US, MaskDist::DIST_DS>(),
"DataCopy not support this dist on current device");
constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
pld(mask, (__ubuf__ uint32_t*)srcAddr, offset, distValue);
}
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(MaskReg& mask, __ubuf__ T* srcAddr)
{
static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_US, MaskDist::DIST_DS>(),
"DataCopy not support this dist on current device");
constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
plds(mask, (__ubuf__ uint32_t*)srcAddr, 0, distValue);
}
template <typename T, PostLiteral postMode, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(MaskReg& mask, __ubuf__ T*& srcAddr, int32_t offset)
{
static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_US, MaskDist::DIST_DS>(),
"DataCopy not support this dist on current device");
constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
plds(mask, (__ubuf__ uint32_t*&)srcAddr, offset, distValue, postValue);
}
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(__ubuf__ T* dstAddr, MaskReg& mask, AddrReg offset)
{
static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_PACK>(),
"DataCopy not support this dist on current device");
constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
pst(mask, (__ubuf__ uint32_t*)dstAddr, offset, distValue);
}
template <typename T, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(__ubuf__ T* dstAddr, MaskReg& mask)
{
static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_PACK>(),
"DataCopy not support this dist on current device");
constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
psts(mask, (__ubuf__ uint32_t*)dstAddr, 0, distValue);
}
template <typename T, PostLiteral postMode, MaskDist dist = MaskDist::DIST_NORM>
__simd_callee__ inline void DataCopyImpl(__ubuf__ T*& dstAddr, MaskReg& mask, int32_t offset)
{
static_assert(SupportBytes<T, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
static_assert(SupportEnum<dist, MaskDist::DIST_NORM, MaskDist::DIST_PACK>(),
"DataCopy not support this dist on current device");
constexpr auto distValue = std::integral_constant<::Dist, static_cast<::Dist>(dist)>();
constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
psts(mask, (__ubuf__ uint32_t*&)dstAddr, offset, distValue, postValue);
}
template <typename T>
__simd_callee__ inline void DataCopyUnAlignImpl(__ubuf__ T*& dstAddr, MaskReg& mask, UnalignReg& ureg)
{
static_assert(SupportBytes<T, 2, 4>(), "DataCopy only support type b16/b32 on current device");
if constexpr (sizeof(T) == 2) {
pstu(ureg, mask, (__ubuf__ uint16_t*&)dstAddr);
} else if constexpr (sizeof(T) == 4) {
pstu(ureg, mask, (__ubuf__ uint32_t*&)dstAddr);
}
}
}
}
#endif