* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
#include "torch_npu/csrc/core/npu/DeviceUtils.h"
#include "torch_npu/csrc/framework/OpCommand.h"
namespace npu_ops_transformer_ext {
namespace RotaryStride {
#include <iostream>
#include <stdio.h>
#include "kernel_operator.h"
#include "tiling/platform/platform_ascendc.h"
#include "dtype_convert.h"
using namespace AscendC;
constexpr int64_t UB_MAX_BYTES = 184*1024;
constexpr int64_t BUFFER_NUM = 1;
template<typename T>
class rotary_stride {
public:
__aicore__ inline rotary_stride() {}
__aicore__ inline void Init(__gm__ void* in, __gm__ void* sin, __gm__ void* cos,
__gm__ void* psql, __gm__ void* out,
const int64_t gbB, const int64_t gbS, const int64_t gbN, const int64_t gbD, const int64_t gbMAXS,
const int64_t stride)
{
blockNum_ = GetBlockNum();
blockIdx_ = GetBlockIdx();
gbB_ = gbB;
gbS_ = gbS;
gbN_ = gbN;
gbD_ = gbD;
gbMAXS_ = gbMAXS;
stride_ = stride;
gbAlignB_ = AlignUp(gbB_, 256/sizeof(int32_t));
bkB_ = 1;
bkS_ = 1;
bkN_ = gbN_;
bkD_ = gbD_;
bkHalfD_ = bkD_ / 2;
bkAlignHalfD_ = AlignUp(bkHalfD_, 256/sizeof(T));
bkLoop_ = gbB_*gbS_ / blockNum_;
if (gbB_*gbS_ % blockNum_ != 0){ bkLoop_ += 1; }
int64_t used_bytes = UB_MAX_BYTES - bkAlignHalfD_ * 2 * sizeof(T) * 2;
used_bytes -= bkAlignHalfD_ * 2 * sizeof(float) * 2;
used_bytes -= gbAlignB_ * sizeof(int32_t);
bkMaxN_ = used_bytes / ((bkAlignHalfD_ * 2 * sizeof(T) * 2) + (bkAlignHalfD_ * 2 * sizeof(float) * 4));
gmIn_.SetGlobalBuffer((__gm__ T*)(in), gbB_ * gbS_ * gbN_ * gbD_);
gmSin_.SetGlobalBuffer((__gm__ T*)(sin), gbMAXS_ * gbD_);
gmCos_.SetGlobalBuffer((__gm__ T*)(cos), gbMAXS_ * gbD_);
gmPsql_.SetGlobalBuffer((__gm__ int32_t*)(psql), gbB_);
gmOut_.SetGlobalBuffer((__gm__ T*)(out), gbB_ * gbS_ * gbN_ * gbD_);
pipe_.InitBuffer(inQueIn_, BUFFER_NUM, bkMaxN_ * bkAlignHalfD_ * 2 * sizeof(T));
pipe_.InitBuffer(inQueSin_, BUFFER_NUM, bkAlignHalfD_ * 2 * sizeof(T));
pipe_.InitBuffer(inQueCos_, BUFFER_NUM, bkAlignHalfD_ * 2 * sizeof(T));
pipe_.InitBuffer(local_in_f_, bkMaxN_ * bkAlignHalfD_ * 2 * sizeof(float));
pipe_.InitBuffer(local_out_f_, bkMaxN_ * bkAlignHalfD_ * 2 * sizeof(float));
pipe_.InitBuffer(local_sin_f_, bkAlignHalfD_ * 2 * sizeof(float));
pipe_.InitBuffer(local_cos_f_, bkAlignHalfD_ * 2 * sizeof(float));
pipe_.InitBuffer(outQueOut_, BUFFER_NUM, bkMaxN_ * bkAlignHalfD_ * 2 * sizeof(T));
pipe_.InitBuffer(inQuePsql_, 1, gbAlignB_ * sizeof(int32_t));
pipe_.InitBuffer(bufSinMul_, bkMaxN_ * bkAlignHalfD_ * 2 * sizeof(float));
pipe_.InitBuffer(bufCosMul_, bkMaxN_ * bkAlignHalfD_ * 2 * sizeof(float));
}
__aicore__ inline void Process()
{
lmSinMul_ = bufSinMul_.Get<float>();
lmCosMul_ = bufCosMul_.Get<float>();
LocalTensor<int32_t> local_psql = inQuePsql_.AllocTensor<int32_t>();
DataCopyExtParams copy_pas{1, (uint32_t)(gbB_*sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> pad_pas;
DataCopyPad(local_psql, gmPsql_, copy_pas, pad_pas);
inQuePsql_.EnQue(local_psql);
lmPsql_ = inQuePsql_.DeQue<int32_t>();
for (int64_t bkl = 0; bkl < bkLoop_; bkl++) {
int64_t task_index = bkl * blockNum_ + blockIdx_;
int64_t real_b = task_index / gbS_;
int64_t real_s = task_index % gbS_;
int32_t psql_value = lmPsql_.GetValue(real_b);
if ((task_index < gbB_*gbS_) && ((real_s + psql_value) >= 0)) {
int64_t offset_inout = (stride_ - gbD_) + (real_b*gbS_ + real_s) * (gbN_*stride_);
int64_t offset_sincos = (real_s + psql_value) * gbD_;
for (int64_t bkn = 0; bkn < bkN_; bkn += bkMaxN_) {
int64_t bkn_len = bkMaxN_;
if (bkn + bkMaxN_ > bkN_) {
bkn_len = bkN_ - bkn;
}
CopyIn(task_index, offset_inout, offset_sincos, bkn, bkn_len);
Compute(task_index, offset_inout, offset_sincos, bkn, bkn_len);
CopyOut(task_index, offset_inout, offset_sincos, bkn, bkn_len);
}
}
}
inQuePsql_.FreeTensor(local_psql);
}
private:
__aicore__ inline void CopyIn(int64_t task_index, int64_t offset_inout, int64_t offset_sincos,
int64_t bkn_offset, int64_t bkn_len)
{
LocalTensor<T> local_in = inQueIn_.AllocTensor<T>();
LocalTensor<T> local_sin = inQueSin_.AllocTensor<T>();
LocalTensor<T> local_cos = inQueCos_.AllocTensor<T>();
DataCopyExtParams copy_pas;
copy_pas.blockCount = (uint16_t)(bkn_len);
copy_pas.blockLen = (uint32_t)(bkHalfD_ * sizeof(T));
copy_pas.srcStride = (uint32_t)((stride_ - bkHalfD_) * sizeof(T));
copy_pas.dstStride = (uint32_t)(2 * (bkAlignHalfD_*sizeof(T)/32) -
(AlignUp(bkHalfD_*sizeof(T), 32)/32));
DataCopyPadExtParams<T> pad_pas;
DataCopyPad(local_in, gmIn_[offset_inout + bkn_offset * stride_], copy_pas, pad_pas);
DataCopyPad(local_in[bkAlignHalfD_], gmIn_[offset_inout + bkn_offset * stride_ + bkHalfD_],
copy_pas, pad_pas);
copy_pas.blockCount = 2;
copy_pas.blockLen = (uint32_t)(bkHalfD_ * sizeof(T));
copy_pas.srcStride = 0;
copy_pas.dstStride = (uint32_t)((bkAlignHalfD_*sizeof(T)/32) - (AlignUp(bkHalfD_*sizeof(T), 32)/32));
DataCopyPad(local_sin, gmSin_[offset_sincos], copy_pas, pad_pas);
DataCopyPad(local_cos, gmCos_[offset_sincos], copy_pas, pad_pas);
inQueIn_.EnQue(local_in);
inQueSin_.EnQue(local_sin);
inQueCos_.EnQue(local_cos);
}
__aicore__ inline void Compute(int64_t task_index, int64_t offset_inout, int64_t offset_sincos,
int64_t bkn_offset, int64_t bkn_len)
{
LocalTensor<T> local_in = inQueIn_.DeQue<T>();
LocalTensor<T> local_sin = inQueSin_.DeQue<T>();
LocalTensor<T> local_cos = inQueCos_.DeQue<T>();
LocalTensor<float> local_in_f = local_in_f_.Get<float>();
LocalTensor<float> local_sin_f = local_sin_f_.Get<float>();
LocalTensor<float> local_cos_f = local_cos_f_.Get<float>();
Cast(local_in_f, local_in, RoundMode::CAST_NONE, bkn_len * bkAlignHalfD_ * 2);
Cast(local_sin_f, local_sin, RoundMode::CAST_NONE, bkAlignHalfD_ * 2);
Cast(local_cos_f, local_cos, RoundMode::CAST_NONE, bkAlignHalfD_ * 2);
LocalTensor<T> local_out = outQueOut_.AllocTensor<T>();
BinaryRepeatParams repeat_params;
repeat_params.dstBlkStride = 1;
repeat_params.src0BlkStride = 1;
repeat_params.src1BlkStride = 1;
repeat_params.dstRepStride = bkAlignHalfD_ * sizeof(float) * 2 / 32;
repeat_params.src0RepStride = bkAlignHalfD_ * sizeof(float) * 2 / 32;
repeat_params.src1RepStride = 0;
uint64_t mask = bkHalfD_;
for (int i = 0; i < 2; i++) {
Mul(lmSinMul_[i*bkAlignHalfD_], local_in_f[i*bkAlignHalfD_], local_sin_f[i*bkAlignHalfD_],
mask, (uint8_t)bkn_len, repeat_params);
Mul(lmCosMul_[i*bkAlignHalfD_], local_in_f[i*bkAlignHalfD_], local_cos_f[i*bkAlignHalfD_],
mask, (uint8_t)bkn_len, repeat_params);
}
repeat_params.dstRepStride = bkAlignHalfD_ * sizeof(float) * 2 / 32;
repeat_params.src0RepStride = bkAlignHalfD_ * sizeof(float) * 2 / 32;
repeat_params.src1RepStride = bkAlignHalfD_ * sizeof(float) * 2 / 32;
LocalTensor<float>local_out_f = local_out_f_.Get<float>();
Sub(local_out_f, lmCosMul_, lmSinMul_[bkAlignHalfD_], mask, (uint8_t)bkn_len, repeat_params);
Add(local_out_f[bkAlignHalfD_], lmCosMul_[bkAlignHalfD_], lmSinMul_, mask,
(uint8_t)bkn_len, repeat_params);
Cast(local_out, local_out_f, RoundMode::CAST_ROUND, bkn_len * bkAlignHalfD_ * 2);
outQueOut_.EnQue(local_out);
inQueIn_.FreeTensor(local_in);
inQueSin_.FreeTensor(local_sin);
inQueCos_.FreeTensor(local_cos);
local_in_f_.FreeTensor(local_in_f);
local_sin_f_.FreeTensor(local_sin_f);
local_cos_f_.FreeTensor(local_cos_f);
local_out_f_.FreeTensor(local_out_f);
}
__aicore__ inline void CopyOut(int64_t task_index, int64_t offset_inout, int64_t offset_sincos,
int64_t bkn_offset, int64_t bkn_len)
{
LocalTensor<T> local_out = outQueOut_.DeQue<T>();
DataCopyExtParams copy_pas;
copy_pas.blockCount = (uint16_t)(bkn_len);
copy_pas.blockLen = (uint32_t)(bkHalfD_ * sizeof(T));
copy_pas.srcStride = (uint32_t)(2*(bkAlignHalfD_*sizeof(T)/32) -
(AlignUp(bkHalfD_*sizeof(T), 32)/32));
copy_pas.dstStride = (uint32_t)((stride_ - bkHalfD_) * sizeof(T));
DataCopyPad(gmOut_[offset_inout + bkn_offset * stride_], local_out, copy_pas);
DataCopyPad(gmOut_[offset_inout + bkn_offset * stride_ + bkHalfD_], local_out[bkAlignHalfD_], copy_pas);
outQueOut_.FreeTensor(local_out);
}
private:
TPipe pipe_;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueIn_, inQueSin_, inQueCos_;
TQue<QuePosition::VECIN, 1> inQuePsql_;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueOut_;
TBuf<QuePosition::VECCALC> bufSinMul_, bufCosMul_;
TBuf<QuePosition::VECCALC> local_in_f_, local_sin_f_, local_cos_f_, local_out_f_;
GlobalTensor<T> gmIn_, gmSin_, gmCos_, gmOut_;
GlobalTensor<int32_t> gmPsql_;
LocalTensor<int32_t> lmPsql_;
LocalTensor<float> lmSinMul_, lmCosMul_;
int64_t blockNum_, blockIdx_;
int64_t gbB_, gbS_, gbN_, gbD_, gbMAXS_;
int64_t gbAlignB_;
int64_t bkB_, bkS_, bkN_;
int64_t bkMaxN_;
int64_t bkD_, bkHalfD_, bkAlignHalfD_;
int64_t bkLoop_;
int64_t stride_;
};
extern "C" __global__ __aicore__ void compute_rotary_stride(__gm__ void* in,
__gm__ void* sin, __gm__ void* cos, __gm__ void* psql, __gm__ void* out,
const int64_t gbB, const int64_t gbS, const int64_t gbN, const int64_t gbD, const int64_t gbMAXS,
const int64_t stride, int32_t dtype)
{
TYPE_SWITCH(dtype, T, {
rotary_stride<T> op;
op.Init(in, sin, cos, psql, out, gbB, gbS, gbN, gbD, gbMAXS, stride);
op.Process();
})
}
inline int64_t align_up(const int64_t number, const int64_t alignSize)
{
if (number % alignSize == 0) {
return number;
}
return ((number / alignSize + 1) * alignSize);
}
int judge_rotary_stride_launch(const int64_t gbB, const int64_t gbD, const int64_t ubSize)
{
int64_t gbB_;
int64_t gbD_;
int64_t gbAlignB_;
int64_t bkD_;
int64_t bkHalfD_;
int64_t bkAlignHalfD_;
int64_t bkMaxN_;
gbB_ = gbB;
gbD_ = gbD;
gbAlignB_ = align_up(gbB_, 256/sizeof(int32_t));
bkD_ = gbD_;
bkHalfD_ = bkD_ / 2;
bkAlignHalfD_ = align_up(bkHalfD_, 256/sizeof(uint16_t));
int64_t used_bytes = ubSize - bkAlignHalfD_ * 2 * sizeof(uint16_t) * 2;
used_bytes -= bkAlignHalfD_ * 2 * sizeof(float) * 2;
used_bytes -= gbAlignB_ * sizeof(int32_t);
bkMaxN_ = used_bytes / ((bkAlignHalfD_ * 2 * sizeof(uint16_t) * 2) +
(bkAlignHalfD_ * 2 * sizeof(float) * 4));
if (bkMaxN_ < 1) {
std::cout << __FUNCTION__ << ": bkMaxN = " << bkMaxN_ << std::endl;
return 1;
}
return 0;
}
void rotary_stride_kernel_lanuch(int64_t blockDim, void* stream,
void* in, void* sin, void* cos, void* psql, void* out,
const int64_t gbB, const int64_t gbS, const int64_t gbN, const int64_t gbD, const int64_t gbMAXS,
const int64_t stride, int32_t dtype)
{
if (gbB*gbS < blockDim) { blockDim = gbB*gbS; }
compute_rotary_stride<<<blockDim, nullptr, stream>>>(in, sin, cos, psql, out,
gbB, gbS, gbN, gbD, gbMAXS, stride, dtype);
}
int rotary_stride_lanuch(int64_t blockDim, void* stream,
void* in, void* sin, void* cos, void* psql, void* out,
const int64_t gbB, const int64_t gbS, const int64_t gbN, const int64_t gbD, const int64_t gbMAXS,
const int64_t stride, int32_t dtype)
{
int64_t ubSize = 184 * 1024;
int ret = judge_rotary_stride_launch(gbB, gbD, ubSize);
if (ret == 0) {
rotary_stride_kernel_lanuch(blockDim, stream, in, sin, cos, psql, out, gbB, gbS, gbN,
gbD, gbMAXS, stride, dtype);
return 0;
}
std::cout << __FUNCTION__ << ": " << "UB size is limited, please check!" << std::endl;
return 1;
}
int64_t rotary_stride_npu(int64_t blockDim, torch::Tensor &in, torch::Tensor &sin, torch::Tensor &cos,
torch::Tensor &out, const int64_t hiddenDim)
{
TORCH_CHECK(torch_npu::utils::is_npu(in), "input tensor must be on NPU device");
TORCH_CHECK(torch_npu::utils::is_npu(sin), "sin tensor must be on NPU device");
TORCH_CHECK(torch_npu::utils::is_npu(cos), "cosin tensor must be on NPU device");
TORCH_CHECK(torch_npu::utils::is_npu(out), "output tensor must be on NPU device");
TORCH_CHECK(in.scalar_type() == at::kBFloat16 || in.scalar_type() == at::kHalf,
"dtype of input tensor is invalid, only BF16 or FP16 is supported.");
TORCH_CHECK(out.scalar_type() == at::kBFloat16 || out.scalar_type() == at::kHalf,
"dtype of output tensor is invalid, only BF16 or FP16 is supported.");
TORCH_CHECK(sin.scalar_type() == at::kBFloat16 || sin.scalar_type() == at::kHalf,
"dtype of sin tensor is invalid, only BF16 or FP16 is supported.");
TORCH_CHECK(cos.scalar_type() == at::kBFloat16 || cos.scalar_type() == at::kHalf,
"dtype of cos tensor is invalid, only BF16 or FP16 is supported.");
auto stream = c10_npu::getCurrentNPUStream().stream(false);
torch::Tensor psql = torch::arange(0, in.size(0) * in.size(1), in.size(1)).to(at::Device("npu"));
int launchStatus = 0;
auto acl_call = [=, &launchStatus]() -> int {
launchStatus = rotary_stride_lanuch(blockDim, stream, in.data_ptr(), sin.data_ptr(), cos.data_ptr(),
psql.data_ptr(), out.data_ptr(), in.size(0), in.size(1), in.size(2), hiddenDim, sin.size(0),
in.size(3), in.scalar_type() == at::kHalf ? 1 : 27);
return 0;
};
at_npu::native::OpCommand::RunOpApi("RotaryStride", acl_call);
return launchStatus;
}
TORCH_LIBRARY_IMPL(npu_ops_transformer_ext, PrivateUse1, m)
{
m.impl("rotary_stride", rotary_stride_npu);
}
}
}