/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file philox_c310_impl.h
 * \brief
 */

#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message("impl/adv_api/detail/math/philox/philox_c310_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/math/philox.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_PHILOX_PHILOX_C310_IMPL_H__
#endif
#ifndef IMPL_MATH_PHILOX_PHILOX_C310_IMPL_H
#define IMPL_MATH_PHILOX_PHILOX_C310_IMPL_H
#include <cstdint>
#include "kernel_basic_intf.h"
namespace AscendC {
namespace PhiloxInternal {
// philox algorithm constant
constexpr uint32_t CONST_MUL_0 = 0xD2511F53;
constexpr uint32_t CONST_MUL_1 = 0xCD9E8D57;
constexpr uint32_t CONST_KEY_ADD_0 = 0x9E3779B9;
constexpr uint32_t CONST_KEY_ADD_1 = 0xBB67AE85;

// philox algorithm each iter 128bit(4*32bit)
constexpr uint16_t PHILOX_ONCE_COUNTER_BIT = 128;
constexpr uint16_t PHILOX_ONCE_COUNTER_BYTE = PHILOX_ONCE_COUNTER_BIT / 8;
// philox algorithm each iter 4 element, current type is B32(uint32_t/int32_t/float)
constexpr uint16_t PHILOX_ONCE_COUNTER_NUM = PHILOX_ONCE_COUNTER_BYTE / sizeof(uint32_t);
// philox parallel one column is VL/B32
constexpr uint32_t ELE_CNT_B32_ONCE = GetVecLen() / sizeof(uint32_t);
constexpr uint16_t PHILOX_ONCE_REPEAT_NUM = PHILOX_ONCE_COUNTER_NUM * ELE_CNT_B32_ONCE;

// uint32 to float32
// |1|_____8____|___________23___________|
// |s|exponent  | mantissa               |
constexpr uint32_t MANTISSA = static_cast<uint32_t>(0x7fffffu);  // 23 bit mantissa
constexpr uint32_t EXP_MASK = static_cast<uint32_t>(127) << 23u; // 7 bit exp
} // namespace PhiloxInternal

// 64 bit key and 128-bit counter, little endian
constexpr uint16_t PHILOX_KEY_SIZE = 2;
constexpr uint16_t PHILOX_COUNTER_SIZE = 4;
struct philoxStruct {
    uint32_t philoxKey[PHILOX_KEY_SIZE] = { 0 };
    uint32_t philoxCounter[PHILOX_COUNTER_SIZE] = {0};
};

using PhiloxKey = uint32_t[PHILOX_KEY_SIZE];
using PhiloxCounter = uint32_t[PHILOX_COUNTER_SIZE];

struct PhiloxRandomParams {
    uint32_t stride;
    uint32_t row;
    uint32_t column;
};

__simd_callee__ inline void AddWith128Bits(Reg::RegTensor<uint32_t> &ctr0, Reg::RegTensor<uint32_t> &ctr1,
    Reg::RegTensor<uint32_t> &ctr2, Reg::RegTensor<uint32_t> &ctr3, Reg::RegTensor<uint32_t> &value,
    Reg::MaskReg &pg)
{
    Reg::MaskReg pd;
    Reg::RegTensor<uint32_t> vZero;
    Duplicate(vZero, 0x0);
    AddCarryOut(pd, ctr0, ctr0, value, pg);
    AddCarryOuts(pd, ctr1, ctr1, vZero, pd, pg);
    AddCarryOuts(pd, ctr2, ctr2, vZero, pd, pg);
    AddCarryOuts(pd, ctr3, ctr3, vZero, pd, pg);
}

__simd_callee__ inline void UInt2Float(Reg::RegTensor<uint32_t> &tmpCtr0, Reg::RegTensor<uint32_t> &tmpCtr1,
    Reg::RegTensor<uint32_t> &tmpCtr2, Reg::RegTensor<uint32_t> &tmpCtr3, Reg::MaskReg &pg)
{
    Reg::RegTensor<uint32_t> vb32ManMask, vb32ExpMask;
    Duplicate(vb32ManMask, PhiloxInternal::MANTISSA);
    Duplicate(vb32ExpMask, PhiloxInternal::EXP_MASK);
    And(tmpCtr0, tmpCtr0, vb32ManMask, pg);
    And(tmpCtr1, tmpCtr1, vb32ManMask, pg);
    And(tmpCtr2, tmpCtr2, vb32ManMask, pg);
    And(tmpCtr3, tmpCtr3, vb32ManMask, pg);
    Or(tmpCtr0, tmpCtr0, vb32ExpMask, pg);
    Or(tmpCtr1, tmpCtr1, vb32ExpMask, pg);
    Or(tmpCtr2, tmpCtr2, vb32ExpMask, pg);
    Or(tmpCtr3, tmpCtr3, vb32ExpMask, pg);
    Adds((Reg::RegTensor<float> &)tmpCtr0, (Reg::RegTensor<float> &)tmpCtr0, -1.0f, pg);
    Adds((Reg::RegTensor<float> &)tmpCtr1, (Reg::RegTensor<float> &)tmpCtr1, -1.0f, pg);
    Adds((Reg::RegTensor<float> &)tmpCtr2, (Reg::RegTensor<float> &)tmpCtr2, -1.0f, pg);
    Adds((Reg::RegTensor<float> &)tmpCtr3, (Reg::RegTensor<float> &)tmpCtr3, -1.0f, pg);
}

template <uint16_t Rounds>
__simd_callee__ inline void SpNetworkKernel(Reg::RegTensor<uint32_t> &tmpL0, Reg::RegTensor<uint32_t> &tmpH0,
    Reg::RegTensor<uint32_t> &tmpL1, Reg::RegTensor<uint32_t> &tmpH1, Reg::RegTensor<uint32_t> &tmpCtr0,
    Reg::RegTensor<uint32_t> &tmpCtr1, Reg::RegTensor<uint32_t> &tmpCtr2,
    Reg::RegTensor<uint32_t> &tmpCtr3, Reg::RegTensor<uint32_t> &tmpKey0,
    Reg::RegTensor<uint32_t> &tmpKey1, Reg::RegTensor<uint32_t> &cMul0, Reg::RegTensor<uint32_t> &cMul1,
    Reg::MaskReg &pg)
{
    // pragma unroll vs manual unroll(442).
    // when count=16384, round 10, cycles(6904 vs 5549), vex(6822 vs 4260), ipc(0.988 vs 0.768)
    // #pragma unroll  // ccec may have a bug: https://codehub-y.huawei.com/c00564736/D_Compile_Issue/issues/172
    for (uint16_t j = 0; j < Rounds; j++) {
        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);
    }
}

