* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file philox_3510_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/math/philox/philox_3510_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/math/philox.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_PHILOX_PHILOX_C310_IMPL_H__
#endif
#ifndef IMPL_MATH_PHILOX_PHILOX_C310_IMPL_H
#define IMPL_MATH_PHILOX_PHILOX_C310_IMPL_H
#include <cstdint>
#include "kernel_basic_intf.h"
#include "include/adv_api/math/philox_utils.h"
namespace AscendC {
namespace PhiloxInternal {
constexpr uint32_t CONST_MUL_0 = 0xD2511F53;
constexpr uint32_t CONST_MUL_1 = 0xCD9E8D57;
constexpr uint32_t CONST_KEY_ADD_0 = 0x9E3779B9;
constexpr uint32_t CONST_KEY_ADD_1 = 0xBB67AE85;
constexpr uint16_t PHILOX_ONCE_COUNTER_BIT = 128;
constexpr uint16_t PHILOX_ONCE_COUNTER_BYTE = PHILOX_ONCE_COUNTER_BIT / 8;
constexpr uint16_t PHILOX_ONCE_COUNTER_NUM = PHILOX_ONCE_COUNTER_BYTE / sizeof(uint32_t);
constexpr uint32_t ELE_CNT_B32_ONCE = GetVecLen() / sizeof(uint32_t);
constexpr uint16_t PHILOX_ONCE_REPEAT_NUM = PHILOX_ONCE_COUNTER_NUM * ELE_CNT_B32_ONCE;
constexpr uint32_t MANTISSA = static_cast<uint32_t>(0x7fffffu);
constexpr uint32_t EXP_MASK = static_cast<uint32_t>(127) << 23u;
}
struct philoxStruct {
uint32_t philoxKey[PHILOX_KEY_SIZE] = {0};
uint32_t philoxCounter[PHILOX_COUNTER_SIZE] = {0};
};
__simd_callee__ inline void AddWith128Bits(
Reg::RegTensor<uint32_t>& ctr0, Reg::RegTensor<uint32_t>& ctr1, Reg::RegTensor<uint32_t>& ctr2,
Reg::RegTensor<uint32_t>& ctr3, Reg::RegTensor<uint32_t>& value, Reg::MaskReg& pg)
{
Reg::MaskReg pd;
Reg::RegTensor<uint32_t> vZero;
Duplicate(vZero, 0x0);
AddCarryOut(pd, ctr0, ctr0, value, pg);
AddCarryOuts(pd, ctr1, ctr1, vZero, pd, pg);
AddCarryOuts(pd, ctr2, ctr2, vZero, pd, pg);
AddCarryOuts(pd, ctr3, ctr3, vZero, pd, pg);
}
__simd_callee__ inline void UInt2Float(
Reg::RegTensor<uint32_t>& tmpCtr0, Reg::RegTensor<uint32_t>& tmpCtr1, Reg::RegTensor<uint32_t>& tmpCtr2,
Reg::RegTensor<uint32_t>& tmpCtr3, Reg::MaskReg& pg)
{
Reg::RegTensor<uint32_t> vb32ManMask, vb32ExpMask;
Duplicate(vb32ManMask, PhiloxInternal::MANTISSA);
Duplicate(vb32ExpMask, PhiloxInternal::EXP_MASK);
And(tmpCtr0, tmpCtr0, vb32ManMask, pg);
And(tmpCtr1, tmpCtr1, vb32ManMask, pg);
And(tmpCtr2, tmpCtr2, vb32ManMask, pg);
And(tmpCtr3, tmpCtr3, vb32ManMask, pg);
Or(tmpCtr0, tmpCtr0, vb32ExpMask, pg);
Or(tmpCtr1, tmpCtr1, vb32ExpMask, pg);
Or(tmpCtr2, tmpCtr2, vb32ExpMask, pg);
Or(tmpCtr3, tmpCtr3, vb32ExpMask, pg);
Adds((Reg::RegTensor<float>&)tmpCtr0, (Reg::RegTensor<float>&)tmpCtr0, -1.0f, pg);
Adds((Reg::RegTensor<float>&)tmpCtr1, (Reg::RegTensor<float>&)tmpCtr1, -1.0f, pg);
Adds((Reg::RegTensor<float>&)tmpCtr2, (Reg::RegTensor<float>&)tmpCtr2, -1.0f, pg);
Adds((Reg::RegTensor<float>&)tmpCtr3, (Reg::RegTensor<float>&)tmpCtr3, -1.0f, pg);
}
template <uint16_t Rounds>
__simd_callee__ inline void SpNetworkKernel(
Reg::RegTensor<uint32_t>& tmpL0, Reg::RegTensor<uint32_t>& tmpH0, Reg::RegTensor<uint32_t>& tmpL1,
Reg::RegTensor<uint32_t>& tmpH1, Reg::RegTensor<uint32_t>& tmpCtr0, Reg::RegTensor<uint32_t>& tmpCtr1,
Reg::RegTensor<uint32_t>& tmpCtr2, Reg::RegTensor<uint32_t>& tmpCtr3, Reg::RegTensor<uint32_t>& tmpKey0,
Reg::RegTensor<uint32_t>& tmpKey1, Reg::RegTensor<uint32_t>& cMul0, Reg::RegTensor<uint32_t>& cMul1,
Reg::MaskReg& pg)
{
for (uint16_t j = 0; j < Rounds; j++) {
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
}
}
template <uint16_t Rounds, typename T, bool DstUnalign = false>
__simd_callee__ inline void SpNetworkFull(
__ubuf__ uint32_t* dstUbTail, uint16_t tailCount, Reg::RegTensor<uint32_t>& ctr0, Reg::RegTensor<uint32_t>& ctr1,
Reg::RegTensor<uint32_t>& ctr2, Reg::RegTensor<uint32_t>& ctr3, Reg::RegTensor<uint32_t>& key0,
Reg::RegTensor<uint32_t>& key1, Reg::RegTensor<uint32_t>& cMul0, Reg::RegTensor<uint32_t>& cMul1, Reg::MaskReg& pg)
{
Reg::RegTensor<uint32_t> tmpCtr3, tmpCtr2, tmpCtr1, tmpCtr0;
tmpCtr0 = ctr0;
tmpCtr1 = ctr1;
tmpCtr2 = ctr2;
tmpCtr3 = ctr3;
Reg::RegTensor<uint32_t> tmpKey0 = key0;
Reg::RegTensor<uint32_t> tmpKey1 = key1;
Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
SpNetworkKernel<Rounds>(
tmpL0, tmpH0, tmpL1, tmpH1, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, tmpKey0, tmpKey1, cMul0, cMul1, pg);
if constexpr (std::is_same_v<T, float>) {
UInt2Float(tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
}
if constexpr (DstUnalign) {
Reg::RegTensor<uint32_t> reorderIndex;
Reg::Arange((Reg::RegTensor<int32_t>&)reorderIndex, 0);
Muls(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_COUNTER_NUM, pg);
uint32_t sreg = static_cast<uint32_t>(tailCount / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::MaskReg pgTail = Reg::UpdateMask<T>(sreg);
Reg::Scatter(dstUbTail, tmpCtr0, reorderIndex, pgTail);
Reg::Scatter(dstUbTail + 1, tmpCtr1, reorderIndex, pgTail);
Reg::Scatter(dstUbTail + 2, tmpCtr2, reorderIndex, pgTail);
Reg::Scatter(dstUbTail + 3, tmpCtr3, reorderIndex, pgTail);
} else {
Interleave(tmpCtr0, tmpCtr2, tmpCtr0, tmpCtr2);
Interleave(tmpCtr1, tmpCtr3, tmpCtr1, tmpCtr3);
Interleave(tmpCtr0, tmpCtr1, tmpCtr0, tmpCtr1);
Interleave(tmpCtr2, tmpCtr3, tmpCtr2, tmpCtr3);
uint32_t sreg = static_cast<uint32_t>(tailCount);
Reg::MaskReg pgTail = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbTail, tmpCtr0, PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
pgTail = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbTail, tmpCtr1, PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
pgTail = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbTail, tmpCtr2, PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
pgTail = Reg::UpdateMask<T>(sreg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbTail, tmpCtr3, PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
}
}
template <bool DstUnalign = false>
__simd_callee__ inline void PhiloxUnrollStoreTmpCtrl(
__ubuf__ uint32_t*& dstUbT, Reg::RegTensor<uint32_t>& tmpCtr0, Reg::RegTensor<uint32_t>& tmpCtr1,
Reg::RegTensor<uint32_t>& tmpCtr2, Reg::RegTensor<uint32_t>& tmpCtr3, Reg::MaskReg& pg)
{
if constexpr (DstUnalign) {
Reg::UnalignReg ureg;
Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr0, ureg, PhiloxInternal::ELE_CNT_B32_ONCE);
Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr1, ureg, PhiloxInternal::ELE_CNT_B32_ONCE);
Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr2, ureg, PhiloxInternal::ELE_CNT_B32_ONCE);
Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr3, ureg, PhiloxInternal::ELE_CNT_B32_ONCE);
Reg::StoreUnAlignPost<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, ureg, 0);
} else {
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr0, PhiloxInternal::ELE_CNT_B32_ONCE, pg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr1, PhiloxInternal::ELE_CNT_B32_ONCE, pg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr2, PhiloxInternal::ELE_CNT_B32_ONCE, pg);
Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
dstUbT, tmpCtr3, PhiloxInternal::ELE_CNT_B32_ONCE, pg);
}
}
template <bool DstUnalign = false>
__simd_callee__ inline void PhiloxUnrollLoadTmpCtrl(
__ubuf__ uint32_t*& dstUbTT0, __ubuf__ uint32_t*& dstUbTT1, __ubuf__ uint32_t*& dstUbTT2,
__ubuf__ uint32_t*& dstUbTT3, Reg::RegTensor<uint32_t>& tmpCtr0, Reg::RegTensor<uint32_t>& tmpCtr1,
Reg::RegTensor<uint32_t>& tmpCtr2, Reg::RegTensor<uint32_t>& tmpCtr3)
{
if constexpr (DstUnalign) {
Reg::UnalignReg ureg;
Reg::LoadUnAlignPre(ureg, dstUbTT0);
Reg::LoadUnAlign(
tmpCtr0, ureg, dstUbTT0, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::LoadUnAlignPre(ureg, dstUbTT1);
Reg::LoadUnAlign(
tmpCtr1, ureg, dstUbTT1, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::LoadUnAlignPre(ureg, dstUbTT2);
Reg::LoadUnAlign(
tmpCtr2, ureg, dstUbTT2, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::LoadUnAlignPre(ureg, dstUbTT3);
Reg::LoadUnAlign(
tmpCtr3, ureg, dstUbTT3, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
} else {
Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
tmpCtr0, dstUbTT0, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
tmpCtr1, dstUbTT1, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
tmpCtr2, dstUbTT2, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(
tmpCtr3, dstUbTT3, PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
}
}
template <typename T, bool DstUnalign = false>
__simd_callee__ inline void PhiloxRound10MainBlockUnroll442(
__ubuf__ uint32_t* dstUb, uint16_t mainIter, Reg::RegTensor<uint32_t>& ctr0, Reg::RegTensor<uint32_t>& ctr1,
Reg::RegTensor<uint32_t>& ctr2, Reg::RegTensor<uint32_t>& ctr3, Reg::RegTensor<uint32_t>& key0,
Reg::RegTensor<uint32_t>& key1, Reg::RegTensor<uint32_t>& cMul0, Reg::RegTensor<uint32_t>& cMul1,
Reg::RegTensor<uint32_t>& vEleStrideB32OneRow, Reg::MaskReg& pg)
{
Reg::RegTensor<uint32_t> tmpCtr3, tmpCtr2, tmpCtr1, tmpCtr0;
__ubuf__ uint32_t* dstUbT = dstUb;
Reg::RegTensor<uint32_t> reorderIndex;
Arange((Reg::RegTensor<int32_t>&)reorderIndex, 0);
Muls(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_COUNTER_NUM, pg);
for (uint16_t i = 0; i < mainIter; i++) {
tmpCtr0 = ctr0;
tmpCtr1 = ctr1;
tmpCtr2 = ctr2;
tmpCtr3 = ctr3;
Reg::RegTensor<uint32_t> tmpKey0 = key0;
Reg::RegTensor<uint32_t> tmpKey1 = key1;
Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
PhiloxUnrollStoreTmpCtrl<DstUnalign>(dstUbT, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
AddWith128Bits(ctr0, ctr1, ctr2, ctr3, vEleStrideB32OneRow, pg);
}
dstUbT = dstUb;
__ubuf__ uint32_t* dstUbTT0 = dstUbT;
__ubuf__ uint32_t* dstUbTT1 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE;
__ubuf__ uint32_t* dstUbTT2 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 2;
__ubuf__ uint32_t* dstUbTT3 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 3;
Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();
for (uint16_t i = 0; i < mainIter; i++) {
PhiloxUnrollLoadTmpCtrl<DstUnalign>(dstUbTT0, dstUbTT1, dstUbTT2, dstUbTT3, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3);
Reg::RegTensor<uint32_t> tmpKey0 = key0;
Reg::RegTensor<uint32_t> tmpKey1 = key1;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
PhiloxUnrollStoreTmpCtrl<DstUnalign>(dstUbT, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
}
dstUbT = dstUb;
dstUbTT0 = dstUbT;
dstUbTT1 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE;
dstUbTT2 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 2;
dstUbTT3 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 3;
Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();
for (uint16_t i = 0; i < mainIter; i++) {
PhiloxUnrollLoadTmpCtrl<DstUnalign>(dstUbTT0, dstUbTT1, dstUbTT2, dstUbTT3, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3);
Reg::RegTensor<uint32_t> tmpKey0 = key0;
Reg::RegTensor<uint32_t> tmpKey1 = key1;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
if constexpr (std::is_same_v<T, float>) {
UInt2Float(tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
}
Reg::Scatter(dstUb, tmpCtr0, reorderIndex, pg);
Reg::Scatter(dstUb + 1, tmpCtr1, reorderIndex, pg);
Reg::Scatter(dstUb + 2, tmpCtr2, reorderIndex, pg);
Reg::Scatter(dstUb + 3, tmpCtr3, reorderIndex, pg);
Adds(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_REPEAT_NUM, pg);
}
}
template <typename T, bool DstUnalign = false>
__simd_callee__ inline void PhiloxRound7MainBlockUnroll43(
__ubuf__ uint32_t* dstUb, uint16_t mainIter, Reg::RegTensor<uint32_t>& ctr0, Reg::RegTensor<uint32_t>& ctr1,
Reg::RegTensor<uint32_t>& ctr2, Reg::RegTensor<uint32_t>& ctr3, Reg::RegTensor<uint32_t>& key0,
Reg::RegTensor<uint32_t>& key1, Reg::RegTensor<uint32_t>& cMul0, Reg::RegTensor<uint32_t>& cMul1,
Reg::RegTensor<uint32_t>& vEleStrideB32OneRow, Reg::MaskReg& pg)
{
Reg::RegTensor<uint32_t> tmpCtr3, tmpCtr2, tmpCtr1, tmpCtr0;
__ubuf__ uint32_t* dstUbT = dstUb;
Reg::RegTensor<uint32_t> reorderIndex;
Reg::Arange((Reg::RegTensor<int32_t>&)reorderIndex, 0);
Muls(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_COUNTER_NUM, pg);
for (uint16_t i = 0; i < mainIter; i++) {
tmpCtr0 = ctr0;
tmpCtr1 = ctr1;
tmpCtr2 = ctr2;
tmpCtr3 = ctr3;
Reg::RegTensor<uint32_t> tmpKey0 = key0;
Reg::RegTensor<uint32_t> tmpKey1 = key1;
Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
PhiloxUnrollStoreTmpCtrl<DstUnalign>(dstUbT, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
AddWith128Bits(ctr0, ctr1, ctr2, ctr3, vEleStrideB32OneRow, pg);
}
dstUbT = dstUb;
__ubuf__ uint32_t* dstUbTT0 = dstUbT;
__ubuf__ uint32_t* dstUbTT1 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE;
__ubuf__ uint32_t* dstUbTT2 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 2;
__ubuf__ uint32_t* dstUbTT3 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 3;
Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();
for (uint16_t i = 0; i < mainIter; i++) {
PhiloxUnrollLoadTmpCtrl<DstUnalign>(dstUbTT0, dstUbTT1, dstUbTT2, dstUbTT3, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3);
Reg::RegTensor<uint32_t> tmpKey0 = key0;
Reg::RegTensor<uint32_t> tmpKey1 = key1;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
Xor(tmpH1, tmpH1, tmpCtr1, pg);
Xor(tmpCtr0, tmpH1, tmpKey0, pg);
Xor(tmpH0, tmpH0, tmpCtr3, pg);
Xor(tmpCtr2, tmpH0, tmpKey1, pg);
tmpCtr1 = tmpL1;
tmpCtr3 = tmpL0;
if constexpr (std::is_same_v<T, float>) {
UInt2Float(tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
}
Reg::Scatter(dstUb, tmpCtr0, reorderIndex, pg);
Reg::Scatter(dstUb + 1, tmpCtr1, reorderIndex, pg);
Reg::Scatter(dstUb + 2, tmpCtr2, reorderIndex, pg);
Reg::Scatter(dstUb + 3, tmpCtr3, reorderIndex, pg);
Adds(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_REPEAT_NUM, pg);
}
}
template <uint16_t Rounds = 7, typename T, bool DstUnalign = false>
__simd_callee__ inline void PhiloxRoundMainBlockUnroll(
__ubuf__ uint32_t* dstUb, uint16_t mainIter, Reg::RegTensor<uint32_t>& ctr0, Reg::RegTensor<uint32_t>& ctr1,
Reg::RegTensor<uint32_t>& ctr2, Reg::RegTensor<uint32_t>& ctr3, Reg::RegTensor<uint32_t>& key0,
Reg::RegTensor<uint32_t>& key1, Reg::RegTensor<uint32_t>& cMul0, Reg::RegTensor<uint32_t>& cMul1,
Reg::RegTensor<uint32_t>& vEleStrideB32OneRow, Reg::MaskReg& pg)
{
if constexpr (Rounds == 10) {
PhiloxRound10MainBlockUnroll442<T, DstUnalign>(
dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, vEleStrideB32OneRow, pg);
} else {
PhiloxRound7MainBlockUnroll43<T, DstUnalign>(
dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, vEleStrideB32OneRow, pg);
}
}
__simd_callee__ inline void PhiloxCounterInit(
const PhiloxCounter& philoxCounter, Reg::RegTensor<uint32_t>& ctr0, Reg::RegTensor<uint32_t>& ctr1,
Reg::RegTensor<uint32_t>& ctr2, Reg::RegTensor<uint32_t>& ctr3, Reg::RegTensor<int32_t>& incIdx, Reg::MaskReg& pg)
{
Duplicate(ctr0, philoxCounter[0]);
Duplicate(ctr1, philoxCounter[1]);
Duplicate(ctr2, philoxCounter[2]);
Duplicate(ctr3, philoxCounter[3]);
AddWith128Bits(ctr0, ctr1, ctr2, ctr3, (Reg::RegTensor<uint32_t>&)incIdx, pg);
}
template <uint16_t Rounds = 7, typename T, bool DstUnalign = false>
__simd_vf__ inline void PhiloxRandomOneRow(
__ubuf__ uint32_t* dstUb, __ubuf__ uint32_t* dstUbTail, const philoxStruct philox, uint16_t mainIter,
uint16_t tailCount)
{
Reg::MaskReg pg = Reg::CreateMask<uint32_t>();
Reg::RegTensor<uint32_t> ctr3, ctr2, ctr1, ctr0;
Reg::RegTensor<int32_t> incIdx;
Reg::Arange(incIdx, 0);
PhiloxCounterInit(philox.philoxCounter, ctr0, ctr1, ctr2, ctr3, incIdx, pg);
Reg::RegTensor<uint32_t> vEleStrideB32OneRow;
Duplicate(vEleStrideB32OneRow, PhiloxInternal::ELE_CNT_B32_ONCE);
Reg::RegTensor<uint32_t> key1, key0;
Duplicate(key0, philox.philoxKey[0]);
Duplicate(key1, philox.philoxKey[1]);
Reg::RegTensor<uint32_t> cMul0, cMul1;
Duplicate(cMul0, PhiloxInternal::CONST_MUL_0);
Duplicate(cMul1, PhiloxInternal::CONST_MUL_1);
PhiloxRoundMainBlockUnroll<Rounds, T, DstUnalign>(
dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, vEleStrideB32OneRow, pg);
if (tailCount > 0) {
SpNetworkFull<Rounds, T>(dstUbTail, tailCount, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, pg);
}
}
// derive index calculation
for (i : row)
for (j : column) // column < half one repeat
index[i * column + j] = i * stride + j
==>
factor = ONE_REPEAT_LEN / column
i.i_extent = factor
i.o_extent = row / factor
i = i.o * factor + i.i < row // contains an if constraint condition
for (i.o, 0, row / factor) {
for (i.i, 0, factor) {
if (i.o * factor + i.i < row) {
for (j, 0, column) {
index[(i.o * factor + i.i) * column + j] = i.o * factor * stride + i.i * stride + j
}
}
}
}
==>
j_fuse = i.i * column + j
j_fuse_extent = factor * column < ONE_REPEAT_LEN
i.i = j_fuse / column
j = j_fuse % column
for (i.o, 0, row / factor) {
if (i.o * factor * column + j_fuse < row * column) {
for (j_fuse, 0, factor * column) {
index[i.o * factor * column + j_fuse] = i.o * factor * stride + j_fuse / column * stride + j_fuse % column
}
}
}
*/
__simd_callee__ inline void PhiloxRandomIndexCal(
__ubuf__ int32_t* indexUb, const PhiloxRandomParams& params, const uint32_t fuseFactor)
{
__ubuf__ int32_t* indexUbT = indexUb;
Reg::MaskReg pg = Reg::CreateMask<uint32_t>();
Reg::RegTensor<int32_t> index, incIdx;
Reg::UnalignReg ureg;
uint32_t stride = static_cast<uint32_t>(params.stride / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
uint32_t elementNum = params.column / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM;
for (uint16_t i = 0; i < fuseFactor; i++) {
Reg::Duplicate(index, i);
Reg::Muls(index, index, stride, pg);
Reg::Arange(incIdx, 0);
Reg::Add(index, index, incIdx, pg);
Reg::StoreUnAlign<int32_t, Reg::PostLiteral::POST_MODE_UPDATE>(indexUbT, index, ureg, elementNum);
}
Reg::StoreUnAlignPost<int32_t, Reg::PostLiteral::POST_MODE_UPDATE>(indexUbT, ureg, 0);
Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();
}
template <uint16_t Rounds = 7, typename T, bool DstBlockUnalign, bool DstRepeatUnalign>
__simd_vf__ inline void PhiloxRandomMultiRowWithFuse(
__ubuf__ uint32_t* dstUbStart, __ubuf__ int32_t* indexUb, const philoxStruct philox,
const PhiloxRandomParams params, const uint32_t fuseFactor, const uint32_t mainFuseAxis, const uint32_t mainRowsNum,
const uint32_t tailFuseAxis)
{
PhiloxRandomIndexCal(indexUb, params, fuseFactor);
Reg::RegTensor<uint32_t> ctr3, ctr2, ctr1, ctr0;
Reg::MaskReg pg = Reg::CreateMask<uint32_t>();
Reg::RegTensor<int32_t> incIdx;
Reg::LoadAlign(incIdx, indexUb);
PhiloxCounterInit(philox.philoxCounter, ctr0, ctr1, ctr2, ctr3, incIdx, pg);
Reg::RegTensor<uint32_t> key1, key0;
Duplicate(key0, philox.philoxKey[0]);
Duplicate(key1, philox.philoxKey[1]);
Reg::RegTensor<uint32_t> cMul0, cMul1;
Duplicate(cMul0, PhiloxInternal::CONST_MUL_0);
Duplicate(cMul1, PhiloxInternal::CONST_MUL_1);
Reg::RegTensor<uint32_t> vEleStrideB32OneRow;
Duplicate(vEleStrideB32OneRow, fuseFactor * params.stride / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
if constexpr (!DstRepeatUnalign) {
PhiloxRoundMainBlockUnroll<Rounds, T, DstBlockUnalign>(
dstUbStart, mainRowsNum, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, vEleStrideB32OneRow, pg);
} else {
for (uint16_t i = 0; i < mainRowsNum; i++) {
__ubuf__ uint32_t* dstUb = dstUbStart + i * mainFuseAxis;
SpNetworkFull<Rounds, T, DstBlockUnalign>(
dstUb, mainFuseAxis, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, pg);
AddWith128Bits(ctr0, ctr1, ctr2, ctr3, vEleStrideB32OneRow, pg);
}
}
if (tailFuseAxis > 0) {
__ubuf__ uint32_t* dstUbTail = dstUbStart + mainRowsNum * mainFuseAxis;
SpNetworkFull<Rounds, T, DstBlockUnalign>(
dstUbTail, tailFuseAxis, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, pg);
}
}
template <uint16_t Rounds = 7, typename T, bool DstUnalign>
__simd_vf__ inline void PhiloxRandomMultiRowNoFuse(
__ubuf__ uint32_t* dstUbStart, const philoxStruct philox, const PhiloxRandomParams params,
uint32_t strideCounterOneRow, uint16_t mainIter, uint16_t tailCount, uint16_t hasTail)
{
Reg::RegTensor<uint32_t> vEleStrideB32OneRow;
Duplicate(vEleStrideB32OneRow, PhiloxInternal::ELE_CNT_B32_ONCE);
Reg::RegTensor<uint32_t> key1, key0;
Duplicate(key0, philox.philoxKey[0]);
Duplicate(key1, philox.philoxKey[1]);
Reg::RegTensor<uint32_t> cMul0, cMul1;
Duplicate(cMul0, PhiloxInternal::CONST_MUL_0);
Duplicate(cMul1, PhiloxInternal::CONST_MUL_1);
Reg::RegTensor<uint32_t> ctr3, ctr2, ctr1, ctr0;
Reg::MaskReg pg = Reg::CreateMask<uint32_t>();
for (uint16_t i = 0; i < params.row; i++) {
__ubuf__ uint32_t* dstUb = dstUbStart + i * params.column;
__ubuf__ uint32_t* dstUbTail = dstUb + mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
Reg::RegTensor<int32_t> incIdx;
Reg::Arange(incIdx, i * strideCounterOneRow);
PhiloxCounterInit(philox.philoxCounter, ctr0, ctr1, ctr2, ctr3, incIdx, pg);
PhiloxRoundMainBlockUnroll<Rounds, T, DstUnalign>(
dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, vEleStrideB32OneRow, pg);
for (uint16_t j = 0; j < hasTail; j++) {
SpNetworkFull<Rounds, T, DstUnalign>(
dstUbTail, tailCount, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, pg);
}
}
}
template <uint16_t Rounds = 7, typename T>
__aicore__ inline void PhiloxRandomImpl(
const LocalTensor<T>& dstLocal, const PhiloxKey& philoxKey, const PhiloxCounter& philoxCounter, uint16_t count)
{
static_assert(
SupportType<T, int32_t, uint32_t, float>(), "PhiloxRandom API only support int32_t/uint32_t/float type");
static_assert(Rounds == 7 || Rounds == 10, "PhiloxRandom API only support 7 or 10 Rounds ");
__ubuf__ uint32_t* dstUb = (__ubuf__ uint32_t*)dstLocal.GetPhyAddr();
uint16_t mainIter = count / PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
uint16_t tailCount = count - mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
__ubuf__ uint32_t* dstUbTail = dstUb + mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
philoxStruct philox;
uint16_t philoxKeySize = (philoxKey == nullptr) ? 0 : PHILOX_KEY_SIZE;
uint16_t philoxCounterSize = (philoxCounter == nullptr) ? 0 : PHILOX_COUNTER_SIZE;
for (uint16_t i = 0; i < philoxKeySize; i++) {
philox.philoxKey[i] = philoxKey[i];
}
for (uint16_t i = 0; i < philoxCounterSize; i++) {
philox.philoxCounter[i] = philoxCounter[i];
}
PhiloxRandomOneRow<Rounds, T, false>(dstUb, dstUbTail, philox, mainIter, tailCount);
}
template <uint16_t Rounds = 7, typename T>
__aicore__ inline void PhiloxRandomImpl(
const LocalTensor<T>& dstLocal, const PhiloxKey& philoxKey, const PhiloxCounter& philoxCounter,
const PhiloxRandomParams& params)
{
static_assert(
SupportType<T, int32_t, uint32_t, float>(), "PhiloxRandom API only support int32_t/uint32_t/float type");
static_assert(Rounds == 7 || Rounds == 10, "PhiloxRandom API only support 7 or 10 Rounds ");
ASCENDC_ASSERT((params.stride % PhiloxInternal::PHILOX_ONCE_COUNTER_NUM == 0), {
KERNEL_LOG(KERNEL_ERROR, "params.stride % 4 = 0!");
});
ASCENDC_ASSERT((params.column % PhiloxInternal::PHILOX_ONCE_COUNTER_NUM == 0), {
KERNEL_LOG(KERNEL_ERROR, "params.column % 4 = 0!");
});
ASCENDC_ASSERT((params.stride >= params.column), { KERNEL_LOG(KERNEL_ERROR, "params.stride >= params.column!"); });
ASCENDC_ASSERT(
(params.row > 0 && params.column > 0), { KERNEL_LOG(KERNEL_ERROR, "params.row > 0 && params.column > 0!"); });
__ubuf__ uint32_t* dstUbStart = (__ubuf__ uint32_t*)dstLocal.GetPhyAddr();
philoxStruct philox;
uint16_t philoxKeySize = (philoxKey == nullptr) ? 0 : PHILOX_KEY_SIZE;
uint16_t philoxCounterSize = (philoxCounter == nullptr) ? 0 : PHILOX_COUNTER_SIZE;
for (uint16_t i = 0; i < philoxKeySize; i++) {
philox.philoxKey[i] = philoxKey[i];
}
for (uint16_t i = 0; i < philoxCounterSize; i++) {
philox.philoxCounter[i] = philoxCounter[i];
}
if (params.row == 1 || params.stride == params.column) {
PhiloxRandomImpl<Rounds, T>(dstLocal, philoxKey, philoxCounter, params.row * params.column);
} else if (params.column <= PhiloxInternal::PHILOX_ONCE_REPEAT_NUM) {
uint32_t fuseFactor = PhiloxInternal::PHILOX_ONCE_REPEAT_NUM / params.column;
uint32_t mainFuseAxis = params.column * fuseFactor;
uint32_t mainRowsNum = params.row / fuseFactor;
uint32_t tailFuseAxis = params.row * params.column - mainFuseAxis * mainRowsNum;
LocalTensor<T> indexTensor;
PopStackBuffer<T, TPosition::LCM>(indexTensor);
__ubuf__ int32_t* indexUb = (__ubuf__ int32_t*)indexTensor.GetPhyAddr();
if (mainFuseAxis == PhiloxInternal::PHILOX_ONCE_REPEAT_NUM) {
PhiloxRandomMultiRowWithFuse<Rounds, T, false, false>(
dstUbStart, indexUb, philox, params, fuseFactor, mainFuseAxis, mainRowsNum, tailFuseAxis);
} else if (mainFuseAxis * sizeof(uint32_t) % ONE_BLOCK_SIZE == 0) {
PhiloxRandomMultiRowWithFuse<Rounds, T, false, true>(
dstUbStart, indexUb, philox, params, fuseFactor, mainFuseAxis, mainRowsNum, tailFuseAxis);
} else {
PhiloxRandomMultiRowWithFuse<Rounds, T, true, true>(
dstUbStart, indexUb, philox, params, fuseFactor, mainFuseAxis, mainRowsNum, tailFuseAxis);
}
} else {
uint32_t strideCounterOneRow = params.stride / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM;
uint16_t mainIter = params.column / PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
uint16_t tailCount = params.column - mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
uint16_t hasTail = static_cast<uint16_t>(tailCount > 0);
if (params.column / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM * sizeof(uint32_t) % ONE_BLOCK_SIZE == 0) {
PhiloxRandomMultiRowNoFuse<Rounds, T, false>(
dstUbStart, philox, params, strideCounterOneRow, mainIter, tailCount, hasTail);
} else {
PhiloxRandomMultiRowNoFuse<Rounds, T, true>(
dstUbStart, philox, params, strideCounterOneRow, mainIter, tailCount, hasTail);
}
}
}
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_PHILOX_PHILOX_C310_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_PHILOX_PHILOX_C310_IMPL_H__
#endif