* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file epilogue_alpha_beta_kernel.cpp
* \brief SIMT epilogue: D = alpha * D_raw + beta * C.
*/
#include <cstdint>
#include "epilogue_alpha_beta_tiling_data.h"
#include "kernel_operator.h"
#include "simt_api/asc_bf16.h"
#include "simt_api/asc_simt.h"
namespace {
constexpr uint32_t SIMT_MAX_THREAD_NUM = 2048;
template <typename TDRaw, typename TC, typename TD>
__simt_vf__ __aicore__ LAUNCH_BOUND(SIMT_MAX_THREAD_NUM) inline void EpilogueAlphaBetaSimt(
uint32_t m, uint32_t n, uint32_t ldc, uint32_t ldd, uint32_t lddRaw, float alpha, float beta, bool useC,
__gm__ const TDRaw* dRawGm, __gm__ const TC* cGm, __gm__ TD* dGm)
{
constexpr bool RAW_BF16 = AscendC::IsSameType<TDRaw, bfloat16_t>::value;
constexpr bool C_BF16 = AscendC::IsSameType<TC, bfloat16_t>::value;
constexpr bool D_BF16 = AscendC::IsSameType<TD, bfloat16_t>::value;
const uint32_t total = m * n;
for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += gridDim.x * blockDim.x) {
const uint32_t row = idx / n;
const uint32_t col = idx % n;
const uint32_t rawIdx = row * lddRaw + col;
const uint32_t outIdx = row * ldd + col;
float rawVal = RAW_BF16 ? __bfloat162float(dRawGm[rawIdx]) : static_cast<float>(dRawGm[rawIdx]);
float cVal = 0.0f;
if (useC) {
const uint32_t cIdx = row * ldc + col;
cVal = C_BF16 ? __bfloat162float(cGm[cIdx]) : static_cast<float>(cGm[cIdx]);
}
const float result = alpha * rawVal + beta * cVal;
if constexpr (D_BF16) {
dGm[outIdx] = __float2bfloat16_rn_sat(result);
} else {
dGm[outIdx] = static_cast<TD>(result);
}
}
}
template <typename TDRaw, typename TC, typename TD>
__aicore__ inline void LaunchEpilogueAlphaBetaTyped(
const EpilogueAlphaBetaTilingData& tiling, GM_ADDR dRaw, GM_ADDR c, GM_ADDR d)
{
auto* dRawGm = reinterpret_cast<__gm__ const TDRaw*>(dRaw);
auto* cGm = reinterpret_cast<__gm__ const TC*>(c);
auto* dGm = reinterpret_cast<__gm__ TD*>(d);
const bool useC = tiling.useC != 0U;
asc_vf_call<EpilogueAlphaBetaSimt<TDRaw, TC, TD>>(
dim3{tiling.numThreads, 1, 1}, tiling.m, tiling.n, tiling.ldc, tiling.ldd, tiling.lddRaw, tiling.alpha,
tiling.beta, useC, dRawGm, cGm, dGm);
}
__aicore__ inline void DispatchEpilogueAlphaBeta(
const EpilogueAlphaBetaTilingData& tiling, GM_ADDR dRaw, GM_ADDR c, GM_ADDR d)
{
const bool rawBf16 = tiling.dtypeDRaw != 0U;
const bool cBf16 = tiling.dtypeC != 0U;
const bool dBf16 = tiling.dtypeD != 0U;
if (!rawBf16 && !cBf16 && !dBf16) {
LaunchEpilogueAlphaBetaTyped<float, float, float>(tiling, dRaw, c, d);
} else if (!rawBf16 && !cBf16 && dBf16) {
LaunchEpilogueAlphaBetaTyped<float, float, bfloat16_t>(tiling, dRaw, c, d);
} else if (!rawBf16 && cBf16 && !dBf16) {
LaunchEpilogueAlphaBetaTyped<float, bfloat16_t, float>(tiling, dRaw, c, d);
} else if (!rawBf16 && cBf16 && dBf16) {
LaunchEpilogueAlphaBetaTyped<float, bfloat16_t, bfloat16_t>(tiling, dRaw, c, d);
} else if (rawBf16 && !cBf16 && !dBf16) {
LaunchEpilogueAlphaBetaTyped<bfloat16_t, float, float>(tiling, dRaw, c, d);
} else if (rawBf16 && !cBf16 && dBf16) {
LaunchEpilogueAlphaBetaTyped<bfloat16_t, float, bfloat16_t>(tiling, dRaw, c, d);
} else if (rawBf16 && cBf16 && !dBf16) {
LaunchEpilogueAlphaBetaTyped<bfloat16_t, bfloat16_t, float>(tiling, dRaw, c, d);
} else {
LaunchEpilogueAlphaBetaTyped<bfloat16_t, bfloat16_t, bfloat16_t>(tiling, dRaw, c, d);
}
}
}
__global__ __aicore__ void EpilogueAlphaBetaKernel(GM_ADDR dRaw, GM_ADDR c, GM_ADDR d, EpilogueAlphaBetaTilingData tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
DispatchEpilogueAlphaBeta(tiling, dRaw, c, d);
}
void epilogue_alpha_beta_kernel_do(
uint8_t* dRaw, uint8_t* c, uint8_t* d, const EpilogueAlphaBetaTilingData& tiling, void* stream)
{
EpilogueAlphaBetaTilingData tilingCopy = tiling;
EpilogueAlphaBetaKernel<<<tilingCopy.usedCoreNum, nullptr, stream>>>(dRaw, c, d, tilingCopy);
}