* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file getri_batched_kernel.cpp
* \brief Kernel-side implementation for batched single-precision matrix inversion (SIMT).
*/
#include <cstdint>
#include "cann_ops_blas_common.h"
#include "kernel_operator.h"
#include "simt_api/asc_simt.h"
#include "getri_batched_tiling_data.h"
using namespace AscendC;
constexpr uint32_t GETRI_SIMT_MAX_THREADS = 256;
static_assert(
(GETRI_SIMT_MAX_THREADS & (GETRI_SIMT_MAX_THREADS - 1)) == 0, "GETRI_SIMT_MAX_THREADS must be a power of 2");
constexpr uint32_t GETRI_REG_TILE_N = 8;
__simt_callee__ __aicore__ inline __gm__ float* ReadPtrFromArray(GM_ADDR arrayBase, uint32_t batchIdx)
{
__gm__ uint64_t* addrSlot = reinterpret_cast<__gm__ uint64_t*>(arrayBase) + batchIdx;
uint64_t rawAddr = *addrSlot;
return reinterpret_cast<__gm__ float*>(rawAddr);
}
__simt_callee__ __aicore__ inline void InitIdentity(__gm__ float* C, uint32_t ldc, uint32_t n)
{
for (uint32_t col = threadIdx.x; col < n; col += blockDim.x) {
for (uint32_t row = 0; row < n; row++) {
C[row + col * ldc] = (row == col) ? 1.0f : 0.0f;
}
}
asc_syncthreads();
}
__simt_callee__ __aicore__ inline void ApplyPermutation(__gm__ float* C, uint32_t ldc, uint32_t n, __gm__ int* piv)
{
if (n <= GETRI_REG_TILE_N) {
if (threadIdx.x < n) {
uint32_t col = threadIdx.x;
float regs[GETRI_REG_TILE_N] = {};
for (uint32_t i = 0; i < n; i++) {
regs[i] = C[i + col * ldc];
}
for (uint32_t k = 0; k < n; k++) {
int pivotRow = piv[k] - 1;
if (pivotRow >= 0 && pivotRow < static_cast<int>(n) && pivotRow != static_cast<int>(k)) {
uint32_t pr = static_cast<uint32_t>(pivotRow);
float tmp = regs[k];
regs[k] = regs[pr];
regs[pr] = tmp;
}
}
for (uint32_t i = 0; i < n; i++) {
C[i + col * ldc] = regs[i];
}
}
asc_syncthreads();
} else {
for (uint32_t k = 0; k < n; k++) {
int pivotRow = piv[k] - 1;
if (pivotRow >= 0 && pivotRow < static_cast<int>(n) && pivotRow != static_cast<int>(k)) {
for (uint32_t col = threadIdx.x; col < n; col += blockDim.x) {
float tmp = C[k + col * ldc];
C[k + col * ldc] = C[pivotRow + col * ldc];
C[pivotRow + col * ldc] = tmp;
}
}
asc_syncthreads();
}
}
}
__simt_callee__ __aicore__ inline void ForwardSolve(
__gm__ float* A, uint32_t lda, __gm__ float* C, uint32_t ldc, uint32_t n)
{
if (n <= GETRI_REG_TILE_N) {
if (threadIdx.x < n) {
uint32_t col = threadIdx.x;
float regs[GETRI_REG_TILE_N] = {};
for (uint32_t i = 0; i < n; i++) {
regs[i] = C[i + col * ldc];
}
for (uint32_t j = 0; j < n; j++) {
float cJC = regs[j];
for (uint32_t i = j + 1; i < n; i++) {
regs[i] -= A[i + j * lda] * cJC;
}
}
for (uint32_t i = 0; i < n; i++) {
C[i + col * ldc] = regs[i];
}
}
asc_syncthreads();
} else {
for (uint32_t j = 0; j < n; j++) {
for (uint32_t col = threadIdx.x; col < n; col += blockDim.x) {
float cJC = C[j + col * ldc];
asc_syncthreads();
for (uint32_t i = j + 1; i < n; i++) {
C[i + col * ldc] -= A[i + j * lda] * cJC;
}
}
asc_syncthreads();
}
}
}
__simt_callee__ __aicore__ inline void BackwardSolveReg(
__gm__ float* A, uint32_t lda, __gm__ float* C, uint32_t ldc, uint32_t n, __ubuf__ int* sharedInfoBuf)
{
float regs[GETRI_REG_TILE_N] = {};
if (threadIdx.x < n) {
uint32_t col = threadIdx.x;
for (uint32_t i = 0; i < n; i++) {
regs[i] = C[i + col * ldc];
}
}
for (int32_t j = static_cast<int32_t>(n) - 1; j >= 0; j--) {
uint32_t uj = static_cast<uint32_t>(j);
float ujj = A[uj + uj * lda];
if (ujj == 0.0f) {
sharedInfoBuf[0] = j + 1;
}
asc_syncthreads();
if (sharedInfoBuf[0] != 0)
break;
if (threadIdx.x < n) {
float cJC = regs[uj] / ujj;
regs[uj] = cJC;
for (uint32_t i = 0; i < uj; i++) {
regs[i] -= A[i + uj * lda] * cJC;
}
}
asc_syncthreads();
}
if (threadIdx.x < n && sharedInfoBuf[0] == 0) {
uint32_t col = threadIdx.x;
for (uint32_t i = 0; i < n; i++) {
C[i + col * ldc] = regs[i];
}
}
asc_syncthreads();
}
__simt_callee__ __aicore__ inline void BackwardSolveGeneral(
__gm__ float* A, uint32_t lda, __gm__ float* C, uint32_t ldc, uint32_t n, __ubuf__ int* sharedInfoBuf)
{
for (int32_t j = static_cast<int32_t>(n) - 1; j >= 0; j--) {
uint32_t uj = static_cast<uint32_t>(j);
float ujj = A[uj + uj * lda];
if (ujj == 0.0f) {
sharedInfoBuf[0] = j + 1;
}
asc_syncthreads();
if (sharedInfoBuf[0] != 0)
break;
for (uint32_t col = threadIdx.x; col < n; col += blockDim.x) {
float cJC = C[uj + col * ldc] / ujj;
C[uj + col * ldc] = cJC;
for (uint32_t i = 0; i < uj; i++) {
C[i + col * ldc] -= A[i + uj * lda] * cJC;
}
}
asc_syncthreads();
}
}
__simt_callee__ __aicore__ inline void BackwardSolve(
__gm__ float* A, uint32_t lda, __gm__ float* C, uint32_t ldc, uint32_t n, __ubuf__ int* sharedInfoBuf)
{
if (n <= GETRI_REG_TILE_N) {
BackwardSolveReg(A, lda, C, ldc, n, sharedInfoBuf);
} else {
BackwardSolveGeneral(A, lda, C, ldc, n, sharedInfoBuf);
}
}
__simt_vf__ __aicore__ LAUNCH_BOUND(GETRI_SIMT_MAX_THREADS) inline void GetriSimtPivot(
uint32_t n, uint32_t lda, uint32_t ldc, uint32_t numBatch, uint32_t startBatch, GM_ADDR aarrayBase,
GM_ADDR carrayBase, __gm__ int* pivotArray, __gm__ int* infoArray)
{
__ubuf__ int sharedInfoBuf[1];
for (uint32_t b = 0; b < numBatch; b++) {
uint32_t batchIdx = startBatch + b;
__gm__ float* A = ReadPtrFromArray(aarrayBase, batchIdx);
__gm__ float* C = ReadPtrFromArray(carrayBase, batchIdx);
__gm__ int* piv = pivotArray + batchIdx * n;
sharedInfoBuf[0] = 0;
asc_syncthreads();
InitIdentity(C, ldc, n);
ApplyPermutation(C, ldc, n, piv);
ForwardSolve(A, lda, C, ldc, n);
BackwardSolve(A, lda, C, ldc, n, sharedInfoBuf);
if (threadIdx.x == 0) {
infoArray[batchIdx] = sharedInfoBuf[0];
}
asc_syncthreads();
}
}
__simt_vf__ __aicore__ LAUNCH_BOUND(GETRI_SIMT_MAX_THREADS) inline void GetriSimtNoPivot(
uint32_t n, uint32_t lda, uint32_t ldc, uint32_t numBatch, uint32_t startBatch, GM_ADDR aarrayBase,
GM_ADDR carrayBase, __gm__ int* infoArray)
{
__ubuf__ int sharedInfoBuf[1];
for (uint32_t b = 0; b < numBatch; b++) {
uint32_t batchIdx = startBatch + b;
__gm__ float* A = ReadPtrFromArray(aarrayBase, batchIdx);
__gm__ float* C = ReadPtrFromArray(carrayBase, batchIdx);
sharedInfoBuf[0] = 0;
asc_syncthreads();
InitIdentity(C, ldc, n);
ForwardSolve(A, lda, C, ldc, n);
BackwardSolve(A, lda, C, ldc, n, sharedInfoBuf);
if (threadIdx.x == 0 && infoArray != nullptr) {
infoArray[batchIdx] = sharedInfoBuf[0];
}
asc_syncthreads();
}
}
extern "C" __global__ __aicore__ void sgetri_batched_kernel(
GM_ADDR aarray, GM_ADDR pivotArray, GM_ADDR carray, GM_ADDR infoArray, const SgetriBatchedTilingData tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
if (tiling.usedCoreNum == 0) {
return;
}
uint32_t blockId = GetBlockIdx();
if (blockId >= tiling.usedCoreNum) {
return;
}
uint32_t startBatch = blockId * tiling.batchPerCore;
uint32_t numBatch = (blockId == tiling.usedCoreNum - 1) ? tiling.batchTail : tiling.batchPerCore;
if (numBatch == 0) {
return;
}
auto* pivots = reinterpret_cast<__gm__ int*>(pivotArray);
auto* infos = reinterpret_cast<__gm__ int*>(infoArray);
if (tiling.usePivot) {
asc_vf_call<GetriSimtPivot>(
dim3{GETRI_SIMT_MAX_THREADS, 1, 1}, tiling.n, tiling.lda, tiling.ldc, numBatch, startBatch, aarray, carray,
pivots, infos);
} else {
asc_vf_call<GetriSimtNoPivot>(
dim3{GETRI_SIMT_MAX_THREADS, 1, 1}, tiling.n, tiling.lda, tiling.ldc, numBatch, startBatch, aarray, carray,
infos);
}
}
void sgetri_batched_kernel_do(
GM_ADDR aarray, GM_ADDR pivotArray, GM_ADDR carray, GM_ADDR infoArray, const SgetriBatchedTilingData& tiling,
uint32_t numBlocks, void* stream)
{
sgetri_batched_kernel<<<numBlocks, nullptr, stream>>>(aarray, pivotArray, carray, infoArray, tiling);
}