* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under
* the terms and conditions of CANN Open Software License Agreement Version 2.0
* (the "License"). Please refer to the License for details. You may not use
* this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
* KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
* NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the
* License.
*/
* mhc_pre AscendC Kernel - Production
* Computes: output[b, seq, d] = Σ_s (h_pre[s] × x[b * num_streams + s, seq, d])
* Supports: fp32, fp16, bf16 (via fp32 compute)
*/
#include "kernel_operator.h"
using namespace AscendC;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t UB_SIZE = 192 * 1024;
template <typename T, int32_t ALIGN>
class MhcPreKernel {
public:
__aicore__ inline MhcPreKernel() {}
__aicore__ inline void Init(
GM_ADDR input,
GM_ADDR h_pre,
GM_ADDR output,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams,
int64_t tile_length, int64_t block_dim
) {
this->gm_input = reinterpret_cast<__gm__ T*>(input);
this->gm_output = reinterpret_cast<__gm__ T*>(output);
this->batch = batch;
this->num_streams = num_streams;
this->batch_elements = seq_len * dim;
this->tile_length = tile_length;
this->block_dim = block_dim;
gm_h.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(h_pre), num_streams);
pipe.InitBuffer(inQue, BUFFER_NUM, tile_length * sizeof(T));
pipe.InitBuffer(outQue, BUFFER_NUM, tile_length * sizeof(T));
pipe.InitBuffer(tmpBuf, tile_length * sizeof(T));
}
__aicore__ inline void Process() {
int64_t block_idx = GetBlockIdx();
for (int64_t batch_idx = block_idx; batch_idx < batch; batch_idx += block_dim) {
ProcessOneBatch(batch_idx);
}
}
private:
__aicore__ inline void ProcessOneBatch(int64_t batch_idx) {
int64_t out_off = batch_idx * batch_elements;
gm_out.SetGlobalBuffer(gm_output + out_off, batch_elements);
int64_t tiles = (batch_elements + tile_length - 1) / tile_length;
for (int64_t t = 0; t < tiles; ++t) {
int64_t tile_off = t * tile_length;
int64_t len = (tile_off + tile_length > batch_elements)
? (batch_elements - tile_off) : tile_length;
ProcessOneTile(batch_idx, tile_off, len);
}
}
__aicore__ inline void ProcessOneTile(int64_t batch_idx, int64_t tile_off, int64_t len) {
LocalTensor<T> acc = outQue.AllocTensor<T>();
LocalTensor<T> tmp = tmpBuf.Get<T>();
uint32_t alen = ((len + ALIGN - 1) / ALIGN) * ALIGN;
for (int64_t s = 0; s < num_streams; ++s) {
int64_t in_off = (batch_idx * num_streams + s) * batch_elements + tile_off;
gm_in.SetGlobalBuffer(gm_input + in_off, len);
CopyIn(len);
LocalTensor<T> in = inQue.DeQue<T>();
T weight = gm_h.GetValue(s);
if (s == 0) {
Muls(acc, in, weight, alen);
} else {
Muls(tmp, in, weight, alen);
Add(acc, acc, tmp, alen);
}
inQue.FreeTensor(in);
}
outQue.EnQue(acc);
CopyOut(tile_off, len);
}
__aicore__ inline void CopyIn(int64_t len) {
LocalTensor<T> local = inQue.AllocTensor<T>();
uint32_t l = static_cast<uint32_t>(len);
if (l % ALIGN == 0) {
DataCopy(local, gm_in[0], l);
} else {
DataCopyExtParams p{1, l * (uint32_t)sizeof(T), 0, 0, 0};
DataCopyPadExtParams<T> pad{false, 0, 0, T(0)};
DataCopyPad(local, gm_in[0], p, pad);
}
inQue.EnQue(local);
}
__aicore__ inline void CopyOut(int64_t off, int64_t len) {
LocalTensor<T> local = outQue.DeQue<T>();
uint32_t l = static_cast<uint32_t>(len);
if (l % ALIGN == 0) {
DataCopy(gm_out[off], local, l);
} else {
DataCopyExtParams p{1, l * (uint32_t)sizeof(T), 0, 0, 0};
DataCopyPad(gm_out[off], local, p);
}
outQue.FreeTensor(local);
}
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inQue;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQue;
TBuf<QuePosition::VECCALC> tmpBuf;
GlobalTensor<T> gm_in, gm_h, gm_out;
__gm__ T* gm_input;
__gm__ T* gm_output;
int64_t batch, num_streams, batch_elements, tile_length, block_dim;
};
#if __CCE_AICORE__ == 220
template <>
class MhcPreKernel<bfloat16_t, 16> {
public:
__aicore__ inline MhcPreKernel() {}
__aicore__ inline void Init(
GM_ADDR input,
GM_ADDR h_pre,
GM_ADDR output,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams,
int64_t tile_length, int64_t block_dim
) {
this->gm_input = reinterpret_cast<__gm__ bfloat16_t*>(input);
this->gm_output = reinterpret_cast<__gm__ bfloat16_t*>(output);
this->batch = batch;
this->num_streams = num_streams;
this->batch_elements = seq_len * dim;
this->tile_length = tile_length;
this->block_dim = block_dim;
gm_h.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(h_pre), num_streams);
pipe.InitBuffer(inQue, BUFFER_NUM, tile_length * sizeof(bfloat16_t));
pipe.InitBuffer(outQue, BUFFER_NUM, tile_length * sizeof(bfloat16_t));
pipe.InitBuffer(accBuf, tile_length * sizeof(float));
pipe.InitBuffer(tmpBuf, tile_length * sizeof(float));
}
__aicore__ inline void Process() {
int64_t block_idx = GetBlockIdx();
for (int64_t batch_idx = block_idx; batch_idx < batch; batch_idx += block_dim) {
ProcessOneBatch(batch_idx);
}
}
private:
__aicore__ inline void ProcessOneBatch(int64_t batch_idx) {
int64_t out_off = batch_idx * batch_elements;
gm_out.SetGlobalBuffer(gm_output + out_off, batch_elements);
int64_t tiles = (batch_elements + tile_length - 1) / tile_length;
for (int64_t t = 0; t < tiles; ++t) {
int64_t tile_off = t * tile_length;
int64_t len = (tile_off + tile_length > batch_elements)
? (batch_elements - tile_off) : tile_length;
ProcessOneTile(batch_idx, tile_off, len);
}
}
__aicore__ inline void ProcessOneTile(int64_t batch_idx, int64_t tile_off, int64_t len) {
LocalTensor<float> acc = accBuf.Get<float>();
LocalTensor<float> tmp = tmpBuf.Get<float>();
uint32_t alen = ((len + 15) / 16) * 16;
for (int64_t s = 0; s < num_streams; ++s) {
int64_t in_off = (batch_idx * num_streams + s) * batch_elements + tile_off;
gm_in.SetGlobalBuffer(gm_input + in_off, len);
CopyIn(len);
LocalTensor<bfloat16_t> in = inQue.DeQue<bfloat16_t>();
float weight = gm_h.GetValue(s);
Cast(tmp, in, RoundMode::CAST_NONE, alen);
if (s == 0) {
Muls(acc, tmp, weight, alen);
} else {
Muls(tmp, tmp, weight, alen);
Add(acc, acc, tmp, alen);
}
inQue.FreeTensor(in);
}
LocalTensor<bfloat16_t> out = outQue.AllocTensor<bfloat16_t>();
Cast(out, acc, RoundMode::CAST_RINT, alen);
outQue.EnQue(out);
CopyOut(tile_off, len);
}
__aicore__ inline void CopyIn(int64_t len) {
LocalTensor<bfloat16_t> local = inQue.AllocTensor<bfloat16_t>();
uint32_t l = static_cast<uint32_t>(len);
if (l % 16 == 0) {
DataCopy(local, gm_in[0], l);
} else {
DataCopyExtParams p{1, l * 2, 0, 0, 0};
DataCopyPadExtParams<bfloat16_t> pad{false, 0, 0, bfloat16_t(0)};
DataCopyPad(local, gm_in[0], p, pad);
}
inQue.EnQue(local);
}
__aicore__ inline void CopyOut(int64_t off, int64_t len) {
LocalTensor<bfloat16_t> local = outQue.DeQue<bfloat16_t>();
uint32_t l = static_cast<uint32_t>(len);
if (l % 16 == 0) {
DataCopy(gm_out[off], local, l);
} else {
DataCopyExtParams p{1, l * 2, 0, 0, 0};
DataCopyPad(gm_out[off], local, p);
}
outQue.FreeTensor(local);
}
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inQue;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQue;
TBuf<QuePosition::VECCALC> accBuf;
TBuf<QuePosition::VECCALC> tmpBuf;
GlobalTensor<bfloat16_t> gm_in, gm_out;
GlobalTensor<float> gm_h;
__gm__ bfloat16_t* gm_input;
__gm__ bfloat16_t* gm_output;
int64_t batch, num_streams, batch_elements, tile_length, block_dim;
};
#endif
inline int64_t CalcTile(int64_t elem, int64_t sz, int64_t align) {
int64_t t = UB_SIZE / ((BUFFER_NUM * 2 + 1) * sz);
t = (t / align) * align;
if (t > elem) t = ((elem + align - 1) / align) * align;
return t < align ? align : t;
}
extern "C" __global__ __aicore__ void mhc_pre_kernel_fp32(
GM_ADDR in, GM_ADDR h, GM_ADDR out,
int64_t b, int64_t s, int64_t d, int64_t n, int64_t tile, int64_t blkDim
) {
MhcPreKernel<float, 8> op;
op.Init(in, h, out, b, s, d, n, tile, blkDim);
op.Process();
}
extern "C" __global__ __aicore__ void mhc_pre_kernel_fp16(
GM_ADDR in, GM_ADDR h, GM_ADDR out,
int64_t b, int64_t s, int64_t d, int64_t n, int64_t tile, int64_t blkDim
) {
MhcPreKernel<half, 16> op;
op.Init(in, h, out, b, s, d, n, tile, blkDim);
op.Process();
}
#if __CCE_AICORE__ == 220
extern "C" __global__ __aicore__ void mhc_pre_kernel_bf16(
GM_ADDR in, GM_ADDR h, GM_ADDR out,
int64_t b, int64_t s, int64_t d, int64_t n, int64_t tile, int64_t blkDim
) {
MhcPreKernel<bfloat16_t, 16> op;
op.Init(in, h, out, b, s, d, n, tile, blkDim);
op.Process();
}
#endif
extern "C" void mhc_pre_do_fp32(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h, uint8_t* out,
int64_t b, int64_t s, int64_t d, int64_t n
) {
int64_t tile = CalcTile(s * d, sizeof(float), 8);
uint32_t maxBlk = b;
if (blockDim == 0 || blockDim > maxBlk) blockDim = maxBlk;
mhc_pre_kernel_fp32<<<blockDim, nullptr, stream>>>(in, h, out, b, s, d, n, tile, blockDim);
}
extern "C" void mhc_pre_do_fp16(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h, uint8_t* out,
int64_t b, int64_t s, int64_t d, int64_t n
) {
int64_t tile = CalcTile(s * d, sizeof(half), 16);
uint32_t maxBlk = b;
if (blockDim == 0 || blockDim > maxBlk) blockDim = maxBlk;
mhc_pre_kernel_fp16<<<blockDim, nullptr, stream>>>(in, h, out, b, s, d, n, tile, blockDim);
}
extern "C" void mhc_pre_do_bf16(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h, uint8_t* out,
int64_t b, int64_t s, int64_t d, int64_t n
) {
int64_t tile = CalcTile(s * d, 2 + 4, 16);
uint32_t maxBlk = b;
if (blockDim == 0 || blockDim > maxBlk) blockDim = maxBlk;
mhc_pre_kernel_bf16<<<blockDim, nullptr, stream>>>(in, h, out, b, s, d, n, tile, blockDim);
}
extern "C" void mhc_pre_do(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h, uint8_t* out,
int64_t b, int64_t s, int64_t d, int64_t n
) {
mhc_pre_do_fp32(blockDim, stream, in, h, out, b, s, d, n);
}