* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "initializer.h"
#include <iostream>
#include <cfloat>
#include "initializer_kernel_ops.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
#include "torch_utils.h"
#ifndef LOG_ERROR
#define LOG_ERROR(msg) std::cout << "[INFO]" << msg << std::endl
#endif
namespace dyn_emb {
namespace {
DataType to_data_type(at::ScalarType scalar_type)
{
return scalartype_to_datatype(scalar_type);
}
struct IndexInitLaunchArgs {
DataType value_type;
DataType index_type;
int64_t num;
int64_t dim;
int64_t stride;
void* buffer;
void* indices;
aclrtStream stream;
};
IndexInitLaunchArgs make_index_init_launch_args(const at::Tensor& buffer, const at::Tensor& indices)
{
IndexInitLaunchArgs args;
args.value_type = to_data_type(buffer.scalar_type());
args.index_type = to_data_type(indices.scalar_type());
args.num = indices.size(0);
args.dim = buffer.size(1);
args.stride = buffer.stride(0);
args.buffer = buffer.data_ptr();
args.indices = indices.data_ptr();
args.stream = c10_npu::getCurrentNPUStream().stream(true);
return args;
}
template <typename LaunchFn>
void dispatch_index_init(const at::Tensor& buffer, const at::Tensor& indices, LaunchFn launch)
{
if (indices.numel() == 0) {
LOG_ERROR("Indices tensor is empty, initialization failed.");
return;
}
if (buffer.dim() != 2) {
LOG_ERROR("Initializer input buffer must be 2D.");
return;
}
if (buffer.stride(1) != 1) {
LOG_ERROR("Initializer input buffer must be contiguous at dim 1.");
return;
}
auto args = make_index_init_launch_args(buffer, indices);
launch(args);
}
bool validate_init_bounds(float lower, float upper)
{
if (lower - upper > FLT_EPSILON) {
LOG_ERROR("Lower bound must be less than upper bound.");
return false;
}
return true;
}
}
CurandStateContext::CurandStateContext()
{
auto stream = c10_npu::getCurrentNPUStream().stream(true);
alloc_curand_states(&states_, stream);
}
CurandStateContext::~CurandStateContext()
{
free_curand_states(states_);
states_ = nullptr;
}
curandState* CurandStateContext::ptr()
{
return states_;
}
void normal_init(at::Tensor buffer, at::Tensor indices, CurandStateContext& curand_state_context, float mean,
float std_dev)
{
dispatch_index_init(buffer, indices, [&](const IndexInitLaunchArgs& args) {
launch_index_normal_init(args.value_type, args.index_type, args.num, args.dim, args.stride, args.buffer,
args.indices, curand_state_context.ptr(), mean, std_dev, args.stream);
});
}
void truncated_normal_init(at::Tensor buffer, at::Tensor indices, CurandStateContext& curand_state_context, float mean,
float std_dev, float lower, float upper)
{
if (!validate_init_bounds(lower, upper)) {
return;
}
dispatch_index_init(buffer, indices, [&](const IndexInitLaunchArgs& args) {
launch_index_truncated_normal_init(args.value_type, args.index_type, args.num, args.dim, args.stride,
args.buffer, args.indices, curand_state_context.ptr(), mean, std_dev, lower,
upper, args.stream);
});
}
void uniform_init(at::Tensor buffer, at::Tensor indices, CurandStateContext& curand_state_context, float lower,
float upper)
{
if (!validate_init_bounds(lower, upper)) {
return;
}
dispatch_index_init(buffer, indices, [&](const IndexInitLaunchArgs& args) {
launch_index_uniform_init(args.value_type, args.index_type, args.num, args.dim, args.stride, args.buffer,
args.indices, curand_state_context.ptr(), lower, upper, args.stream);
});
}
void const_init(at::Tensor buffer, at::Tensor indices, float value)
{
dispatch_index_init(buffer, indices, [&](const IndexInitLaunchArgs& args) {
launch_index_const_init(args.value_type, args.index_type, args.num, args.dim, args.stride, args.buffer,
args.indices, value, args.stream);
});
}
void debug_init(at::Tensor buffer, at::Tensor indices, at::Tensor keys)
{
dispatch_index_init(buffer, indices, [&](const IndexInitLaunchArgs& args) {
launch_index_debug_init(args.value_type, args.index_type, to_data_type(keys.scalar_type()), args.num, args.dim,
args.stride, args.buffer, args.indices, keys.data_ptr(), 100000, args.stream);
});
}
}