template <uint16_t Rounds, typename T, bool DstUnalign = false>
__simd_callee__ inline void SpNetworkFull(__ubuf__ uint32_t *dstUbTail, uint16_t tailCount,
    Reg::RegTensor<uint32_t> &ctr0, Reg::RegTensor<uint32_t> &ctr1, Reg::RegTensor<uint32_t> &ctr2,
    Reg::RegTensor<uint32_t> &ctr3, Reg::RegTensor<uint32_t> &key0, Reg::RegTensor<uint32_t> &key1,
    Reg::RegTensor<uint32_t> &cMul0, Reg::RegTensor<uint32_t> &cMul1, Reg::MaskReg &pg)
{
    Reg::RegTensor<uint32_t> tmpCtr3, tmpCtr2, tmpCtr1, tmpCtr0;
    tmpCtr0 = ctr0;
    tmpCtr1 = ctr1;
    tmpCtr2 = ctr2;
    tmpCtr3 = ctr3;
    Reg::RegTensor<uint32_t> tmpKey0 = key0;
    Reg::RegTensor<uint32_t> tmpKey1 = key1;
    Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;
    SpNetworkKernel<Rounds>(tmpL0, tmpH0, tmpL1, tmpH1, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, tmpKey0, tmpKey1, cMul0,
        cMul1, pg);

    if constexpr (std::is_same_v<T, float>) {
        UInt2Float(tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
    }

    if constexpr (DstUnalign) {
        Reg::RegTensor<uint32_t> reorderIndex;
        Reg::Arange((Reg::RegTensor<int32_t> &)reorderIndex, 0);
        Muls(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_COUNTER_NUM, pg);
        // column % 4 = 0, scatter pgTail is tailCount / 4
        uint32_t sreg = static_cast<uint32_t>(tailCount / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::MaskReg pgTail = Reg::UpdateMask<T>(sreg);
        Reg::Scatter(dstUbTail, tmpCtr0, reorderIndex, pgTail);
        Reg::Scatter(dstUbTail + 1, tmpCtr1, reorderIndex, pgTail);
        Reg::Scatter(dstUbTail + 2, tmpCtr2, reorderIndex, pgTail);
        Reg::Scatter(dstUbTail + 3, tmpCtr3, reorderIndex, pgTail);
    } else {
        Interleave(tmpCtr0, tmpCtr2, tmpCtr0, tmpCtr2);
        Interleave(tmpCtr1, tmpCtr3, tmpCtr1, tmpCtr3);
        Interleave(tmpCtr0, tmpCtr1, tmpCtr0, tmpCtr1);
        Interleave(tmpCtr2, tmpCtr3, tmpCtr2, tmpCtr3);
        uint32_t sreg = static_cast<uint32_t>(tailCount);
        Reg::MaskReg pgTail = Reg::UpdateMask<T>(sreg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbTail, tmpCtr0,
            PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
        pgTail = Reg::UpdateMask<T>(sreg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbTail, tmpCtr1,
            PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
        pgTail = Reg::UpdateMask<T>(sreg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbTail, tmpCtr2,
            PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
        pgTail = Reg::UpdateMask<T>(sreg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbTail, tmpCtr3,
            PhiloxInternal::ELE_CNT_B32_ONCE, pgTail);
    }
}

template <bool DstUnalign = false>
__simd_callee__ inline void PhiloxUnrollStoreTmpCtrl(__ubuf__ uint32_t *&dstUbT, Reg::RegTensor<uint32_t> &tmpCtr0,
    Reg::RegTensor<uint32_t> &tmpCtr1, Reg::RegTensor<uint32_t> &tmpCtr2,
    Reg::RegTensor<uint32_t> &tmpCtr3, Reg::MaskReg &pg)
{
    if constexpr (DstUnalign) {
        Reg::UnalignReg ureg;
        Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr0, ureg,
            PhiloxInternal::ELE_CNT_B32_ONCE);
        Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr1, ureg,
            PhiloxInternal::ELE_CNT_B32_ONCE);
        Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr2, ureg,
            PhiloxInternal::ELE_CNT_B32_ONCE);
        Reg::StoreUnAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr3, ureg,
            PhiloxInternal::ELE_CNT_B32_ONCE);
        Reg::StoreUnAlignPost<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, ureg, 0);
    } else {
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr0,
            PhiloxInternal::ELE_CNT_B32_ONCE, pg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr1,
            PhiloxInternal::ELE_CNT_B32_ONCE, pg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr2,
            PhiloxInternal::ELE_CNT_B32_ONCE, pg);
        Reg::StoreAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(dstUbT, tmpCtr3,
            PhiloxInternal::ELE_CNT_B32_ONCE, pg);
    }
}

