* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file stpsv_kernel.cpp
* \brief Single-precision triangular packed solver kernel (scalar path, n < 128).
*
* For n >= 128, stpsv_kernel_simt.cpp provides a SIMT VF parallelized kernel.
*/
#include <cstdint>
#include "cann_ops_blas_common.h"
#include "kernel_operator.h"
#include "stpsv_tiling_data.h"
#include "common/helper/kernel_constant.h"
#include "stpsv_kernel_utils.h"
using namespace AscendC;
enum class TpsvUplo { UPPER, LOWER };
enum class TpsvTrans { NO_TRANS, TRANS };
enum class TpsvDiag { UNIT, NON_UNIT };
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
class StpsvKernel {
public:
__aicore__ inline StpsvKernel() {}
__aicore__ inline void Init(StpsvTilingData tiling);
__aicore__ inline void Process();
private:
__aicore__ inline void ProcessForward();
__aicore__ inline void ProcessBackward();
__aicore__ inline float GetDiag(uint32_t row);
__aicore__ inline float GetElemOffDiag(uint32_t row, uint32_t col);
__aicore__ inline uint32_t XOffset(uint32_t idx);
__aicore__ inline void CopyIn(uint32_t row);
__aicore__ inline void Compute(uint32_t row);
__aicore__ inline void CopyOut(uint32_t row);
TPipe pipe;
GlobalTensor<float> apGM;
GlobalTensor<float> xGM;
TQue<QuePosition::VECIN, BUFFER_NUM> xInQueue;
TQue<QuePosition::VECOUT, BUFFER_NUM> xOutQueue;
static constexpr bool kForward =
(UPLO == TpsvUplo::LOWER && TRANS == TpsvTrans::NO_TRANS) ||
(UPLO == TpsvUplo::UPPER && TRANS == TpsvTrans::TRANS);
uint32_t n;
int64_t incx;
};
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline uint32_t StpsvKernel<UPLO, TRANS, DIAG>::XOffset(uint32_t idx)
{
if (incx >= 0) {
return idx * static_cast<uint32_t>(incx);
} else {
return (n - 1 - idx) * static_cast<uint32_t>(-incx);
}
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::Init(StpsvTilingData tiling)
{
this->n = tiling.n;
this->incx = tiling.incx;
uint32_t apCount = n * (n + 1) / 2;
uint32_t absIncx = static_cast<uint32_t>(incx >= 0 ? incx : -incx);
uint32_t xCount = (n > 0) ? (absIncx * (n - 1) + 1) : 0;
apGM.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(tiling.ap), apCount);
xGM.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(tiling.x), xCount);
pipe.InitBuffer(xInQueue, BUFFER_NUM, sizeof(float));
pipe.InitBuffer(xOutQueue, BUFFER_NUM, sizeof(float));
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline float StpsvKernel<UPLO, TRANS, DIAG>::GetDiag(uint32_t row)
{
if constexpr (UPLO == TpsvUplo::LOWER) {
return apGM.GetValue(TpsvPackedLowerIdx(row, row, n));
} else {
return apGM.GetValue(TpsvPackedUpperIdx(row, row));
}
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline float StpsvKernel<UPLO, TRANS, DIAG>::GetElemOffDiag(uint32_t row, uint32_t col)
{
if constexpr (UPLO == TpsvUplo::LOWER) {
if constexpr (TRANS == TpsvTrans::NO_TRANS) {
return apGM.GetValue(TpsvPackedLowerIdx(row, col, n));
} else {
return apGM.GetValue(TpsvPackedLowerIdx(col, row, n));
}
} else {
if constexpr (TRANS == TpsvTrans::NO_TRANS) {
return apGM.GetValue(TpsvPackedUpperIdx(row, col));
} else {
return apGM.GetValue(TpsvPackedUpperIdx(col, row));
}
}
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::CopyIn(uint32_t row)
{
LocalTensor<float> xLocal = xInQueue.AllocTensor<float>();
xLocal.SetValue(0, xGM.GetValue(XOffset(row)));
xInQueue.EnQue<float>(xLocal);
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::Compute(uint32_t row)
{
LocalTensor<float> xLocal = xInQueue.DeQue<float>();
float sum = xLocal.GetValue(0);
if constexpr (kForward) {
for (uint32_t j = 0; j < row; ++j) {
sum -= GetElemOffDiag(row, j) * xGM.GetValue(XOffset(j));
}
if constexpr (DIAG == TpsvDiag::NON_UNIT) {
sum = sum / GetDiag(row);
}
} else {
for (uint32_t j = row + 1; j < n; ++j) {
sum -= GetElemOffDiag(row, j) * xGM.GetValue(XOffset(j));
}
if constexpr (DIAG == TpsvDiag::NON_UNIT) {
sum = sum / GetDiag(row);
}
}
xLocal.SetValue(0, sum);
xOutQueue.EnQue<float>(xLocal);
xInQueue.FreeTensor(xLocal);
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::CopyOut(uint32_t row)
{
LocalTensor<float> xLocal = xOutQueue.DeQue<float>();
xGM.SetValue(XOffset(row), xLocal.GetValue(0));
xOutQueue.FreeTensor(xLocal);
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::ProcessForward()
{
for (uint32_t i = 0; i < n; ++i) {
CopyIn(i);
Compute(i);
CopyOut(i);
}
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::ProcessBackward()
{
for (uint32_t i = n; i-- > 0; ) {
CopyIn(i);
Compute(i);
CopyOut(i);
}
}
template <TpsvUplo UPLO, TpsvTrans TRANS, TpsvDiag DIAG>
__aicore__ inline void StpsvKernel<UPLO, TRANS, DIAG>::Process()
{
if constexpr (kForward) {
ProcessForward();
} else {
ProcessBackward();
}
}
#define DEFINE_TPSV_KERNEL(uplo, trans, diag, name) \
__global__ __aicore__ void name(StpsvTilingData tiling) \
{ \
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); \
StpsvKernel<TpsvUplo::uplo, TpsvTrans::trans, TpsvDiag::diag> op; \
op.Init(tiling); \
op.Process(); \
}
DEFINE_TPSV_KERNEL(LOWER, NO_TRANS, NON_UNIT, stpsv_kernel_lower_no_trans_non_unit)
DEFINE_TPSV_KERNEL(LOWER, NO_TRANS, UNIT, stpsv_kernel_lower_no_trans_unit)
DEFINE_TPSV_KERNEL(UPPER, NO_TRANS, NON_UNIT, stpsv_kernel_upper_no_trans_non_unit)
DEFINE_TPSV_KERNEL(UPPER, NO_TRANS, UNIT, stpsv_kernel_upper_no_trans_unit)
DEFINE_TPSV_KERNEL(LOWER, TRANS, NON_UNIT, stpsv_kernel_lower_trans_non_unit)
DEFINE_TPSV_KERNEL(LOWER, TRANS, UNIT, stpsv_kernel_lower_trans_unit)
DEFINE_TPSV_KERNEL(UPPER, TRANS, NON_UNIT, stpsv_kernel_upper_trans_non_unit)
DEFINE_TPSV_KERNEL(UPPER, TRANS, UNIT, stpsv_kernel_upper_trans_unit)
#undef DEFINE_TPSV_KERNEL
void stpsv_simt_kernel_do(const StpsvTilingData &tiling, void* stream);
void stpsv_kernel_do(const StpsvTilingData &tiling, void* stream)
{
if (tiling.numThreads > 0) {
stpsv_simt_kernel_do(tiling, stream);
return;
}
if (tiling.uplo == ACLBLAS_LOWER) {
if (tiling.trans == ACLBLAS_OP_N) {
if (tiling.diag == ACLBLAS_NON_UNIT) {
stpsv_kernel_lower_no_trans_non_unit<<<1, nullptr, stream>>>(tiling);
} else {
stpsv_kernel_lower_no_trans_unit<<<1, nullptr, stream>>>(tiling);
}
} else {
if (tiling.diag == ACLBLAS_NON_UNIT) {
stpsv_kernel_lower_trans_non_unit<<<1, nullptr, stream>>>(tiling);
} else {
stpsv_kernel_lower_trans_unit<<<1, nullptr, stream>>>(tiling);
}
}
} else {
if (tiling.trans == ACLBLAS_OP_N) {
if (tiling.diag == ACLBLAS_NON_UNIT) {
stpsv_kernel_upper_no_trans_non_unit<<<1, nullptr, stream>>>(tiling);
} else {
stpsv_kernel_upper_no_trans_unit<<<1, nullptr, stream>>>(tiling);
}
} else {
if (tiling.diag == ACLBLAS_NON_UNIT) {
stpsv_kernel_upper_trans_non_unit<<<1, nullptr, stream>>>(tiling);
} else {
stpsv_kernel_upper_trans_unit<<<1, nullptr, stream>>>(tiling);
}
}
}
}