* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
#include "torch_npu/csrc/core/npu/DeviceUtils.h"
#include "torch_npu/csrc/framework/OpCommand.h"
namespace npu_ops_transformer_ext {
namespace FinalRouting {
#include <iostream>
#include <stdio.h>
#include "kernel_operator.h"
using namespace AscendC;
constexpr int64_t UB_MAX_BYTES = 184*1024;
constexpr int64_t BUFFER_NUM = 1;
class final_routing {
public:
__aicore__ inline final_routing() {}
__aicore__ inline void Init(GM_ADDR in, GM_ADDR token_table, GM_ADDR score_table, GM_ADDR out,
const int64_t expert_num, const int64_t token_num, const int64_t copy_byte)
{
blockNum_ = GetBlockNum();
blockIdx_ = GetBlockIdx();
gbExpertNum_ = expert_num;
gbTokenNum_ = token_num;
gbCopyByte_ = copy_byte;
gbCopyNum_ = gbCopyByte_ / sizeof(bfloat16_t);
gbTokenNumAlign_ = AlignUp(gbTokenNum_, 64 / sizeof(int32_t));
gbExpertNumAlign_ = AlignUp(gbExpertNum_, 64 / sizeof(int32_t));
bkExpertNum_ = 1;
bkTokenNum_ = 1;
bkCopyByte_ = gbCopyByte_;
bkCopyNum_ = bkCopyByte_ / sizeof(bfloat16_t);
bkCopyByteAlign_ = AlignUp(bkCopyByte_, 64);
bkCopyNumAlign_ = bkCopyByteAlign_ / sizeof(bfloat16_t);
gmIn_.SetGlobalBuffer((__gm__ bfloat16_t*)(in), gbExpertNum_ * gbTokenNum_ * gbCopyNum_);
gmTokenTable_.SetGlobalBuffer((__gm__ int32_t*)(token_table), gbTokenNum_ * gbExpertNum_);
gmScoreTable_.SetGlobalBuffer((__gm__ bfloat16_t*)(score_table), gbTokenNum_ * gbExpertNum_);
gmOut_.SetGlobalBuffer((__gm__ bfloat16_t*)(out), gbTokenNum_ * gbCopyNum_);
pipe_.InitBuffer(inQueIn_, BUFFER_NUM, bkCopyByteAlign_ * 2);
pipe_.InitBuffer(inQueTokenTable_, BUFFER_NUM, gbExpertNumAlign_ * sizeof(int32_t));
pipe_.InitBuffer(inQueScoreTable_, BUFFER_NUM, gbExpertNumAlign_ * sizeof(bfloat16_t));
pipe_.InitBuffer(outQueOut_, BUFFER_NUM, bkCopyByteAlign_ * 2);
}
__aicore__ inline void Process()
{
for (int64_t tis = 0; tis < gbTokenNum_; tis += blockNum_) {
int64_t token_idx = tis + blockIdx_;
if (token_idx < gbTokenNum_) {
LocalTensor<int32_t> local_token_table = inQueTokenTable_.AllocTensor<int32_t>();
LocalTensor<bfloat16_t> local_score_table = inQueScoreTable_.AllocTensor<bfloat16_t>();
DataCopyParams copy_pas{1, (uint16_t)(gbExpertNum_ * sizeof(int32_t)), 0, 0};
DataCopyPadParams pad_pas;
DataCopyPad(local_token_table, gmTokenTable_[token_idx * gbExpertNum_], copy_pas, pad_pas);
copy_pas.blockLen = (uint16_t)(gbExpertNum_ * sizeof(bfloat16_t));
DataCopyPad(local_score_table, gmScoreTable_[token_idx * gbExpertNum_], copy_pas, pad_pas);
inQueTokenTable_.EnQue(local_token_table);
inQueScoreTable_.EnQue(local_score_table);
lmTokenTable_ = inQueTokenTable_.DeQue<int32_t>();
lmScoreTable_ = inQueScoreTable_.DeQue<bfloat16_t>();
lmOut_ = outQueOut_.AllocTensor<bfloat16_t>();
lmOutFloat_ = lmOut_.ReinterpretCast<float>();
Duplicate(lmOutFloat_, (float)0.0, bkCopyNumAlign_);
for (int64_t expert_idx = 0; expert_idx < gbExpertNum_; expert_idx++) {
int64_t routing_token_idx = lmTokenTable_.GetValue(expert_idx);
if (routing_token_idx >= 0) {
CopyIn(routing_token_idx);
Compute(AscendC::ToFloat(lmScoreTable_.GetValue(expert_idx)));
}
}
Cast(lmOut_, lmOutFloat_, RoundMode::CAST_ROUND, bkCopyNum_);
outQueOut_.EnQue(lmOut_);
outQueOut_.DeQue<bfloat16_t>();
CopyOut(token_idx);
inQueTokenTable_.FreeTensor(local_token_table);
inQueScoreTable_.FreeTensor(local_score_table);
outQueOut_.FreeTensor(lmOut_);
}
}
}
private:
__aicore__ inline void CopyIn(int64_t routing_token_idx)
{
LocalTensor<bfloat16_t> local_in = inQueIn_.AllocTensor<bfloat16_t>();
DataCopyExtParams copy_pas{1, (uint32_t)(bkCopyByte_), 0, 0, 0};
DataCopyPadExtParams<bfloat16_t> pad_pas;
int64_t offset = routing_token_idx * gbCopyNum_;
DataCopyPad(local_in[bkCopyNumAlign_], gmIn_[offset], copy_pas, pad_pas);
inQueIn_.EnQue(local_in);
}
__aicore__ inline void Compute(float score)
{
LocalTensor<bfloat16_t> local_in = inQueIn_.DeQue<bfloat16_t>();
LocalTensor<float> local_in_float = local_in.ReinterpretCast<float>();
Cast(local_in_float, local_in[bkCopyNumAlign_], RoundMode::CAST_NONE, bkCopyNum_);
Muls(local_in_float, local_in_float, score, bkCopyNum_);
Add(lmOutFloat_, lmOutFloat_, local_in_float, bkCopyNum_);
inQueIn_.FreeTensor(local_in);
}
__aicore__ inline void CopyOut(int64_t token_idx)
{
DataCopyExtParams copy_pas{1, (uint32_t)(bkCopyByte_), 0, 0, 0};
DataCopyPad(gmOut_[token_idx * bkCopyNum_], lmOut_, copy_pas);
}
private:
TPipe pipe_;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueIn_, inQueTokenTable_, inQueScoreTable_;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueOut_;
GlobalTensor<bfloat16_t> gmIn_, gmOut_;
GlobalTensor<int32_t> gmTokenTable_;
GlobalTensor<bfloat16_t> gmScoreTable_;
LocalTensor<int32_t> lmTokenTable_;
LocalTensor<bfloat16_t> lmScoreTable_;
LocalTensor<bfloat16_t> lmOut_;
LocalTensor<float> lmOutFloat_;
int64_t blockNum_, blockIdx_;
int64_t gbExpertNum_, gbTokenNum_, gbCopyByte_, gbCopyNum_;
int64_t gbExpertNumAlign_, gbTokenNumAlign_;
int64_t bkExpertNum_, bkTokenNum_, bkCopyByte_, bkCopyNum_;
int64_t bkCopyByteAlign_, bkCopyNumAlign_;
};
extern "C" __global__ __aicore__ void compute_final_routing(
GM_ADDR in, GM_ADDR token_table, GM_ADDR score_table, GM_ADDR out,
const int64_t expert_num, const int64_t token_num, const int64_t copy_byte)
{
final_routing op;
op.Init(in, token_table, score_table, out, expert_num, token_num, copy_byte);
op.Process();
}
void final_routing_kernel_lanuch(int64_t block_dim, void* stream,
uint8_t* in, uint8_t* token_table, uint8_t* score_table, uint8_t* out,
const int64_t expert_num, const int64_t token_num, const int64_t copy_byte)
{
compute_final_routing<<<block_dim, nullptr, stream>>>(in, token_table, score_table, out,
expert_num, token_num, copy_byte);
}
inline int64_t align_up(const int64_t number, const int64_t alignSize)
{
if (number % alignSize == 0) {
return number;
}
return ((number / alignSize + 1) * alignSize);
}
int judge_final_routing_lanuch(const int64_t expert_num, const int64_t token_num, const int64_t copy_byte,
const int64_t ubSize, const int64_t vCores)
{
(void)(vCores);
constexpr int64_t BUFFER_NUM = 1;
int64_t gbExpertNum_;
int64_t gbTokenNum_;
int64_t gbCopyByte_;
int64_t gbExpertNumAlign_;
int64_t gbTokenNumAlign_;
int64_t bkCopyByte_;
int64_t bkCopyByteAlign_;
gbExpertNum_ = expert_num;
gbTokenNum_ = token_num;
gbCopyByte_ = copy_byte;
gbTokenNumAlign_ = align_up(gbTokenNum_, 64 / sizeof(int32_t));
gbExpertNumAlign_ = align_up(gbExpertNum_, 64 / sizeof(int32_t));
bkCopyByte_ = gbCopyByte_;
bkCopyByteAlign_ = align_up(bkCopyByte_, 64);
float use_byte = BUFFER_NUM * bkCopyByteAlign_ * 2;
use_byte += gbTokenNumAlign_ * sizeof(int32_t);
use_byte += gbExpertNumAlign_ * sizeof(int32_t);
use_byte += gbExpertNumAlign_ * sizeof(int16_t);
if (use_byte > ubSize) {
std::cout << __FUNCTION__ << ": " << use_byte/1024 << " KB" << std::endl;
return 1;
}
return 0;
}
int final_routing_lanuch(int64_t block_dim, void* stream,
uint8_t* in, uint8_t* token_table, uint8_t* score_table, uint8_t* out,
const int64_t expert_num, const int64_t token_num, const int64_t copy_byte)
{
int64_t ubSize = 184 * 1024;
int64_t vCores = 40;
int ret = judge_final_routing_lanuch(expert_num, token_num, copy_byte, ubSize, vCores);
if (ret == 0) {
final_routing_kernel_lanuch(block_dim, stream, in, token_table, score_table,
out, expert_num, token_num, copy_byte);
return 0;
}
std::cout << __FUNCTION__ << ": " << "UB size is limited, please check!" << std::endl;
return 1;
}
int64_t final_routing_npu(int64_t block_dim, torch::Tensor &in, torch::Tensor &token_table,
torch::Tensor &score_table, torch::Tensor &out)
{
TORCH_CHECK(torch_npu::utils::is_npu(in), "input tensor must be on NPU device");
TORCH_CHECK(torch_npu::utils::is_npu(token_table), "token table tensor must be on NPU device");
TORCH_CHECK(torch_npu::utils::is_npu(score_table), "score table tensor must be on NPU device");
TORCH_CHECK(torch_npu::utils::is_npu(out), "output tensor must be on NPU device");
const int64_t token_num = token_table.size(0);
const int64_t expert_num = token_table.size(1);
const int64_t copy_byte = in.element_size() * in.size(1);
auto stream = c10_npu::getCurrentNPUStream().stream(false);
int launchStatus = 0;
auto acl_call = [=, &launchStatus]() -> int {
launchStatus = final_routing_lanuch(block_dim, stream, (uint8_t *)in.data_ptr(),
(uint8_t *)token_table.data_ptr(), (uint8_t *)score_table.data_ptr(), (uint8_t *)out.data_ptr(),
expert_num, token_num, copy_byte);
return 0;
};
at_npu::native::OpCommand::RunOpApi("FinalRouting", acl_call);
return launchStatus;
}
TORCH_LIBRARY_IMPL(npu_ops_transformer_ext, PrivateUse1, m)
{
m.impl("final_routing", final_routing_npu);
}
}
}