* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_micro_common_impl.h
* \brief
*/
#ifndef ASCENDC_MODULE_MICRO_COMMON_IMPL_H
#define ASCENDC_MODULE_MICRO_COMMON_IMPL_H
#include "kernel_tensor.h"
#include "micro_api/kernel_micro_utils.h"
namespace AscendC {
namespace MicroAPI {
template <typename T, StoreDist dist> __aicore__ inline constexpr StoreDist GetStoreDist()
{
if constexpr (dist == StoreDist::DIST_NORM) {
static_assert(SupportBytes<T, 1, 2, 4, 8>(),
"StoreDist DIST_NORM only support type b8/b16/b32/b64 on current device");
if constexpr (sizeof(T) == 1) {
return StoreDist::DIST_NORM_B8;
} else if constexpr (sizeof(T) == 2) {
return StoreDist::DIST_NORM_B16;
} else if constexpr (sizeof(T) == 4) {
return StoreDist::DIST_NORM_B32;
} else if constexpr (sizeof(T) == 8) {
return StoreDist::DIST_NORM_B32;
}
}
return dist;
}
template <typename T, const RegTrait& otherTrait = RegTraitNumOne> constexpr __aicore__ inline bool CheckRegTrait()
{
constexpr RegTrait regTrait = T::trait;
return regTrait.REG_NUM == otherTrait.REG_NUM;
}
#ifndef __ASC_NPU_HOST__
template <RoundMode mode> __aicore__ inline constexpr ::ROUND GetRound()
{
#if defined(ASCENDC_CPU_DEBUG) && ASCENDC_CPU_DEBUG == 1
if constexpr (mode == RoundMode::CAST_RINT) {
return ::ROUND::CAST_RINT;
} else if constexpr (mode == RoundMode::CAST_ROUND) {
return ::ROUND::CAST_ROUND;
} else if constexpr (mode == RoundMode::CAST_FLOOR) {
return ::ROUND::CAST_FLOOR;
} else if constexpr (mode == RoundMode::CAST_CEIL) {
return ::ROUND::CAST_CEIL;
} else if constexpr (mode == RoundMode::CAST_TRUNC) {
return ::ROUND::CAST_TRUNC;
} else if constexpr (mode == RoundMode::CAST_ODD) {
return ::ROUND::CAST_ODD;
} else {
return ::ROUND::CAST_HYBRID;
}
#else
if constexpr (mode == RoundMode::CAST_RINT) {
return ::ROUND::R;
} else if constexpr (mode == RoundMode::CAST_ROUND) {
return ::ROUND::A;
} else if constexpr (mode == RoundMode::CAST_FLOOR) {
return ::ROUND::F;
} else if constexpr (mode == RoundMode::CAST_CEIL) {
return ::ROUND::C;
} else if constexpr (mode == RoundMode::CAST_TRUNC) {
return ::ROUND::Z;
} else if constexpr (mode == RoundMode::CAST_ODD) {
return ::ROUND::O;
} else {
return ::ROUND::H;
}
#endif
}
#endif
#ifndef __ASC_NPU_HOST__
template <MaskMergeMode mode> __aicore__ inline constexpr auto GetMaskMergeMode()
{
#if defined(ASCENDC_CPU_DEBUG) && ASCENDC_CPU_DEBUG == 1
return std::integral_constant<::CpuMode, static_cast<::CpuMode>(mode)>();
#else
return std::integral_constant<::Mode, static_cast<::Mode>(mode)>();
#endif
}
template <MemType src, MemType dst> __simd_callee__ inline void LocalMemBarImpl()
{
static_assert((SupportEnum<src, MemType::VEC_STORE>() && SupportEnum<dst, MemType::VEC_LOAD>()) ||
(SupportEnum<src, MemType::VEC_LOAD>() && SupportEnum<dst, MemType::VEC_STORE>()) ||
(SupportEnum<src, MemType::VEC_STORE>() && SupportEnum<dst, MemType::VEC_STORE>()) ||
(SupportEnum<src, MemType::VEC_STORE>() && SupportEnum<dst, MemType::SCALAR_LOAD>()) ||
(SupportEnum<src, MemType::VEC_STORE>() && SupportEnum<dst, MemType::SCALAR_STORE>()) ||
(SupportEnum<src, MemType::VEC_LOAD>() && SupportEnum<dst, MemType::SCALAR_STORE>()) ||
(SupportEnum<src, MemType::SCALAR_STORE>() && SupportEnum<dst, MemType::VEC_LOAD>()) ||
(SupportEnum<src, MemType::SCALAR_STORE>() && SupportEnum<dst, MemType::VEC_STORE>()) ||
(SupportEnum<src, MemType::SCALAR_LOAD>() && SupportEnum<dst, MemType::VEC_STORE>()) ||
(SupportEnum<src, MemType::VEC_ALL>() && SupportEnum<dst, MemType::VEC_ALL>()) ||
(SupportEnum<src, MemType::VEC_ALL>() && SupportEnum<dst, MemType::SCALAR_ALL>()) ||
(SupportEnum<src, MemType::SCALAR_ALL>() && SupportEnum<dst, MemType::VEC_ALL>()),
"LocalMemBar does support current MemType Combination on current device!");
if constexpr (src == MemType::VEC_STORE && dst == MemType::VEC_LOAD) {
mem_bar(VST_VLD);
} else if constexpr (src == MemType::VEC_LOAD && dst == MemType::VEC_STORE) {
mem_bar(VLD_VST);
} else if constexpr (src == MemType::VEC_STORE && dst == MemType::VEC_STORE) {
mem_bar(VST_VST);
} else if constexpr (src == MemType::VEC_STORE && dst == MemType::SCALAR_LOAD) {
mem_bar(VST_LD);
} else if constexpr (src == MemType::VEC_STORE && dst == MemType::SCALAR_STORE) {
mem_bar(VST_ST);
} else if constexpr (src == MemType::VEC_LOAD && dst == MemType::SCALAR_STORE) {
mem_bar(VLD_ST);
} else if constexpr (src == MemType::SCALAR_STORE && dst == MemType::VEC_LOAD) {
mem_bar(ST_VLD);
} else if constexpr (src == MemType::SCALAR_STORE && dst == MemType::VEC_STORE) {
mem_bar(ST_VST);
} else if constexpr (src == MemType::SCALAR_LOAD && dst == MemType::VEC_STORE) {
mem_bar(LD_VST);
} else if constexpr (src == MemType::VEC_ALL && dst == MemType::VEC_ALL) {
mem_bar(VV_ALL);
} else if constexpr (src == MemType::VEC_ALL && dst == MemType::SCALAR_ALL) {
mem_bar(VS_ALL);
} else if constexpr (src == MemType::SCALAR_ALL && dst == MemType::VEC_ALL) {
mem_bar(SV_ALL);
}
}
#endif
template <typename T, typename U, typename ShortType>
__simd_callee__ inline void TraitOneToTaitTwoTmpl(T& dstReg, U& srcReg)
{
using ActualT1 = typename U::ActualT;
using ActualT2 = typename T::ActualT;
static_assert(CheckRegTrait<T, RegTraitNumTwo>() && CheckRegTrait<U, RegTraitNumOne>(),
"T should be RegTraitNumTwo and U should be RegTraitNumOne");
static_assert(sizeof(ActualT2) == (sizeof(ShortType) * 2) && sizeof(ActualT1) == (sizeof(ShortType) * 2),
"T and U should be 2 times of shortType lenth");
RegTensor<ShortType> zeroReg;
MaskReg maskFull = CreateMask<ShortType, MaskPattern::ALL>();
Duplicate(zeroReg, 0, maskFull);
DeInterleave((RegTensor<ShortType>&)dstReg.reg[0], (RegTensor<ShortType>&)dstReg.reg[1],
(RegTensor<ShortType>&)srcReg, zeroReg);
}
template <typename T, typename U, typename ShortType>
__simd_callee__ inline void TraitTwoToTaitOneTmpl(T& dstReg, U& srcReg)
{
using ActualT1 = typename T::ActualT;
using ActualT2 = typename U::ActualT;
static_assert(CheckRegTrait<T, RegTraitNumOne>() && CheckRegTrait<U, RegTraitNumTwo>(),
"T should be RegTraitNumOne and U should be RegTraitNumTwo");
static_assert(sizeof(ActualT2) == (sizeof(ShortType) * 2) && sizeof(ActualT1) == (sizeof(ShortType) * 2),
"U and T should be 2 times of shortType lenth");
RegTensor<ShortType> dstRegShortFake;
Interleave((RegTensor<ShortType>&)dstReg, dstRegShortFake, (RegTensor<ShortType>&)srcReg.reg[0],
(RegTensor<ShortType>&)srcReg.reg[1]);
}
template <typename T = DefaultType, MaskMergeMode mode = MaskMergeMode::MERGING, typename RegT>
__simd_callee__ inline void CopyMerging(RegT& dstReg, RegT& srcReg, MaskReg& mask)
{
using ActualT = typename RegT::ActualT;
constexpr auto modeValue = GetMaskMergeMode<mode>();
if constexpr (IsSameType<ActualT, bool>::value) {
vmov((RegTensor<int8_t>&)dstReg, (RegTensor<int8_t>&)srcReg, mask, modeValue);
} else if constexpr (sizeof(ActualT) == 1) {
vmov((RegTensor<uint8_t>&)dstReg, (RegTensor<uint8_t>&)srcReg, mask, modeValue);
} else if constexpr (sizeof(ActualT) == 2) {
vmov((RegTensor<uint16_t>&)dstReg, (RegTensor<uint16_t>&)srcReg, mask, modeValue);
} else if constexpr (sizeof(ActualT) == 4) {
vmov((RegTensor<uint32_t>&)dstReg, (RegTensor<uint32_t>&)srcReg, mask, modeValue);
} else if constexpr (sizeof(ActualT) == 8) {
if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
constexpr auto lowerDist =
std::integral_constant<::HiloPart, static_cast<::HiloPart>(HighLowPart::LOWEST)>();
MaskReg dstMask;
MaskReg tmpMask;
MaskReg dumpMask;
ppack(tmpMask, mask, lowerDist);
pintlv_b32(dstMask, dumpMask, tmpMask, tmpMask);
vmov((RegTensor<uint32_t> &)dstReg, (RegTensor<uint32_t> &)srcReg, dstMask, modeValue);
} else if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
vmov((RegTensor<uint32_t> &)dstReg.reg[0], (RegTensor<uint32_t> &)srcReg.reg[0], mask, modeValue);
vmov((RegTensor<uint32_t> &)dstReg.reg[1], (RegTensor<uint32_t> &)srcReg.reg[1], mask, modeValue);
}
}
}
template <typename T, typename U>
__simd_callee__ inline void B64TraitOneToTaitTwo(T& dstReg, U& srcReg)
{
TraitOneToTaitTwoTmpl<T, U, uint32_t>(dstReg, srcReg);
}
template <typename T, typename U>
__simd_callee__ inline void B64TraitTwoToTaitOne(T& dstReg, U& srcReg)
{
TraitTwoToTaitOneTmpl<T, U, uint32_t>(dstReg, srcReg);
}
template <typename T, typename U>
__simd_callee__ inline void B32TraitOneToTaitTwo(T& dstReg, U& srcReg)
{
TraitOneToTaitTwoTmpl<T, U, uint16_t>(dstReg, srcReg);
}
template <typename T, typename U>
__simd_callee__ inline void B32TraitTwoToTaitOne(T& dstReg, U& srcReg)
{
TraitTwoToTaitOneTmpl<T, U, uint16_t>(dstReg, srcReg);
}
}
}
#endif