* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file softmaxflashv2.h
* \brief SoftmaxFlashV2 api of AscendC
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_SOFTMAXFLASHV2_H__
#endif
#ifndef LIB_SOFTMAX_SOFTMAXFLASHV2_H
#define LIB_SOFTMAX_SOFTMAXFLASHV2_H
#include "kernel_tensor.h"
#include "kernel_tiling/kernel_tiling.h"
#include "include/adv_api/activation/softmax_utils.h"
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
#include "../../../impl/adv_api/detail/activation/softmax/softmax_common.h"
#include "../../../impl/adv_api/detail/activation/softmax/softmax_flashv2_base_impl.h"
#endif
#pragma begin_pipe(V)
namespace AscendC {
* \ingroup SoftmaxFlashV2
* \brief compute tiling
* \param [out] softmaxTiling: the new tiling for SoftmaxFlashV2 api
* \param [in] shapeInfo: input src shape
* \param [in] dataTypeSize1: typesize of input src tensor data
* \param [in] dataTypeSize2: typesize of input max&&sum tensor data
* \param [in] localWorkSpaceSize: input local temporary Tensor size, byte unit
* \param [in] isUpdate: whether to update the sum and expmax, keep the same with SoftmaxFlashV2 api
* \param [in] isBasicBlock: keep the same with SoftmaxFlashV2 api
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ, keep the same with SoftmaxFlashV2 api
*/
__aicore__ inline constexpr SoftMaxTiling SoftMaxFlashV2TilingFunc(const SoftMaxShapeInfo& shapeInfo,
const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, const uint32_t localWorkSpaceSize,
const bool isUpdate = false, const bool isBasicBlock = false, const bool isDataFormatNZ = false, const bool isFlashOutputBrc = false)
{
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
return SoftMaxFlashV2TilingFuncImpl(shapeInfo.srcM, shapeInfo.srcK, dataTypeSize1, dataTypeSize2,
localWorkSpaceSize / B32_BYTE_SIZE, isUpdate, isBasicBlock, isDataFormatNZ, isFlashOutputBrc);
#else
return {};
#endif
}
* \ingroup SoftmaxFlashV2
* \brief softmax calculation in FlashAttention-2 process: you can refet to the paper:
* <FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning>
* if isUpdate = false:
* max = rowmax(x), y = exp(x-max), sum = rowsum(y)
* if isUpdate = true:
* max = max(rowmax(x), inmax), y = exp(x-max), sum = rowsum(y)
* expmax = exp(inmax - max), sum = expmax * insum + sum
* \note support data type: half and float
* \param [out] dstTensor: output y
* \param [out] expSumTensor: output sum
* \param [out] maxTensor: output max
* \param [out] expMaxTensor: output expmax, only valid when isUpdate is true
* \param [in] srcTensor: input x
* \param [in] inExpSumTensor: input last insum
* \param [in] inMaxTensor: input last inmax
* \param [in] softmaxShapeInfo: input src shape
* \param [in] tiling: input softmaxtiling
* \param [in] isUpdate: whether to update the sum and expmax
* \param [in] isReuseSource: reserved param
* \param [in] isBasicBlock: if src shape[m,k] satisfy the condition(m%8 == 0 && k%64 == 0), you can set true to improve
* performance, but it is a reserved param when isDataFormatNZ = true
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ
*/
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor,
const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor,
const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
{
if ASCEND_IS_AIC {
return;
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
TRACE_START(TraceId::SoftmaxFlashV2);
SoftmaxFlashV2Impl<T, T, isUpdate, isReuseSource, isBasicBlock, isDataFormatNZ, config>(dstTensor, expSumTensor,
maxTensor, srcTensor, expMaxTensor, inExpSumTensor, inMaxTensor, tiling, softmaxShapeInfo);
TRACE_STOP(TraceId::SoftmaxFlashV2);
#endif
}
* \ingroup SoftmaxFlashV2
* \brief softmax calculation in FlashAttention-2 process: you can refet to the paper:
* <FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning>
* if isUpdate = false:
* max = rowmax(x), y = exp(x-max), sum = rowsum(y)
* if isUpdate = true:
* max = max(rowmax(x), inmax), y = exp(x-max), sum = rowsum(y)
* expmax = exp(inmax - max), sum = expmax * insum + sum
* \param [out] dstTensor: output y with dtype of half
* \param [out] expSumTensor: output sum with dtype of float
* \param [out] maxTensor: output max with dtype of float
* \param [out] expMaxTensor: output expmax with dtype of half, only valid when isUpdate is true
* \param [in] srcTensor: input x with dtype of half
* \param [in] inExpSumTensor: input last insum with dtype of float
* \param [in] inMaxTensor: input last inmax with dtype of float
* \param [in] softmaxShapeInfo: input src shape
* \param [in] tiling: input softmaxtiling
* \param [in] isUpdate: whether to update the sum and expmax
* \param [in] isReuseSource: reserved param
* \param [in] isBasicBlock: if src shape[m,k] satisfy the condition(m%8 == 0 && k%64 == 0), you can set true to improve
* performance, but it is a reserved param when isDataFormatNZ = true
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ
*/
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor,
const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor,
const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
{
if ASCEND_IS_AIC {
return;
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
TRACE_START(TraceId::SoftmaxFlashV2);
SoftmaxFlashV2Impl<half, float, isUpdate, isReuseSource, isBasicBlock, isDataFormatNZ, config>(dstTensor, expSumTensor,
maxTensor, srcTensor, expMaxTensor, inExpSumTensor, inMaxTensor, tiling, softmaxShapeInfo);
TRACE_STOP(TraceId::SoftmaxFlashV2);
#endif
}
* \ingroup SoftmaxFlashV2
* \brief softmax calculation in FlashAttention-2 process: you can refet to the paper:
* <FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning>
* if isUpdate = false:
* max = rowmax(x), y = exp(x-max), sum = rowsum(y)
* if isUpdate = true:
* max = max(rowmax(x), inmax), y = exp(x-max), sum = rowsum(y)
* expmax = exp(inmax - max), sum = expmax * insum + sum
* \note support data type: half and float
* \param [out] dstTensor: output y
* \param [out] expSumTensor: output sum
* \param [out] maxTensor: output max
* \param [out] expMaxTensor: output expmax, only valid when isUpdate is true
* \param [in] srcTensor: input x
* \param [in] inExpSumTensor: input last insum
* \param [in] inMaxTensor: input last inmax
* \param [in] sharedTmpBuffer: input local temporary Tensor, you can get the range by tilingfunc of
* GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize
* \param [in] softmaxShapeInfo: input src shape
* \param [in] tiling: input softmaxtiling
* \param [in] isUpdate: whether to update the sum and expmax
* \param [in] isReuseSource: reserved param
* \param [in] isBasicBlock: if src shape[m,k] satisfy the condition(m%8 == 0 && k%64 == 0), you can set true to improve
* performance, but it is a reserved param when isDataFormatNZ = true
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ
*/
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor,
const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor,
const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor,
const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
{
if ASCEND_IS_AIC {
return;
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
TRACE_START(TraceId::SoftmaxFlashV2);
SoftmaxFlashV2Impl<T, T, isUpdate, isReuseSource, isBasicBlock, isDataFormatNZ, config>(dstTensor, expSumTensor,
maxTensor, srcTensor, expMaxTensor, inExpSumTensor, inMaxTensor, sharedTmpBuffer, tiling, softmaxShapeInfo);
TRACE_STOP(TraceId::SoftmaxFlashV2);
#endif
}
* \ingroup SoftmaxFlashV2
* \brief softmax calculation in FlashAttention-2 process: you can refet to the paper:
* <FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning>
* if isUpdate = false:
* max = rowmax(x), y = exp(x-max), sum = rowsum(y)
* if isUpdate = true:
* max = max(rowmax(x), inmax), y = exp(x-max), sum = rowsum(y)
* expmax = exp(inmax - max), sum = expmax * insum + sum
* \param [out] dstTensor: output y with dtype of half
* \param [out] expSumTensor: output sum with dtype of float
* \param [out] maxTensor: output max with dtype of float
* \param [out] expMaxTensor: output expmax with dtype of half, only valid when isUpdate is true
* \param [in] srcTensor: input x with dtype of half
* \param [in] inExpSumTensor: input last insum with dtype of float
* \param [in] inMaxTensor: input last inmax with dtype of float
* \param [in] sharedTmpBuffer: input local temporary Tensor, you can get the range by tilingfunc of
* GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize
* \param [in] softmaxShapeInfo: input src shape
* \param [in] tiling: input softmaxtiling
* \param [in] isUpdate: whether to update the sum and expmax
* \param [in] isReuseSource: reserved param
* \param [in] isBasicBlock: if src shape[m,k] satisfy the condition(m%8 == 0 && k%64 == 0), you can set true to improve
* performance, but it is a reserved param when isDataFormatNZ = true
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ
*/
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor,
const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor,
const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor,
const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
{
if ASCEND_IS_AIC {
return;
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
TRACE_START(TraceId::SoftmaxFlashV2);
SoftmaxFlashV2Impl<half, float, isUpdate, isReuseSource, isBasicBlock, isDataFormatNZ, config>(dstTensor, expSumTensor,
maxTensor, srcTensor, expMaxTensor, inExpSumTensor, inMaxTensor, sharedTmpBuffer, tiling, softmaxShapeInfo);
TRACE_STOP(TraceId::SoftmaxFlashV2);
#endif
}
* \ingroup SoftmaxFlashV2
* \brief softmax calculation in FlashAttention-2 process: you can refet to the paper:
* <FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning>
* if isUpdate = false:
* max = rowmax(x), y = exp(x-max), sum = rowsum(y)
* if isUpdate = true:
* max = max(rowmax(x), inmax), y = exp(x-max), sum = rowsum(y)
* expmax = exp(inmax - max), sum = expmax * insum + sum
* \note support data type: half and float
* \param [out] dstTensor: output y
* \param [out] outReduceMax: output reduce max
* \param [out] outExpSum: output sum
* \param [out] outMax: output max
* \param [out] outExpMax: output expmax, only valid when isUpdate is true
* \param [in] srcTensor: input x
* \param [in] inExpSum: input last insum
* \param [in] inMax: input last inmax
* \param [in] softmaxShapeInfo: input src shape
* \param [in] tiling: input softmaxtiling
* \param [in] isUpdate: whether to update the sum and expmax
* \param [in] isReuseSource: reserved param
* \param [in] isBasicBlock: if src shape[m,k] satisfy the condition(m%8 == 0 && k%64 == 0), you can set true to improve
* performance, but it is a reserved param when isDataFormatNZ = true
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ
*/
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outReduceMax,
const LocalTensor<T>& outExpSum, const LocalTensor<T>& outMax, const LocalTensor<T>& srcTensor,
const LocalTensor<T>& outExpMax, const LocalTensor<T>& inExpSum, const LocalTensor<T>& inMax,
const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
{
if ASCEND_IS_AIC {
return;
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
TRACE_START(TraceId::SoftmaxFlashV2);
SoftmaxFlashV2MaxImpl<T, T, isUpdate, isReuseSource, isBasicBlock, isDataFormatNZ, config>(dstTensor, outReduceMax,
outExpSum, outMax, srcTensor, outExpMax, inExpSum, inMax, tiling, softmaxShapeInfo);
TRACE_STOP(TraceId::SoftmaxFlashV2);
#endif
}
* \ingroup SoftmaxFlashV2
* \brief softmax calculation in FlashAttention-2 process: you can refet to the paper:
* <FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning>
* if isUpdate = false:
* max = rowmax(x), y = exp(x-max), sum = rowsum(y)
* if isUpdate = true:
* max = max(rowmax(x), inmax), y = exp(x-max), sum = rowsum(y)
* expmax = exp(inmax - max), sum = expmax * insum + sum
* \param [out] dstTensor: output y with dtype of half
* \param [out] outReduceMax: output reduce max
* \param [out] outExpSum: output sum with dtype of float
* \param [out] outMax: output max with dtype of float
* \param [out] outExpMax: output expmax with dtype of half, only valid when isUpdate is true
* \param [in] srcTensor: input x with dtype of half
* \param [in] inExpSum: input last insum with dtype of float
* \param [in] inMax: input last inmax with dtype of float
* \param [in] sharedTmpBuffer: input local temporary Tensor, you can get the range by tilingfunc of
* GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize
* \param [in] softmaxShapeInfo: input src shape
* \param [in] tiling: input softmaxtiling
* \param [in] isUpdate: whether to update the sum and expmax
* \param [in] isReuseSource: reserved param
* \param [in] isBasicBlock: if src shape[m,k] satisfy the condition(m%8 == 0 && k%64 == 0), you can set true to improve
* performance, but it is a reserved param when isDataFormatNZ = true
* \param [in] isDataFormatNZ: if the data format of input srcTensor is NZ
*/
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outReduceMax,
const LocalTensor<T>& outExpSum, const LocalTensor<T>& outMax, const LocalTensor<T>& srcTensor,
const LocalTensor<T>& outExpMax, const LocalTensor<T>& inExpSum, const LocalTensor<T>& inMax,
const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
{
if ASCEND_IS_AIC {
return;
}
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
TRACE_START(TraceId::SoftmaxFlashV2);
SoftmaxFlashV2MaxImpl<T, T, isUpdate, isReuseSource, isBasicBlock, isDataFormatNZ, config>(dstTensor, outReduceMax,
outExpSum, outMax, srcTensor, outExpMax, inExpSum, inMax, sharedTmpBuffer, tiling, softmaxShapeInfo);
TRACE_STOP(TraceId::SoftmaxFlashV2);
#endif
}
}
#pragma end_pipe
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_SOFTMAXFLASHV2_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_SOFTMAXFLASHV2_H__
#endif