* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef RANDOM_KERNEL_BASE_H
#define RANDOM_KERNEL_BASE_H
#include "kernel_operator.h"
#include "op_kernel/platform_util.h"
#include "op_kernel/math_util.h"
#include "random_unified_tiling_data_arch35.h"
#include "simt_api/asc_simt.h"
namespace RandomKernelBase {
using namespace AscendC;
static constexpr MicroAPI::CastTrait castTraitTf = {
MicroAPI::RegLayout::ZERO, MicroAPI::SatMode::NO_SAT, MicroAPI::MaskMergeMode::ZEROING};
static constexpr uint16_t ALG_KEY_SIZE = 2;
static constexpr uint16_t ALG_COUNTER_SIZE = 4;
static constexpr uint32_t INT32_ONE_REPEAT = Ops::Base::GetVRegSize() / sizeof(int32_t);
static constexpr uint32_t RIGHT_SHIFT = 32;
static constexpr uint16_t BLOCK_SIZE = 32;
static constexpr uint16_t DOUBLE_UNIFORM_RESULT = 2;
static constexpr uint32_t INT32_FLOAT32_ONE_REPEAT = Ops::Base::GetVRegSize() / sizeof(int32_t);
static constexpr float DOUBLE_MULTIPLE = 2.0f;
static constexpr float PI = 3.14159265358979323846f;
static constexpr int IDX_2 = 2;
static constexpr int IDX_3 = 3;
static constexpr int32_t CONTINUOUS_USE = 0;
static constexpr int32_t DIS_CONTINUOUS_USE = 1;
static constexpr uint32_t PHILOX_W32_A = 0x9E3779B9;
static constexpr uint32_t PHILOX_W32_B = 0xBB67AE85;
static constexpr uint32_t PHILOX_M4X32_A = 0xD2511F53;
static constexpr uint32_t PHILOX_M4X32_B = 0xCD9E8D57;
static constexpr float RAND_2POW32_INV = 2.3283064e-10f;
static constexpr float RAND_2POW32_INV_HALF = RAND_2POW32_INV / 2.0f;
constexpr uint64_t VEC_2 = 2;
constexpr uint64_t VEC_4 = 4;
constexpr uint64_t VEC_8 = 8;
constexpr uint64_t VEC_16 = 16;
template <typename T>
__aicore__ inline void CopyOut(
LocalTensor<T> yLocal, GlobalTensor<T> yGm, uint32_t burstNum, uint32_t busrtLength, int64_t gmOffset)
{
DataCopyExtParams copyParams;
copyParams.blockCount = burstNum;
copyParams.blockLen = busrtLength;
DataCopyPad(yGm[gmOffset], yLocal, copyParams);
}
template <typename T>
__aicore__ inline void Uint32ToFloat(LocalTensor<T>& yOutput, LocalTensor<uint32_t>& philoxRes, const uint32_t calCount)
{
__ubuf__ int32_t* ubPhilox = (__ubuf__ int32_t*)philoxRes.GetPhyAddr();
__ubuf__ float* ubOut = (__ubuf__ float*)yOutput.GetPhyAddr();
uint32_t repeatTimes = Ops::Base::CeilDiv(calCount, static_cast<uint32_t>(INT32_ONE_REPEAT));
__VEC_SCOPE__
{
MicroAPI::RegTensor<int32_t> vReg0;
MicroAPI::RegTensor<int32_t> vReg1;
MicroAPI::RegTensor<int32_t> vReg2;
MicroAPI::RegTensor<int32_t> vReg3;
MicroAPI::RegTensor<int32_t> vReg4;
MicroAPI::RegTensor<float> vReg5;
MicroAPI::RegTensor<float> vReg6;
MicroAPI::MaskReg mask;
uint32_t sReg1 = static_cast<uint32_t>(calCount);
uint32_t sReg2 = static_cast<uint32_t>(0x7fffffu);
uint32_t exp = static_cast<uint32_t>(127);
uint32_t sReg3 = exp << 23;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<int32_t, MicroAPI::MaskPattern::ALL>();
MicroAPI::Duplicate<int32_t, MicroAPI::MaskMergeMode::ZEROING>(vReg1, sReg2, maskAll);
MicroAPI::Duplicate<int32_t, MicroAPI::MaskMergeMode::ZEROING>(vReg3, sReg3, maskAll);
float sReg4 = static_cast<float>(-1.0);
int32_t offSet = static_cast<int32_t>(INT32_ONE_REPEAT);
for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTimes); ++i) {
mask = MicroAPI::UpdateMask<float>(sReg1);
MicroAPI::DataCopy<int32_t, MicroAPI::PostLiteral::POST_MODE_UPDATE, MicroAPI::LoadDist::DIST_NORM>(
vReg0, ubPhilox, offSet);
MicroAPI::And<int32_t, MicroAPI::MaskMergeMode::ZEROING>(vReg2, vReg0, vReg1, mask);
MicroAPI::Or<int32_t, MicroAPI::MaskMergeMode::ZEROING>(vReg4, vReg2, vReg3, mask);
vReg5 = (MicroAPI::RegTensor<float>&)vReg4;
MicroAPI::Adds<float, float, MicroAPI::MaskMergeMode::ZEROING>(vReg6, vReg5, sReg4, mask);
MicroAPI::DataCopy<float, MicroAPI::PostLiteral::POST_MODE_UPDATE, MicroAPI::StoreDist::DIST_NORM_B32>(
ubOut, vReg6, offSet, mask);
}
}
}
template <typename T>
__aicore__ inline void Uint16ToHalf(LocalTensor<T>& yOutput, LocalTensor<uint32_t>& philoxRes, const uint32_t calCount)
{
__ubuf__ int32_t* ubPhilox = (__ubuf__ int32_t*)philoxRes.GetPhyAddr();
__ubuf__ half* ubOut = (__ubuf__ half*)yOutput.GetPhyAddr();
uint32_t repeatTimes = Ops::Base::CeilDiv(calCount, static_cast<uint32_t>(INT32_ONE_REPEAT));
SetCtrlSpr<60, 60>(0);
__VEC_SCOPE__
{
MicroAPI::RegTensor<int32_t> vRegT;
MicroAPI::RegTensor<int16_t> vReg0;
MicroAPI::RegTensor<int16_t> vReg1;
MicroAPI::RegTensor<int16_t> vReg2;
MicroAPI::RegTensor<int16_t> vReg3;
MicroAPI::RegTensor<int16_t> vReg4;
MicroAPI::RegTensor<half> vReg5;
MicroAPI::RegTensor<half> vReg6;
MicroAPI::MaskReg mask;
uint32_t sReg1 = static_cast<uint32_t>(calCount);
uint16_t sReg2 = static_cast<uint16_t>(0x3ffu);
uint16_t exp = static_cast<uint16_t>(15);
uint16_t sReg3 = exp << 10;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<int32_t>();
MicroAPI::Duplicate<int16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg1, sReg2, maskAll);
MicroAPI::Duplicate<int16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg3, sReg3, maskAll);
half sReg4 = static_cast<half>(-1.0);
int32_t offset = static_cast<int32_t>(INT32_ONE_REPEAT);
for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTimes); ++i) {
mask = MicroAPI::UpdateMask<float>(sReg1);
MicroAPI::DataCopy<int32_t, MicroAPI::PostLiteral::POST_MODE_UPDATE, MicroAPI::LoadDist::DIST_NORM>(
vRegT, ubPhilox, offset);
MicroAPI::Cast<int16_t, int32_t, castTraitTf>(vReg0, vRegT, mask);
MicroAPI::And<int16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg2, vReg0, vReg1, mask);
MicroAPI::Or<int16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg4, vReg2, vReg3, mask);
vReg5 = (MicroAPI::RegTensor<half>&)vReg4;
MicroAPI::Adds<half, half, MicroAPI::MaskMergeMode::ZEROING>(vReg6, vReg5, sReg4, mask);
MicroAPI::DataCopy<half, MicroAPI::PostLiteral::POST_MODE_UPDATE, MicroAPI::StoreDist::DIST_PACK_B32>(
ubOut, vReg6, offset, mask);
}
}
SetCtrlSpr<60, 60>(1);
}
template <typename T>
__aicore__ inline void Uint16ToBfloat16(
LocalTensor<T>& yOutput, LocalTensor<uint32_t>& philoxRes, const uint32_t calCount)
{
__ubuf__ int32_t* ubPhilox = (__ubuf__ int32_t*)philoxRes.GetPhyAddr();
__ubuf__ bfloat16_t* ubOut = (__ubuf__ bfloat16_t*)yOutput.GetPhyAddr();
uint32_t repeatTimes = Ops::Base::CeilDiv(calCount, static_cast<uint32_t>(INT32_ONE_REPEAT));
SetCtrlSpr<60, 60>(0);
__VEC_SCOPE__
{
MicroAPI::RegTensor<int32_t> vRegT;
MicroAPI::RegTensor<int16_t> vReg0;
MicroAPI::RegTensor<int16_t> vReg1;
MicroAPI::RegTensor<int16_t> vReg2;
MicroAPI::RegTensor<int16_t> vReg3;
MicroAPI::RegTensor<int16_t> vReg4;
MicroAPI::RegTensor<bfloat16_t> vReg5;
MicroAPI::RegTensor<bfloat16_t> vReg6;
uint32_t sReg1 = static_cast<uint32_t>(calCount);
uint16_t sReg2 = static_cast<uint16_t>(0x7fu);
uint16_t exp = static_cast<uint16_t>(127);
uint16_t sReg3 = exp << 7;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<half, MicroAPI::MaskPattern::ALL>();
MicroAPI::Duplicate(vReg1, sReg2, maskAll);
MicroAPI::Duplicate(vReg3, sReg3, maskAll);
bfloat16_t sReg4 = static_cast<bfloat16_t>(-1.0);
MicroAPI::MaskReg mask;
int32_t offSet = static_cast<int32_t>(INT32_ONE_REPEAT);
for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTimes); ++i) {
mask = MicroAPI::UpdateMask<float>(sReg1);
MicroAPI::DataCopy<int32_t, MicroAPI::PostLiteral::POST_MODE_UPDATE, MicroAPI::LoadDist::DIST_NORM>(
vRegT, ubPhilox, offSet);
MicroAPI::Cast<int16_t, int32_t, castTraitTf>(vReg0, vRegT, mask);
MicroAPI::And<int16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg2, vReg0, vReg1, mask);
MicroAPI::Or<int16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg4, vReg2, vReg3, mask);
vReg5 = (MicroAPI::RegTensor<bfloat16_t>&)vReg4;
MicroAPI::Adds<bfloat16_t, bfloat16_t, MicroAPI::MaskMergeMode::ZEROING>(vReg6, vReg5, sReg4, mask);
MicroAPI::DataCopy<bfloat16_t, MicroAPI::PostLiteral::POST_MODE_UPDATE, MicroAPI::StoreDist::DIST_PACK_B32>(
ubOut, vReg6, offSet, mask);
}
}
SetCtrlSpr<60, 60>(1);
}
template <typename T>
__aicore__ inline void U32Conversion(LocalTensor<T>& yOutput, LocalTensor<uint32_t>& philoxRes, const uint32_t calCount)
{
if constexpr (AscendC::IsSameType<T, half>::value) {
Uint16ToHalf(yOutput, philoxRes, calCount);
} else if constexpr (AscendC::IsSameType<T, float>::value) {
Uint32ToFloat(yOutput, philoxRes, calCount);
} else if constexpr (AscendC::IsSameType<T, bfloat16_t>::value) {
Uint16ToBfloat16(yOutput, philoxRes, calCount);
}
}
template <typename T>
__aicore__ inline void Float32Conversion(
LocalTensor<T> yOutput, LocalTensor<float> normalFloatResult, const uint32_t calCount)
{
if constexpr (AscendC::IsSameType<T, float>::value) {
DataCopy(yOutput, normalFloatResult, Ops::Base::CeilAlign(calCount, static_cast<uint32_t>(BLOCK_SIZE)));
} else if constexpr (AscendC::IsSameType<T, half>::value) {
Cast(yOutput, normalFloatResult, RoundMode::CAST_NONE, calCount);
} else {
Cast(yOutput, normalFloatResult, RoundMode::CAST_RINT, calCount);
}
}
* Formula: Box-Muller
* X = sqrt(-2 * ln(U1)) * cos(2 * PI * U2)
* Y = sqrt(-2 * ln(U1)) * sin(2 * PI * U2)
*/
template <typename T>
__aicore__ inline void BoxMullerFloatSIMD(
LocalTensor<float>& yOutputTmp, LocalTensor<float> v1Result, LocalTensor<float> u2Result, const uint32_t calCount)
{
__ubuf__ float* uniformRes = (__ubuf__ float*)yOutputTmp.GetPhyAddr();
__ubuf__ float* ubV1Out = (__ubuf__ float*)v1Result.GetPhyAddr();
__ubuf__ float* ubU2Out = (__ubuf__ float*)u2Result.GetPhyAddr();
uint32_t repeatTimes =
Ops::Base::CeilDiv(calCount / DOUBLE_UNIFORM_RESULT, static_cast<uint32_t>(INT32_FLOAT32_ONE_REPEAT));
uint32_t sreg1 = calCount;
__VEC_SCOPE__
{
MicroAPI::RegTensor<float> vreg0;
MicroAPI::RegTensor<float> vreg1;
MicroAPI::RegTensor<float> vreg2;
MicroAPI::RegTensor<float> vreg3;
MicroAPI::RegTensor<float> vreg4;
MicroAPI::RegTensor<float> vreg5;
MicroAPI::RegTensor<float> vreg6;
MicroAPI::MaskReg mask;
float epsScalar = 1.0e-7f;
float doublePiScalar = DOUBLE_MULTIPLE * PI;
int32_t offset = static_cast<int32_t>(INT32_FLOAT32_ONE_REPEAT);
for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTimes); ++i) {
mask = MicroAPI::UpdateMask<float>(sreg1);
MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
vreg0, vreg1, uniformRes + i * offset * DOUBLE_UNIFORM_RESULT);
MicroAPI::Maxs(vreg2, vreg0, epsScalar, mask);
MicroAPI::Ln(vreg3, vreg2, mask);
MicroAPI::Muls(vreg4, vreg1, doublePiScalar, mask);
MicroAPI::Muls(vreg5, vreg3, -DOUBLE_MULTIPLE, mask);
MicroAPI::Sqrt(vreg6, vreg5, mask);
MicroAPI::DataCopy<float, MicroAPI::PostLiteral::POST_MODE_UPDATE>(ubV1Out, vreg4, offset, mask);
MicroAPI::DataCopy<float, MicroAPI::PostLiteral::POST_MODE_UPDATE>(ubU2Out, vreg6, offset, mask);
}
}
}
template <typename T>
__aicore__ inline void BoxMullerMulSIMD(
LocalTensor<float>& yOutputTmp, LocalTensor<float> v1Result, LocalTensor<float> u2Result,
LocalTensor<float> yOutput, const uint32_t calCount)
{
Cos<float, false>(yOutputTmp, v1Result);
Sin<float, false>(yOutput, v1Result);
uint32_t repeatTimes =
Ops::Base::CeilDiv(calCount / DOUBLE_UNIFORM_RESULT, static_cast<uint32_t>(INT32_FLOAT32_ONE_REPEAT));
__ubuf__ float* ubSinResult = (__ubuf__ float*)yOutput.GetPhyAddr();
__ubuf__ float* ubCosResult = (__ubuf__ float*)yOutputTmp.GetPhyAddr();
__ubuf__ float* ubU2Result = (__ubuf__ float*)u2Result.GetPhyAddr();
__ubuf__ float* ubOut = (__ubuf__ float*)v1Result.GetPhyAddr();
__VEC_SCOPE__
{
MicroAPI::RegTensor<float> vreg0;
MicroAPI::RegTensor<float> vreg1;
MicroAPI::RegTensor<float> vreg2;
MicroAPI::RegTensor<float> vreg3;
MicroAPI::RegTensor<float> vreg4;
MicroAPI::RegTensor<float> vreg5;
MicroAPI::MaskReg mask;
uint32_t sreg1 = static_cast<uint32_t>(calCount);
int32_t offset = static_cast<int32_t>(INT32_FLOAT32_ONE_REPEAT);
for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTimes); ++i) {
mask = MicroAPI::UpdateMask<float>(sreg1);
MicroAPI::DataCopy<float, MicroAPI::PostLiteral::POST_MODE_UPDATE>(vreg0, ubSinResult, offset);
MicroAPI::DataCopy<float, MicroAPI::PostLiteral::POST_MODE_UPDATE>(vreg1, ubCosResult, offset);
MicroAPI::DataCopy<float, MicroAPI::PostLiteral::POST_MODE_UPDATE>(vreg2, ubU2Result, offset);
MicroAPI::Mul(vreg3, vreg0, vreg2, mask);
MicroAPI::Mul(vreg4, vreg1, vreg2, mask);
MicroAPI::DataCopy<int32_t, MicroAPI::StoreDist::DIST_INTLV_B32>(
reinterpret_cast<__ubuf__ int32_t*>(ubOut + i * offset * DOUBLE_UNIFORM_RESULT),
(MicroAPI::RegTensor<int32_t>&)(vreg3), (MicroAPI::RegTensor<int32_t>&)(vreg4), mask);
}
}
}
__simt_callee__ __aicore__ inline void BoxMullerFloatSafe(float u1, const float u2, float* z0, float* z1)
{
const float eps = 1.0e-7f;
if (u1 < eps) {
u1 = eps;
}
float v = static_cast<float>(DOUBLE_MULTIPLE * PI * u2);
float r = sqrtf(-DOUBLE_MULTIPLE * logf(u1));
sincosf(v, z0, z1);
*z0 *= r;
*z1 *= r;
}
__simt_callee__ __aicore__ inline void BoxMullerFloat(float u1, const float u2, float* z0, float* z1)
{
float v = static_cast<float>(DOUBLE_MULTIPLE * PI * u2);
float r = sqrtf(-DOUBLE_MULTIPLE * logf(u1));
sincosf(v, z0, z1);
*z0 *= r;
*z1 *= r;
}
template <uint16_t COPY_SIZE>
__simt_callee__ __aicore__ inline void CopyArray(uint32_t* dst, const uint32_t* src)
{
#pragma unroll
for (uint16_t i = 0; i < COPY_SIZE; i++) {
dst[i] = src[i];
}
}
__simt_callee__ __aicore__ inline void SkipOne(uint32_t* counter)
{
if (++counter[0])
return;
if (++counter[1])
return;
if (++counter[IDX_2])
return;
++counter[IDX_3];
}
__simt_callee__ __aicore__ inline void SkipLo(uint32_t* counter, uint64_t n)
{
const uint32_t nlo = static_cast<uint32_t>(n);
uint32_t nhi = static_cast<uint32_t>(n >> RIGHT_SHIFT);
counter[0] += nlo;
if (counter[0] < nlo) {
nhi++;
}
counter[1] += nhi;
if (nhi <= counter[1])
return;
if (++counter[IDX_2])
return;
++counter[IDX_3];
}
__simt_callee__ __aicore__ inline void SkipHi(uint32_t* counter, uint64_t n)
{
const uint32_t countLo = static_cast<uint32_t>(n);
uint32_t countHi = static_cast<uint32_t>(n >> RIGHT_SHIFT);
counter[IDX_2] += countLo;
if (counter[IDX_2] < countLo) {
countHi++;
}
counter[IDX_3] += countHi;
}
__simt_callee__ __aicore__ inline void FlashCounter(uint64_t globalThreadIdx, uint64_t offset, uint32_t* counter)
{
SkipHi(counter, globalThreadIdx);
SkipLo(counter, offset);
}
__simt_callee__ __aicore__ inline void PhiloxAlgParsInit(uint32_t* key, uint32_t* counter, int64_t seed, int64_t offset)
{
key[0] = static_cast<uint32_t>(seed);
key[1] = static_cast<uint32_t>(seed >> RIGHT_SHIFT);
offset = (offset + VEC_4 - 1) / VEC_4;
SkipLo(counter, offset);
}
__simt_callee__ __aicore__ inline void MultiplyHighLow(
uint32_t a, uint32_t b, uint32_t* resultLow, uint32_t* resultHigh)
{
const uint64_t product = static_cast<uint64_t>(a) * b;
*resultLow = static_cast<uint32_t>(product);
*resultHigh = static_cast<uint32_t>(product >> RIGHT_SHIFT);
}
__simt_callee__ __aicore__ inline void Philox4x32Round(uint32_t* counter, const uint32_t* key)
{
uint32_t lo0;
uint32_t hi0;
MultiplyHighLow(PHILOX_M4X32_A, counter[0], &lo0, &hi0);
uint32_t lo1;
uint32_t hi1;
MultiplyHighLow(PHILOX_M4X32_B, counter[IDX_2], &lo1, &hi1);
uint32_t result[ALG_COUNTER_SIZE];
result[0] = hi1 ^ counter[1] ^ key[0];
result[1] = lo1;
result[IDX_2] = hi0 ^ counter[IDX_3] ^ key[1];
result[IDX_3] = lo0;
CopyArray<ALG_COUNTER_SIZE>(counter, result);
}
__simt_callee__ __aicore__ inline void KeyInc(uint32_t* key)
{
key[0] += PHILOX_W32_A;
key[1] += PHILOX_W32_B;
}
__simt_callee__ __aicore__ inline void PhiloxRandomSimt(const uint32_t* key, const uint32_t* counter, uint32_t* results)
{
uint32_t keyTmp[ALG_KEY_SIZE];
uint32_t counterTmp[ALG_COUNTER_SIZE];
CopyArray<ALG_KEY_SIZE>(keyTmp, key);
CopyArray<ALG_COUNTER_SIZE>(counterTmp, counter);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
KeyInc(keyTmp);
Philox4x32Round(counterTmp, keyTmp);
CopyArray<ALG_COUNTER_SIZE>(results, counterTmp);
}
__simt_callee__ __aicore__ inline void PhiloxRandomSimt(const uint32_t* key, const uint32_t* counter, float* results)
{
uint32_t resultU32[ALG_COUNTER_SIZE];
PhiloxRandomSimt(key, counter, resultU32);
#pragma unroll
for (uint16_t i = 0; i < ALG_COUNTER_SIZE; i++) {
results[i] = resultU32[i] * RAND_2POW32_INV + RAND_2POW32_INV_HALF;
}
}
template <int32_t STEP, int32_t ARANGE_MODE>
__simt_callee__ __aicore__ inline void ThreadMappingAndSkip(
uint64_t idx, uint32_t* counter, uint64_t magic, uint64_t shift, uint64_t totalThreads)
{
static_assert(
!(ARANGE_MODE == DIS_CONTINUOUS_USE && STEP != VEC_4),
"When ARANGE_MODE is DIS_CONTINUOUS_USE, STEP must be 4");
uint64_t idxTmp = idx / STEP;
uint64_t globalThreadIdx = 0;
uint64_t repeat = Simt::UintDiv(idxTmp, magic, shift);
if constexpr (ARANGE_MODE == CONTINUOUS_USE) {
globalThreadIdx = idxTmp - repeat * totalThreads;
} else {
auto repeatTmp = Simt::UintDiv(idx, magic, shift);
globalThreadIdx = idx - repeatTmp * totalThreads;
}
if constexpr (STEP == VEC_8 || STEP == VEC_16) {
repeat = repeat * (STEP / VEC_4) + (idx / VEC_4 % (STEP / VEC_4));
}
FlashCounter(globalThreadIdx, repeat, counter);
}
使用说明
uint32_t key[ALG_KEY_SIZE] = {0, 0};
uint32_t counter[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
PhiloxAlgParsInit(key, counter, seed, offset);
int32_t step = 4;
uint64_t totalThreads = gridDimx(动态计算) * blockDim(固定值);
uint64_t magic, shift;
GetUintDivMagicAndShift(magic, shift, totalThreads);
// 方式1:
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * step; i < outputLength;
i += gridDim.x * blockDim.x * step)
{
uint32_t results[ALG_COUNTER_SIZE]; // 或者 float类型
uint32_t counterTmp[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
CopyArray<ALG_COUNTER_SIZE>(counterTmp, counter);
ThreadMappingAndSkip<step, CONTINUOUS_USE>(i, counterTmp, magic, shift, totalThreads);
PhiloxRandomSimt(key, counterTmp, results);
// 使用results 对连续的4个索引做操作
}
// 方式2:
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x); i < outputLength;
i += gridDim.x * blockDim.x * step)
{
uint32_t results[ALG_COUNTER_SIZE]; // 或者 float类型
uint32_t counterTmp[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
CopyArray<ALG_COUNTER_SIZE>(counterTmp, counter);
ThreadMappingAndSkip<step, DIS_CONTINUOUS_USE>(i, counterTmp, magic, shift, totalThreads);
PhiloxRandomSimt(key, counterTmp, results);
// 使用results 对非连续的4个索引做操作,stride为totalThreads
}
*/
class RandomKernelBaseOp {
public:
__aicore__ inline RandomKernelBaseOp(const RandomUnifiedTilingDataStruct* __restrict tilingData)
: tiling_(tilingData){};
__aicore__ inline void VarsInit();
__aicore__ inline void Skip(const uint64_t count);
__aicore__ inline void GenRandomSIMD(LocalTensor<uint32_t> randomLocal, const uint64_t count);
const RandomUnifiedTilingDataStruct* tiling_;
uint32_t key_[ALG_KEY_SIZE] = {0};
uint32_t counter_[ALG_COUNTER_SIZE] = {0};
uint32_t blockIdx_;
int64_t curCoreProNum_ = 0;
int64_t ubRepeatimes_ = 0;
};
__aicore__ inline void RandomKernelBaseOp::VarsInit()
{
blockIdx_ = GetBlockIdx();
if (blockIdx_ > tiling_->usedCoreNum) {
return;
}
if (blockIdx_ == tiling_->usedCoreNum - 1) {
curCoreProNum_ = tiling_->tailCoreProNum;
} else {
curCoreProNum_ = tiling_->normalCoreProNum;
}
ubRepeatimes_ = Ops::Base::CeilDiv(curCoreProNum_, tiling_->singleBufferSize);
for (uint32_t i = 0; i < ALG_KEY_SIZE; i++) {
key_[i] = tiling_->key[i];
}
for (uint32_t i = 0; i < ALG_COUNTER_SIZE; i++) {
counter_[i] = tiling_->counter[i];
}
GlobalTensor SetGlobalBuffer
Que InitBuffer
*/
}
__aicore__ inline void RandomKernelBaseOp::Skip(const uint64_t count)
{
const uint32_t countLo = static_cast<uint32_t>(count);
uint32_t countHi = static_cast<uint32_t>(count >> RIGHT_SHIFT);
counter_[0] += countLo;
if (counter_[0] < countLo) {
++countHi;
}
counter_[1] += countHi;
if (counter_[1] < countHi) {
if (++counter_[2] == 0) {
++counter_[3];
}
}
}
__aicore__ inline void RandomKernelBaseOp::GenRandomSIMD(LocalTensor<uint32_t> randomLocal, const uint64_t count)
{
constexpr uint16_t kPhiloxRounds = 10;
PhiloxRandom<kPhiloxRounds>(
randomLocal, {key_[0], key_[1]}, {counter_[0], counter_[1], counter_[2], counter_[3]}, count);
}
constexpr uint32_t SIMT_STEP = 4;
constexpr uint32_t DEFAULT_SIMT_THREAD_NUM = 512;
struct ExecutionPolicyKernel {
uint64_t magic;
uint64_t shift;
};
template <typename T, typename TransformFunc, uint32_t ThreadNum = DEFAULT_SIMT_THREAD_NUM, uint32_t Vec = SIMT_STEP>
__simt_vf__ __aicore__ LAUNCH_BOUND(ThreadNum) inline void PhiloxSimtKernelContinuous(
__gm__ volatile T* outputGm, int64_t offset, int64_t seed, uint64_t outputSize, uint64_t magic, uint64_t shift,
uint64_t totalThreads, TransformFunc transform)
{
uint32_t key[ALG_KEY_SIZE] = {0, 0};
uint32_t counter[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
PhiloxAlgParsInit(key, counter, seed, offset);
uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
uint32_t randRounds = (Vec + SIMT_STEP - 1) / SIMT_STEP;
uint32_t resultsAll[16] = {0U};
for (uint64_t i = idx * Vec; i < outputSize; i += blockDim.x * gridDim.x * Vec) {
for (uint32_t randIdx = 0; randIdx < randRounds; randIdx++) {
uint32_t counterTmp[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
CopyArray<ALG_COUNTER_SIZE>(counterTmp, counter);
ThreadMappingAndSkip<Vec, CONTINUOUS_USE>(i + randIdx * SIMT_STEP, counterTmp, magic, shift, totalThreads);
uint32_t results[SIMT_STEP];
PhiloxRandomSimt(key, counterTmp, results);
for (uint32_t j = 0; j < SIMT_STEP && randIdx * SIMT_STEP + j < Vec; j++) {
resultsAll[randIdx * SIMT_STEP + j] = results[j];
}
}
for (uint32_t j = 0; j < Vec; j++) {
uint64_t li = i + j;
if (li < outputSize) {
transform(outputGm, li, resultsAll, j);
}
}
}
}
template <typename T, typename TransformFunc, uint32_t ThreadNum = DEFAULT_SIMT_THREAD_NUM, uint32_t UNROLL = 4>
__simt_vf__ __aicore__ LAUNCH_BOUND(ThreadNum) inline void PhiloxSimtKernelDiscontinuous(
__gm__ volatile T* outputGm, int64_t offset, int64_t seed, uint64_t outputSize, uint64_t magic, uint64_t shift,
uint64_t totalThreads, TransformFunc transform)
{
uint32_t key[ALG_KEY_SIZE] = {0, 0};
uint32_t counter[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
PhiloxAlgParsInit(key, counter, seed, offset);
uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
uint64_t repeatTime = (outputSize + totalThreads * UNROLL - 1) / (totalThreads * UNROLL);
for (uint64_t loopIdx = 0; loopIdx < repeatTime; loopIdx++) {
for (uint64_t linearIndex = idx; linearIndex < totalThreads; linearIndex += blockDim.x * gridDim.x) {
uint32_t counterTmp[ALG_COUNTER_SIZE] = {0, 0, 0, 0};
CopyArray<ALG_COUNTER_SIZE>(counterTmp, counter);
FlashCounter(linearIndex, loopIdx, counterTmp);
uint32_t results[SIMT_STEP];
PhiloxRandomSimt(key, counterTmp, results);
for (uint32_t iStep = 0; iStep < UNROLL; iStep++) {
uint64_t li = linearIndex + loopIdx * totalThreads * UNROLL + totalThreads * iStep;
if (li < outputSize) {
transform(outputGm, li, results, iStep, SIMT_STEP / UNROLL);
}
}
}
}
}
template <typename Launcher>
__aicore__ inline void ProcessWithSplitBlocks(
const RandomUnifiedSimtTilingDataStruct* __restrict tilingData, Launcher& launcher)
{
for (int64_t blockIdx = 0; blockIdx < tilingData->splitBlockCount; blockIdx++) {
const SplitBlockInfo& block = tilingData->splitBlocks[blockIdx];
ExecutionPolicyKernel policy;
GetUintDivMagicAndShift(policy.magic, policy.shift, static_cast<uint64_t>(block.totalThreads));
launcher(policy, block.gmOffset, block.kernelOffset, block.numel, block.grid, block.totalThreads);
}
}
}
* 其他算子接入SIMT Kernel模板说明:
*
* 1. Tiling侧配置:
* OpTilingConfig config;
* config.kernelMode = RandomKernelMode::SIMT;
* config.enableSplitBlocks = true;
*
* 2. Kernel侧定义TransformFunc:
* template <typename T>
* struct MyTransform {
* float param_;
* __aicore__ MyTransform(float param) : param_(param) {}
* __simt_callee__ __aicore__ inline void operator()(
* __gm__ volatile T* outputGm, uint64_t li, const uint32_t* results, uint32_t iStep,
* [[maybe_unused]] uint32_t randPerOutput = 1)
* {
* // results[iStep] 是 uint32 随机数,需自行转换为 float
* float u = results[iStep] * RAND_2POW32_INV + RAND_2POW32_INV_HALF;
* float val = logf(u);
* outputGm[li] = static_cast<T>(val * param_);
* }
* };
*
* 3. TransformFunc 参数说明:
* - outputGm: 输出 Global Memory 地址
* - li: 当前写入的线性索引
* - results: Philox 生成的 uint32 随机数数组(大小为 SIMT_STEP=4)
* - iStep: 当前处理的元素索引(0 ~ UNROLL-1)
* - randPerOutput: 每个输出消耗的 Philox 输出数量(SIMT_STEP/UNROLL,默认=1)
*
* 4. 调用模板:
* AscendC::Simt::VF_CALL<PhiloxSimtKernelDiscontinuous<T, MyTransform<T>>>(
* AscendC::Simt::Dim3(DEFAULT_SIMT_THREAD_NUM),
* outputGm, kernelOffset, seed, numel,
* policy.magic, policy.shift, totalThreads, transform);
*
* 5. 切分执行(如果 enableSplitBlocks=true):
* struct MyLauncher {
* __aicore__ inline void operator()(
* const ExecutionPolicyKernel& policy, int64_t gmOffset, int64_t kernelOffset,
* int64_t numel, int64_t grid, int64_t totalThreads)
* {
* __gm__ volatile T* gmPtr = (__gm__ volatile T*)baseAddr + gmOffset;
* MyTransform<T> transform(param);
* AscendC::Simt::VF_CALL<PhiloxSimtKernelDiscontinuous<T, MyTransform<T>>>(
* AscendC::Simt::Dim3(DEFAULT_SIMT_THREAD_NUM),
* gmPtr, kernelOffset, seed, numel,
* policy.magic, policy.shift, totalThreads, transform);
* }
* };
* ProcessWithSplitBlocks(tilingData, launcher);
*/
#endif