* This file is part of the OpenBOAT project at Harbin Institute of Technology (HIT)
* and is contributed to the CANN Open Software.
*
* Copyright (c) 2025 AISS Group, Harbin Institute of Technology (HIT).
* All Rights Reserved.
*
* Authors (accounts):
* - Liu Jun <@kbryantttt>
* - Su Tonghua <@sutonghua>
*
* This program is free software: you can redistribute it and/or modify it.
* Licensed under the CANN Open Software License Agreement Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* See the LICENSE file at the root of the repository for the full text of the License.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTIES OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
*/
* \file tril.h
* \brief
*/
#ifndef __TRIL_H__
#define __TRIL_H__
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tril_tiling_data.h"
#include "tril_tiling_key.h"
namespace NsTril {
using namespace AscendC;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t minNum = 1;
constexpr int keyOne = 1;
constexpr int keyTwo = 2;
constexpr int keyThree = 3;
constexpr int keyFour = 4;
constexpr int computeBatchSize = 256;
template <typename T>
class Tril {
public:
__aicore__ inline Tril(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, const TrilTilingData* tilingData, uint32_t key);
__aicore__ inline void Process();
private:
__aicore__ inline void SheerDup();
__aicore__ inline void SheerZero();
__aicore__ inline void NaivePath();
__aicore__ inline void FastPath();
__aicore__ inline void AllZero(uint32_t tileLength);
__aicore__ inline void CopyIn(uint32_t GmOffset, uint32_t tileLength);
__aicore__ inline void CopyOut(uint32_t GmOffset, uint32_t tileLength);
__aicore__ inline void Compute(int32_t cnt, uint32_t initLength, int32_t adjust);
private:
AscendC::TPipe pipe;
AscendC::TQueBind<AscendC::QuePosition::VECIN, AscendC::QuePosition::VECOUT, BUFFER_NUM> queBind;
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX;
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueueY;
AscendC::GlobalTensor<DTYPE_X> xGm;
AscendC::GlobalTensor<DTYPE_X> yGm;
uint32_t totalLengthAligned;
int32_t matrixNum;
int32_t matrixSize;
int32_t rowLength;
int32_t columnLength;
int32_t diagVal;
int32_t loopCnt;
uint32_t fullTileLength;
uint32_t lastTileLength;
int32_t fullCnt;
int32_t lastCnt;
int32_t alignNum;
uint32_t typeSize;
uint32_t fullRowInc;
uint32_t initLength;
int32_t repeatTimes;
uint32_t key;
};
template <typename T>
__aicore__ inline void Tril<T>::Init(GM_ADDR x, GM_ADDR y, const TrilTilingData* tilingData, uint32_t key)
{
this->matrixNum = tilingData->matrixNum;
this->matrixSize = tilingData->matrixSize;
this->rowLength = tilingData->rowLength;
this->columnLength = tilingData->columnLength;
this->diagVal = tilingData->diagVal;
this->fullCnt = tilingData->fullCnt;
this->lastCnt = tilingData->lastCnt;
if (tilingData->columnLength == 0)
{
this->columnLength = minNum;
}
this->fullRowInc = tilingData->fullTileLength / tilingData->columnLength;
this->initLength = 1;
this->typeSize = tilingData->typeSize;
if (this->typeSize == 0)
{
this->typeSize = sizeof(float);
}
this->repeatTimes = columnLength / (computeBatchSize / this->typeSize);
this->key = key;
uint64_t gmBuffer = tilingData->totalLengthAligned;
xGm.SetGlobalBuffer((__gm__ DTYPE_X *)x, gmBuffer);
yGm.SetGlobalBuffer((__gm__ DTYPE_X *)y, gmBuffer);
this->loopCnt = tilingData->loopCnt;
this->fullTileLength = tilingData->fullTileLength;
this->lastTileLength = tilingData->lastTileLength;
uint32_t singleBuffer = tilingData->fullTileLength;
if (singleBuffer < tilingData->lastTileLength)
{
singleBuffer = tilingData->lastTileLength;
}
if (key == keyThree || key == keyFour)
{
pipe.InitBuffer(inQueueX, BUFFER_NUM, singleBuffer * this->typeSize);
pipe.InitBuffer(outQueueY, BUFFER_NUM, singleBuffer * this->typeSize);
}
else
{
pipe.InitBuffer(queBind, BUFFER_NUM, singleBuffer * this->typeSize);
}
}
template <typename T>
__aicore__ inline void Tril<T>::SheerDup()
{
uint32_t GmOffset = 0;
for (int i = 0; i < this->loopCnt - 1; i++, GmOffset += this->fullTileLength)
{
auto bindLocal = queBind.AllocTensor<DTYPE_X>();
AscendC::DataCopy(bindLocal, xGm[GmOffset], this->fullTileLength);
queBind.EnQue(bindLocal);
bindLocal = queBind.DeQue<DTYPE_X>();
AscendC::DataCopy(yGm[GmOffset], bindLocal, this->fullTileLength);
queBind.FreeTensor(bindLocal);
}
auto bindLocal = queBind.AllocTensor<DTYPE_X>();
AscendC::DataCopy(bindLocal, xGm[GmOffset], this->lastTileLength);
queBind.EnQue(bindLocal);
bindLocal = queBind.DeQue<DTYPE_X>();
AscendC::DataCopy(yGm[GmOffset], bindLocal, this->lastTileLength);
queBind.FreeTensor(bindLocal);
}
template <typename T>
__aicore__ inline void Tril<T>::SheerZero()
{
uint32_t GmOffset = 0;
for (int i = 0; i < this->loopCnt - 1; i++, GmOffset += this->fullTileLength)
{
CopyIn(GmOffset, this->fullTileLength);
AllZero(this->fullTileLength);
CopyOut(GmOffset, this->fullTileLength);
}
CopyIn(GmOffset, this->lastTileLength);
AllZero(this->lastTileLength);
CopyOut(GmOffset, this->lastTileLength);
}
template <typename T>
__aicore__ inline void Tril<T>:: NaivePath()
{
int32_t cnt = 0;
for (int32_t i = 0; i < this->matrixNum; i++)
{
for (int32_t j = 0; j < this->rowLength; j++)
{
int32_t k = 0;
while (k < this->columnLength && k - j <= this->diagVal)
{
DTYPE_X curr = xGm.GetValue(cnt);
yGm.SetValue(cnt, curr);
k++;
cnt++;
}
while (k < this->columnLength)
{
yGm.SetValue(cnt, (DTYPE_X)0);
k++;
cnt++;
}
}
}
}
template <typename T>
__aicore__ inline void Tril<T>::FastPath()
{
uint32_t GmOffset = 0;
int32_t init_row = 0;
for (int num = 0; num < this->matrixNum; num++)
{
uint32_t calLength = this->initLength;
if (this->diagVal <= 0)
{
init_row = 1 - diagVal;
}
for (int32_t i = 0; i < this->loopCnt - 1; i++)
{
CopyIn(GmOffset, this->fullTileLength);
Compute(this->fullCnt, calLength, init_row);
CopyOut(GmOffset, this->fullTileLength);
if (init_row > 0)
{
init_row -= this->fullRowInc;
if (init_row < 0)
{
calLength -= init_row;
init_row = 0;
}
}
else
{
calLength += this->fullRowInc;
}
GmOffset += this->fullTileLength;
}
CopyIn(GmOffset, this->lastTileLength);
Compute(this->lastCnt, calLength, init_row);
CopyOut(GmOffset, this->lastTileLength);
GmOffset += this->lastTileLength;
}
}
template <typename T>
__aicore__ inline void Tril<T>::AllZero(uint32_t tileLength)
{
auto xLocal = inQueueX.DeQue<DTYPE_X>();
auto yLocal = outQueueY.AllocTensor<DTYPE_X>();
AscendC::Sub(yLocal, xLocal, xLocal, tileLength);
outQueueY.EnQue(yLocal);
inQueueX.FreeTensor(xLocal);
}
template <typename T>
__aicore__ inline void Tril<T>::CopyIn(uint32_t GmOffset, uint32_t tileLength)
{
auto xLocal = inQueueX.AllocTensor<DTYPE_X>();
AscendC::DataCopy(xLocal, xGm[GmOffset], tileLength);
inQueueX.EnQue(xLocal);
}
template <typename T>
__aicore__ inline void Tril<T>::CopyOut(uint32_t GmOffset, uint32_t tileLength)
{
auto yLocal = outQueueY.DeQue<DTYPE_X>();
AscendC::DataCopy(yGm[GmOffset], yLocal, tileLength);
outQueueY.FreeTensor(yLocal);
}
template <typename T>
__aicore__ inline void Tril<T>::Compute(int32_t cnt, uint32_t initLength, int32_t adjust)
{
auto xLocal = inQueueX.DeQue<DTYPE_X>();
auto yLocal = outQueueY.AllocTensor<DTYPE_X>();
uint32_t localOffset = 0;
uint32_t currLength = initLength;
DTYPE_X scalarZero = 0;
uint64_t mask[2] = {UINT64_MAX, UINT64_MAX};
AscendC::Adds(yLocal, xLocal, scalarZero, mask, this->repeatTimes * cnt, {1, 1, 8, 8});
for (int32_t i = 0; i < adjust; i++)
{
AscendC::Sub(yLocal[localOffset], xLocal[localOffset], xLocal[localOffset], currLength);
currLength--;
localOffset += this->columnLength;
}
outQueueY.EnQue(yLocal);
inQueueX.FreeTensor(xLocal);
}
template <typename T>
__aicore__ inline void Tril<T>::Process()
{
if (this->key == keyOne)
{
NaivePath();
}
else if (this->key == keyTwo)
{
SheerDup();
}
else if (this->key == keyThree)
{
SheerZero();
}
else if (key == keyFour)
{
FastPath();
}
}
}
#endif