* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file uniform.h
* \brief Uniform random number generator implementation
*/
#ifndef TILEOP_TILE_OPERATOR_UNIFORM__H
#define TILEOP_TILE_OPERATOR_UNIFORM__H
#include "pto_tile.h"
#include "utils/layout.h"
#include "utils/tile_tensor.h"
#if defined(PTO_NPU_ARCH_A5) || defined(__CPU_SIM)
#define OP_TILE_OP_UNIFORM TUniform
template <typename TDst, typename TTmp>
TILEOP void TUniform(TDst dst, TTmp tmpbuf, uint64_t key, uint64_t counter0, uint64_t counter1, uint16_t rounds) {
using ShapeValueType = typename Std::tuple_element<0, typename TDst::Shape>::type;
constexpr auto shapeSize = Std::tuple_size<typename TDst::Shape>::value;
constexpr int Size = Std::tuple_element<shapeSize - 1, typename TDst::TileShape>::type::value;
constexpr int tileW = (Size + 7) / 8 * 8;
constexpr size_t ALIGN_SIZE = 32;
uint64_t tileCounter[2] = {counter0, counter1};
uint64_t tmpbufAddr = tmpbuf.GetAddr();
__ubuf__ uint32_t* uint32Buffer = reinterpret_cast<__ubuf__ uint32_t*>(tmpbufAddr);
using TileUint32 = pto::Tile<pto::TileType::Vec, uint32_t, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
TileUint32 uint32Tile(1, Size);
TileUint32 dstUint32Tile(1, Size);
pto::TASSIGN(uint32Tile, (uint64_t)uint32Buffer);
pto::TRandomKey uniformKey = {static_cast<uint32_t>(key & 0xFFFFFFFF),
static_cast<uint32_t>(key >> 32)};
pto::TRandomCounter uniformCounter = {static_cast<uint32_t>(tileCounter[0] & 0xFFFFFFFF),
static_cast<uint32_t>(tileCounter[0] >> 32),
static_cast<uint32_t>(tileCounter[1] & 0xFFFFFFFF),
static_cast<uint32_t>(tileCounter[1] >> 32)};
if (rounds == 7) {
pto::TRANDOM<7>(uint32Tile, uniformKey, uniformCounter);
} else {
pto::TRANDOM<10>(uint32Tile, uniformKey, uniformCounter);
}
using DstType = typename TDst::Type;
constexpr bool isFloat = std::is_same_v<DstType, float>;
constexpr bool isHalf = std::is_same_v<DstType, half>;
constexpr bool isBfloat16 = std::is_same_v<DstType, bfloat16_t>;
if constexpr (isFloat) {
pto::TASSIGN(dstUint32Tile, (uint64_t)dst.GetAddr());
pto::TANDS(dstUint32Tile, uint32Tile, 0x7fffff);
pto::TORS(uint32Tile, dstUint32Tile, 0x3f800000);
using TileFloat = pto::Tile<pto::TileType::Vec, float, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
TileFloat floatTile(1, Size);
pto::TASSIGN(floatTile, (uint64_t)uint32Buffer);
using TileDst = pto::Tile<pto::TileType::Vec, DstType, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
TileDst dstTile(1, Size);
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr()));
pto::TSUBS(dstTile, floatTile, 1.0f);
} else if constexpr (isHalf || isBfloat16) {
constexpr int64_t uint32BufferBytes = ((Size * sizeof(uint32_t) + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE;
__ubuf__ uint32_t* uint32BufferLow = reinterpret_cast<__ubuf__ uint32_t*>(tmpbufAddr + uint32BufferBytes);
TileUint32 uint32TileLow(1, Size);
pto::TASSIGN(uint32TileLow, (uint64_t)uint32BufferLow);
pto::TANDS(uint32TileLow, uint32Tile, 0xFFFF);
__ubuf__ uint16_t* uint16Buffer = reinterpret_cast<__ubuf__ uint16_t*>(tmpbufAddr);
using TileUint16 = pto::Tile<pto::TileType::Vec, uint16_t, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
TileUint16 uint16Tile(1, Size);
TileUint16 dstUint16Tile(1, Size);
pto::TASSIGN(uint16Tile, (uint64_t)uint16Buffer);
pto::TASSIGN(dstUint16Tile, (uint64_t)dst.GetAddr());
pto::TCVT(uint16Tile, uint32TileLow, pto::RoundMode::CAST_NONE);
if constexpr (isHalf) {
pto::TANDS(dstUint16Tile, uint16Tile, 0x3ff);
pto::TORS(uint16Tile, dstUint16Tile, 0x3c00);
} else {
pto::TANDS(dstUint16Tile, uint16Tile, 0x7f);
pto::TORS(uint16Tile, dstUint16Tile, 0x3f80);
}
__ubuf__ DstType* resultBuffer = reinterpret_cast<__ubuf__ DstType*>(uint16Buffer);
using TileResult = pto::Tile<pto::TileType::Vec, DstType, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
TileResult resultTile(1, Size);
pto::TASSIGN(resultTile, (uint64_t)resultBuffer);
using TileDst = pto::Tile<pto::TileType::Vec, DstType, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
TileDst dstTile(1, Size);
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr()));
if constexpr (isHalf) {
pto::TSUBS(dstTile, resultTile, static_cast<half>(1.0));
} else {
pto::TSUBS(dstTile, resultTile, static_cast<bfloat16_t>(1.0));
}
}
}
#endif
#endif