template <bool DstUnalign = false>
__simd_callee__ inline void PhiloxUnrollLoadTmpCtrl(__ubuf__ uint32_t *&dstUbTT0, __ubuf__ uint32_t *&dstUbTT1,
    __ubuf__ uint32_t *&dstUbTT2, __ubuf__ uint32_t *&dstUbTT3, Reg::RegTensor<uint32_t> &tmpCtr0,
    Reg::RegTensor<uint32_t> &tmpCtr1, Reg::RegTensor<uint32_t> &tmpCtr2,
    Reg::RegTensor<uint32_t> &tmpCtr3)
{
    if constexpr (DstUnalign) {
        Reg::UnalignReg ureg;
        Reg::LoadUnAlignPre(ureg, dstUbTT0);
        Reg::LoadUnAlign(tmpCtr0, ureg, dstUbTT0,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::LoadUnAlignPre(ureg, dstUbTT1);
        Reg::LoadUnAlign(tmpCtr1, ureg, dstUbTT1,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::LoadUnAlignPre(ureg, dstUbTT2);
        Reg::LoadUnAlign(tmpCtr2, ureg, dstUbTT2,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::LoadUnAlignPre(ureg, dstUbTT3);
        Reg::LoadUnAlign(tmpCtr3, ureg, dstUbTT3,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
    } else {
        Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(tmpCtr0, dstUbTT0,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(tmpCtr1, dstUbTT1,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(tmpCtr2, dstUbTT2,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
        Reg::LoadAlign<uint32_t, Reg::PostLiteral::POST_MODE_UPDATE>(tmpCtr3, dstUbTT3,
            PhiloxInternal::ELE_CNT_B32_ONCE * PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
    }
}

template <typename T, bool DstUnalign = false>
__simd_callee__ inline void PhiloxRound10MainBlockUnroll442(__ubuf__ uint32_t *dstUb, uint16_t mainIter,
    Reg::RegTensor<uint32_t> &ctr0, Reg::RegTensor<uint32_t> &ctr1, Reg::RegTensor<uint32_t> &ctr2,
    Reg::RegTensor<uint32_t> &ctr3, Reg::RegTensor<uint32_t> &key0, Reg::RegTensor<uint32_t> &key1,
    Reg::RegTensor<uint32_t> &cMul0, Reg::RegTensor<uint32_t> &cMul1,
    Reg::RegTensor<uint32_t> &vEleStrideB32OneRow, Reg::MaskReg &pg)
{
    Reg::RegTensor<uint32_t> tmpCtr3, tmpCtr2, tmpCtr1, tmpCtr0;
    __ubuf__ uint32_t *dstUbT = dstUb;

    Reg::RegTensor<uint32_t> reorderIndex;
    Arange((Reg::RegTensor<int32_t> &)reorderIndex, 0);
    Muls(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_COUNTER_NUM, pg);

    for (uint16_t i = 0; i < mainIter; i++) {
        tmpCtr0 = ctr0;
        tmpCtr1 = ctr1;
        tmpCtr2 = ctr2;
        tmpCtr3 = ctr3;
        Reg::RegTensor<uint32_t> tmpKey0 = key0;
        Reg::RegTensor<uint32_t> tmpKey1 = key1;
        Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        PhiloxUnrollStoreTmpCtrl<DstUnalign>(dstUbT, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
        AddWith128Bits(ctr0, ctr1, ctr2, ctr3, vEleStrideB32OneRow, pg);
    }

    dstUbT = dstUb;
    __ubuf__ uint32_t *dstUbTT0 = dstUbT;
    __ubuf__ uint32_t *dstUbTT1 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE;
    __ubuf__ uint32_t *dstUbTT2 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 2;
    __ubuf__ uint32_t *dstUbTT3 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 3;
    Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < mainIter; i++) {
        PhiloxUnrollLoadTmpCtrl<DstUnalign>(dstUbTT0, dstUbTT1, dstUbTT2, dstUbTT3, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3);

        Reg::RegTensor<uint32_t> tmpKey0 = key0;
        Reg::RegTensor<uint32_t> tmpKey1 = key1;

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        PhiloxUnrollStoreTmpCtrl<DstUnalign>(dstUbT, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
    }

    dstUbT = dstUb;
    dstUbTT0 = dstUbT;
    dstUbTT1 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE;
    dstUbTT2 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 2;
    dstUbTT3 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 3;
    Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < mainIter; i++) {
        PhiloxUnrollLoadTmpCtrl<DstUnalign>(dstUbTT0, dstUbTT1, dstUbTT2, dstUbTT3, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3);

        Reg::RegTensor<uint32_t> tmpKey0 = key0;
        Reg::RegTensor<uint32_t> tmpKey1 = key1;

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;

        if constexpr (std::is_same_v<T, float>) {
            UInt2Float(tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
        }

        Reg::Scatter(dstUb, tmpCtr0, reorderIndex, pg);
        Reg::Scatter(dstUb + 1, tmpCtr1, reorderIndex, pg);
        Reg::Scatter(dstUb + 2, tmpCtr2, reorderIndex, pg);
        Reg::Scatter(dstUb + 3, tmpCtr3, reorderIndex, pg);
        Adds(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_REPEAT_NUM, pg);
    }
}

template <typename T, bool DstUnalign = false>
__simd_callee__ inline void PhiloxRound7MainBlockUnroll43(__ubuf__ uint32_t *dstUb, uint16_t mainIter,
    Reg::RegTensor<uint32_t> &ctr0, Reg::RegTensor<uint32_t> &ctr1, Reg::RegTensor<uint32_t> &ctr2,
    Reg::RegTensor<uint32_t> &ctr3, Reg::RegTensor<uint32_t> &key0, Reg::RegTensor<uint32_t> &key1,
    Reg::RegTensor<uint32_t> &cMul0, Reg::RegTensor<uint32_t> &cMul1,
    Reg::RegTensor<uint32_t> &vEleStrideB32OneRow, Reg::MaskReg &pg)
{
    Reg::RegTensor<uint32_t> tmpCtr3, tmpCtr2, tmpCtr1, tmpCtr0;
    __ubuf__ uint32_t *dstUbT = dstUb;

    Reg::RegTensor<uint32_t> reorderIndex;
    Reg::Arange((Reg::RegTensor<int32_t> &)reorderIndex, 0);
    Muls(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_COUNTER_NUM, pg);

    for (uint16_t i = 0; i < mainIter; i++) {
        tmpCtr0 = ctr0;
        tmpCtr1 = ctr1;
        tmpCtr2 = ctr2;
        tmpCtr3 = ctr3;
        Reg::RegTensor<uint32_t> tmpKey0 = key0;
        Reg::RegTensor<uint32_t> tmpKey1 = key1;
        Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        PhiloxUnrollStoreTmpCtrl<DstUnalign>(dstUbT, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
        AddWith128Bits(ctr0, ctr1, ctr2, ctr3, vEleStrideB32OneRow, pg);
    }

    dstUbT = dstUb;
    __ubuf__ uint32_t *dstUbTT0 = dstUbT;
    __ubuf__ uint32_t *dstUbTT1 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE;
    __ubuf__ uint32_t *dstUbTT2 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 2;
    __ubuf__ uint32_t *dstUbTT3 = dstUbT + PhiloxInternal::ELE_CNT_B32_ONCE * 3;
    Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < mainIter; i++) {
        PhiloxUnrollLoadTmpCtrl<DstUnalign>(dstUbTT0, dstUbTT1, dstUbTT2, dstUbTT3, tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3);

        Reg::RegTensor<uint32_t> tmpKey0 = key0;
        Reg::RegTensor<uint32_t> tmpKey1 = key1;

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Reg::RegTensor<uint32_t> tmpL0, tmpH0, tmpL1, tmpH1;

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;
        Adds(tmpKey0, tmpKey0, PhiloxInternal::CONST_KEY_ADD_0, pg);
        Adds(tmpKey1, tmpKey1, PhiloxInternal::CONST_KEY_ADD_1, pg);

        Mull(tmpL0, tmpH0, tmpCtr0, cMul0, pg);
        Mull(tmpL1, tmpH1, tmpCtr2, cMul1, pg);
        Xor(tmpH1, tmpH1, tmpCtr1, pg);
        Xor(tmpCtr0, tmpH1, tmpKey0, pg);
        Xor(tmpH0, tmpH0, tmpCtr3, pg);
        Xor(tmpCtr2, tmpH0, tmpKey1, pg);
        tmpCtr1 = tmpL1;
        tmpCtr3 = tmpL0;

        if constexpr (std::is_same_v<T, float>) {
            UInt2Float(tmpCtr0, tmpCtr1, tmpCtr2, tmpCtr3, pg);
        }

        Reg::Scatter(dstUb, tmpCtr0, reorderIndex, pg);
        Reg::Scatter(dstUb + 1, tmpCtr1, reorderIndex, pg);
        Reg::Scatter(dstUb + 2, tmpCtr2, reorderIndex, pg);
        Reg::Scatter(dstUb + 3, tmpCtr3, reorderIndex, pg);
        Adds(reorderIndex, reorderIndex, PhiloxInternal::PHILOX_ONCE_REPEAT_NUM, pg);
    }
}

template <uint16_t Rounds = 7, typename T, bool DstUnalign = false>
__simd_callee__ inline void PhiloxRoundMainBlockUnroll(__ubuf__ uint32_t *dstUb, uint16_t mainIter,
    Reg::RegTensor<uint32_t> &ctr0, Reg::RegTensor<uint32_t> &ctr1, Reg::RegTensor<uint32_t> &ctr2,
    Reg::RegTensor<uint32_t> &ctr3, Reg::RegTensor<uint32_t> &key0, Reg::RegTensor<uint32_t> &key1,
    Reg::RegTensor<uint32_t> &cMul0, Reg::RegTensor<uint32_t> &cMul1,
    Reg::RegTensor<uint32_t> &vEleStrideB32OneRow, Reg::MaskReg &pg)
{
    if constexpr (Rounds == 10) {
        // main block with 4 + 4 + 2 unroll
        PhiloxRound10MainBlockUnroll442<T, DstUnalign>(dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0,
            cMul1, vEleStrideB32OneRow, pg);
    } else {
        // main block with 4 + 3 unroll
        PhiloxRound7MainBlockUnroll43<T, DstUnalign>(dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1,
            vEleStrideB32OneRow, pg);
    }
}

__simd_callee__ inline void PhiloxCounterInit(const PhiloxCounter &philoxCounter, Reg::RegTensor<uint32_t> &ctr0,
    Reg::RegTensor<uint32_t> &ctr1, Reg::RegTensor<uint32_t> &ctr2, Reg::RegTensor<uint32_t> &ctr3,
    Reg::RegTensor<int32_t> &incIdx, Reg::MaskReg &pg)
{
    Duplicate(ctr0, philoxCounter[0]);
    Duplicate(ctr1, philoxCounter[1]);
    Duplicate(ctr2, philoxCounter[2]);
    Duplicate(ctr3, philoxCounter[3]);
    AddWith128Bits(ctr0, ctr1, ctr2, ctr3, (Reg::RegTensor<uint32_t> &)incIdx, pg);
}

template <uint16_t Rounds = 7, typename T, bool DstUnalign = false>
__simd_vf__ inline void PhiloxRandomOneRow(__ubuf__ uint32_t *dstUb, __ubuf__ uint32_t *dstUbTail,
    const philoxStruct philox, uint16_t mainIter, uint16_t tailCount)
{
    Reg::MaskReg pg = Reg::CreateMask<uint32_t>();

    Reg::RegTensor<uint32_t> ctr3, ctr2, ctr1, ctr0;
    Reg::RegTensor<int32_t> incIdx;
    Reg::Arange(incIdx, 0);
    PhiloxCounterInit(philox.philoxCounter, ctr0, ctr1, ctr2, ctr3, incIdx, pg);

    Reg::RegTensor<uint32_t> vEleStrideB32OneRow;
    Duplicate(vEleStrideB32OneRow, PhiloxInternal::ELE_CNT_B32_ONCE);

    Reg::RegTensor<uint32_t> key1, key0;
    Duplicate(key0, philox.philoxKey[0]);
    Duplicate(key1, philox.philoxKey[1]);

    Reg::RegTensor<uint32_t> cMul0, cMul1;
    Duplicate(cMul0, PhiloxInternal::CONST_MUL_0);
    Duplicate(cMul1, PhiloxInternal::CONST_MUL_1);

    PhiloxRoundMainBlockUnroll<Rounds, T, DstUnalign>(dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1,
        vEleStrideB32OneRow, pg);

    if (tailCount > 0) {
        SpNetworkFull<Rounds, T>(dstUbTail, tailCount, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1, pg);
    }
}

/*
// derive index calculation
for (i : row)
  for (j : column)  // column < half one repeat
    index[i * column + j] = i * stride + j
==>
factor = ONE_REPEAT_LEN / column
i.i_extent = factor
i.o_extent = row / factor
i = i.o * factor + i.i  < row // contains an if constraint condition
for (i.o, 0, row / factor) {
  for (i.i, 0, factor) {
    if (i.o * factor + i.i  < row) {
      for (j, 0, column) {
        index[(i.o * factor + i.i) * column + j] = i.o * factor * stride + i.i * stride + j
      }
    }
  }
}
==>
j_fuse = i.i * column + j
j_fuse_extent = factor * column < ONE_REPEAT_LEN
i.i = j_fuse / column
j = j_fuse % column
for (i.o, 0, row / factor) {
  if (i.o * factor * column + j_fuse < row * column) {
    for (j_fuse, 0, factor * column) {
      index[i.o * factor * column + j_fuse] = i.o * factor * stride + j_fuse / column * stride + j_fuse % column
    }
  }
}
*/
__simd_callee__ inline void PhiloxRandomIndexCal(__ubuf__ int32_t *indexUb, const PhiloxRandomParams &params,
    const uint32_t fuseFactor)
{
    __ubuf__ int32_t *indexUbT = indexUb;
    Reg::MaskReg pg = Reg::CreateMask<uint32_t>();
    Reg::RegTensor<int32_t> index, incIdx;
    Reg::UnalignReg ureg;
    uint32_t stride = static_cast<uint32_t>(params.stride / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);
    uint32_t elementNum = params.column / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM;
    for (uint16_t i = 0; i < fuseFactor; i++) {
        Reg::Duplicate(index, i);
        Reg::Muls(index, index, stride, pg);
        Reg::Arange(incIdx, 0);
        Reg::Add(index, index, incIdx, pg);
        Reg::StoreUnAlign<int32_t, Reg::PostLiteral::POST_MODE_UPDATE>(indexUbT, index, ureg, elementNum);
    }
    Reg::StoreUnAlignPost<int32_t, Reg::PostLiteral::POST_MODE_UPDATE>(indexUbT, ureg, 0);
    Reg::LocalMemBar<Reg::MemType::VEC_STORE, Reg::MemType::VEC_LOAD>();
}

template <uint16_t Rounds = 7, typename T, bool DstBlockUnalign, bool DstRepeatUnalign>
__simd_vf__ inline void PhiloxRandomMultiRowWithFuse(__ubuf__ uint32_t *dstUbStart, __ubuf__ int32_t *indexUb,
    const philoxStruct philox, const PhiloxRandomParams params, const uint32_t fuseFactor, const uint32_t mainFuseAxis,
    const uint32_t mainRowsNum, const uint32_t tailFuseAxis)
{
    PhiloxRandomIndexCal(indexUb, params, fuseFactor);
    Reg::RegTensor<uint32_t> ctr3, ctr2, ctr1, ctr0;

    Reg::MaskReg pg = Reg::CreateMask<uint32_t>();
    Reg::RegTensor<int32_t> incIdx;
    Reg::LoadAlign(incIdx, indexUb);
    PhiloxCounterInit(philox.philoxCounter, ctr0, ctr1, ctr2, ctr3, incIdx, pg);

    Reg::RegTensor<uint32_t> key1, key0;
    Duplicate(key0, philox.philoxKey[0]);
    Duplicate(key1, philox.philoxKey[1]);

    Reg::RegTensor<uint32_t> cMul0, cMul1;
    Duplicate(cMul0, PhiloxInternal::CONST_MUL_0);
    Duplicate(cMul1, PhiloxInternal::CONST_MUL_1);

    Reg::RegTensor<uint32_t> vEleStrideB32OneRow;
    Duplicate(vEleStrideB32OneRow, fuseFactor * params.stride / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM);

    if constexpr (!DstRepeatUnalign) {
        PhiloxRoundMainBlockUnroll<Rounds, T, DstBlockUnalign>(dstUbStart, mainRowsNum, ctr0, ctr1, ctr2, ctr3, key0,
            key1, cMul0, cMul1, vEleStrideB32OneRow, pg);
    } else {
        // if dst repeat unalign, use MainBlockUnroll algo may be dst overlap. SpNetworkFull can control each element.
        for (uint16_t i = 0; i < mainRowsNum; i++) {
            __ubuf__ uint32_t *dstUb = dstUbStart + i * mainFuseAxis;
            SpNetworkFull<Rounds, T, DstBlockUnalign>(dstUb, mainFuseAxis, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0,
                cMul1, pg);
            AddWith128Bits(ctr0, ctr1, ctr2, ctr3, vEleStrideB32OneRow, pg);
        }
    }

    if (tailFuseAxis > 0) {
        __ubuf__ uint32_t *dstUbTail = dstUbStart + mainRowsNum * mainFuseAxis;
        SpNetworkFull<Rounds, T, DstBlockUnalign>(dstUbTail, tailFuseAxis, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0,
            cMul1, pg);
    }
}

template <uint16_t Rounds = 7, typename T, bool DstUnalign>
__simd_vf__ inline void PhiloxRandomMultiRowNoFuse(__ubuf__ uint32_t *dstUbStart, const philoxStruct philox,
    const PhiloxRandomParams params, uint32_t strideCounterOneRow, uint16_t mainIter, uint16_t tailCount,
    uint16_t hasTail)
{
    Reg::RegTensor<uint32_t> vEleStrideB32OneRow;
    Duplicate(vEleStrideB32OneRow, PhiloxInternal::ELE_CNT_B32_ONCE);

    Reg::RegTensor<uint32_t> key1, key0;
    Duplicate(key0, philox.philoxKey[0]);
    Duplicate(key1, philox.philoxKey[1]);

    Reg::RegTensor<uint32_t> cMul0, cMul1;
    Duplicate(cMul0, PhiloxInternal::CONST_MUL_0);
    Duplicate(cMul1, PhiloxInternal::CONST_MUL_1);

    Reg::RegTensor<uint32_t> ctr3, ctr2, ctr1, ctr0;
    Reg::MaskReg pg = Reg::CreateMask<uint32_t>();

    for (uint16_t i = 0; i < params.row; i++) {
        __ubuf__ uint32_t *dstUb = dstUbStart + i * params.column;
        __ubuf__ uint32_t *dstUbTail = dstUb + mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;

        Reg::RegTensor<int32_t> incIdx;
        Reg::Arange(incIdx, i * strideCounterOneRow);
        PhiloxCounterInit(philox.philoxCounter, ctr0, ctr1, ctr2, ctr3, incIdx, pg);
        PhiloxRoundMainBlockUnroll<Rounds, T, DstUnalign>(dstUb, mainIter, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0,
            cMul1, vEleStrideB32OneRow, pg);
        for (uint16_t j = 0; j < hasTail; j++) {
            SpNetworkFull<Rounds, T, DstUnalign>(dstUbTail, tailCount, ctr0, ctr1, ctr2, ctr3, key0, key1, cMul0, cMul1,
                pg);
        }
    }
}

template <uint16_t Rounds = 7, typename T>
__aicore__ inline void PhiloxRandomImpl(const LocalTensor<T> &dstLocal, const PhiloxKey &philoxKey,
    const PhiloxCounter &philoxCounter, uint16_t count)
{
    static_assert(SupportType<T, int32_t, uint32_t, float>(),
        "PhiloxRandom API only support int32_t/uint32_t/float type");
    static_assert(Rounds == 7 || Rounds == 10, "PhiloxRandom API only support 7 or 10 Rounds ");

    __ubuf__ uint32_t *dstUb = (__ubuf__ uint32_t *)dstLocal.GetPhyAddr();
    uint16_t mainIter = count / PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
    uint16_t tailCount = count - mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
    __ubuf__ uint32_t *dstUbTail = dstUb + mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
    philoxStruct philox;
    uint16_t philoxKeySize = (philoxKey == nullptr) ? 0 : PHILOX_KEY_SIZE;
    uint16_t philoxCounterSize = (philoxCounter == nullptr) ? 0 : PHILOX_COUNTER_SIZE;
    for (uint16_t i = 0; i < philoxKeySize; i++) {
        philox.philoxKey[i] = philoxKey[i];
    }
    for (uint16_t i = 0; i < philoxCounterSize; i++) {
        philox.philoxCounter[i] = philoxCounter[i];
    }
    PhiloxRandomOneRow<Rounds, T, false>(dstUb, dstUbTail, philox, mainIter, tailCount);
}

template <uint16_t Rounds = 7, typename T>
__aicore__ inline void PhiloxRandomImpl(const LocalTensor<T> &dstLocal, const PhiloxKey &philoxKey,
    const PhiloxCounter &philoxCounter, const PhiloxRandomParams &params)
{
    static_assert(SupportType<T, int32_t, uint32_t, float>(),
        "PhiloxRandom API only support int32_t/uint32_t/float type");
    static_assert(Rounds == 7 || Rounds == 10, "PhiloxRandom API only support 7 or 10 Rounds ");

    ASCENDC_ASSERT((params.stride % PhiloxInternal::PHILOX_ONCE_COUNTER_NUM == 0),
                   { KERNEL_LOG(KERNEL_ERROR, "params.stride % 4 = 0!"); });
    ASCENDC_ASSERT((params.column % PhiloxInternal::PHILOX_ONCE_COUNTER_NUM == 0),
                   { KERNEL_LOG(KERNEL_ERROR, "params.column % 4 = 0!"); });
    ASCENDC_ASSERT((params.stride >= params.column), { KERNEL_LOG(KERNEL_ERROR, "params.stride >= params.column!"); });
    ASCENDC_ASSERT((params.row > 0 && params.column > 0),
                   { KERNEL_LOG(KERNEL_ERROR, "params.row > 0 && params.column > 0!"); });

    __ubuf__ uint32_t *dstUbStart = (__ubuf__ uint32_t *)dstLocal.GetPhyAddr();
    philoxStruct philox;
    uint16_t philoxKeySize = (philoxKey == nullptr) ? 0 : PHILOX_KEY_SIZE;
    uint16_t philoxCounterSize = (philoxCounter == nullptr) ? 0 : PHILOX_COUNTER_SIZE;
    for (uint16_t i = 0; i < philoxKeySize; i++) {
        philox.philoxKey[i] = philoxKey[i];
    }
    for (uint16_t i = 0; i < philoxCounterSize; i++) {
        philox.philoxCounter[i] = philoxCounter[i];
    }
    // judge fuse axis and unalign pattern
    if (params.row == 1 || params.stride == params.column) {
        // if one row or stride == column, continuous and align, count = row * column
        PhiloxRandomImpl<Rounds, T>(dstLocal, philoxKey, philoxCounter, params.row * params.column);
    } else if (params.column <= PhiloxInternal::PHILOX_ONCE_REPEAT_NUM) {
        // fuse axis condition: params.column <= one repeat
        uint32_t fuseFactor = PhiloxInternal::PHILOX_ONCE_REPEAT_NUM / params.column; // fuseFactor >= 1
        uint32_t mainFuseAxis = params.column * fuseFactor;
        uint32_t mainRowsNum = params.row / fuseFactor;
        uint32_t tailFuseAxis = params.row * params.column - mainFuseAxis * mainRowsNum;
        LocalTensor<T> indexTensor;
        PopStackBuffer<T, TPosition::LCM>(indexTensor);
        __ubuf__ int32_t *indexUb = (__ubuf__ int32_t *)indexTensor.GetPhyAddr();
        if (mainFuseAxis == PhiloxInternal::PHILOX_ONCE_REPEAT_NUM) {
            // DstBlock align, DstRepeat align
            PhiloxRandomMultiRowWithFuse<Rounds, T, false, false>(dstUbStart, indexUb, philox,
                params, fuseFactor, mainFuseAxis, mainRowsNum, tailFuseAxis);
        } else if (mainFuseAxis * sizeof(uint32_t) % ONE_BLOCK_SIZE == 0) {
            // DstBlock align, DstRepeat unalign. use SpNetworkFull align condition
            PhiloxRandomMultiRowWithFuse<Rounds, T, false, true>(dstUbStart, indexUb, philox,
                params, fuseFactor, mainFuseAxis, mainRowsNum, tailFuseAxis);
        } else {
            // DstBlock unalign, DstRepeat unalign
            PhiloxRandomMultiRowWithFuse<Rounds, T, true, true>(dstUbStart, indexUb, philox,
                params, fuseFactor, mainFuseAxis, mainRowsNum, tailFuseAxis);
        }
    } else {
        // no fuse axis
        uint32_t strideCounterOneRow = params.stride / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM;
        uint16_t mainIter = params.column / PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
        uint16_t tailCount = params.column - mainIter * PhiloxInternal::PHILOX_ONCE_REPEAT_NUM;
        uint16_t hasTail = static_cast<uint16_t>(tailCount > 0);
        if (params.column / PhiloxInternal::PHILOX_ONCE_COUNTER_NUM * sizeof(uint32_t) % ONE_BLOCK_SIZE == 0) {
            // align pattern. use PhiloxRoundMainBlockUnroll align condition
            PhiloxRandomMultiRowNoFuse<Rounds, T, false>(dstUbStart, philox, params,
                strideCounterOneRow, mainIter, tailCount, hasTail);
        } else {
            // unalign pattern. each process element: column / PHILOX_ONCE_COUNTER_NUM
            PhiloxRandomMultiRowNoFuse<Rounds, T, true>(dstUbStart, philox, params,
                strideCounterOneRow, mainIter, tailCount, hasTail);
        }
    }
}
} // namespace AscendC
#endif // IMPL_MATH_PHILOX_PHILOX_C310_IMPL_H

#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_PHILOX_PHILOX_C310_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_PHILOX_PHILOX_C310_IMPL_H__
#endif