* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file attention_post.cpp
* \brief
*/
#include "attention_post.h"
#include "interface/function/function.h"
#include "tilefwk/tilefwk_op.h"
#include "tilefwk/tensor.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/utils/common.h"
#include "interface/configs/config_manager.h"
namespace npu::tile_fwk {
void PostCompute(Tensor& input, PostTensors& postTensors, const PostTileConfig& tileConfig, Tensor& postOut)
{
assert(
input.GetShape().size() == SHAPE_DIM4 && postTensors.weightUV.GetShape().size() == SHAPE_DIM3 &&
postTensors.weightO.GetShape().size() == SHAPE_DIM2);
auto dtype = input.GetStorage()->Datatype();
auto n = postTensors.weightUV.GetShape()[0];
auto kvLoraRank = postTensors.weightUV.GetShape()[1];
auto vHeadDim = postTensors.weightUV.GetShape()[2];
auto h = postTensors.weightO.GetShape()[1];
int tileB = tileConfig.tileB;
int tileS = tileConfig.tileS;
int tileBS = tileB * tileS;
bool isQuantWUv = postTensors.weightUvScale.GetStorage() != nullptr;
bool isSmoothWUv = postTensors.smoothScalesWUv.GetStorage() != nullptr;
bool isQuantWo = postTensors.weightOScale.GetStorage() != nullptr;
bool isSmoothWo = postTensors.smoothScalesWo.GetStorage() != nullptr;
int b = input.GetShape()[0];
int s = input.GetShape()[1];
SymbolicScalar bLoop = b / tileB;
SymbolicScalar sLoop = s / tileS;
LOOP("POST_LOOP_L0_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bLoop, 1), {}, true)
{
SymbolicScalar bOffset = bIdx * tileB;
LOOP("POST_LOOP_L1_sIdx", FunctionType::DYNAMIC_LOOP, sIdx, LoopRange(0, sLoop, 1))
{
SymbolicScalar sOffset = sIdx * tileS;
std::vector<SymbolicScalar> outOffset = {bOffset, sOffset, 0};
TileShape::Current().SetVecTile({1, 1, 32, kvLoraRank});
auto inputView = View(input, {tileB, tileS, n, kvLoraRank}, {bOffset, sOffset, 0, 0});
config::SetSemanticLabel("postReshape1");
auto inputRes = Reshape(inputView, {tileBS, n, kvLoraRank});
TileShape::Current().SetVecTile({std::min(32, tileBS), 2, kvLoraRank});
config::SetSemanticLabel("postTranspose1");
auto inputTrans = Transpose(inputRes, {0, 1});
config::SetSemanticLabel("postBmm");
int c0 = 16;
int m = (std::min(32, tileBS) + c0 - 1) / c0 * c0;
TileShape::Current().SetCubeTile(
{m, m}, {std::min(256L, kvLoraRank), std::min(512L, kvLoraRank)}, {vHeadDim, vHeadDim});
Tensor bmm;
if (isQuantWUv) {
config::SetSemanticLabel("postQuantWUv");
TileShape::Current().SetVecTile({1, 1, std::min(512L, kvLoraRank)});
std::tuple<Tensor, Tensor> quantRes;
if (isSmoothWUv) {
quantRes = Quant(inputTrans, true, true, postTensors.smoothScalesWUv);
} else {
quantRes = Quant(inputTrans, true, false);
}
auto inputTransQuant = std::get<0>(quantRes);
auto scaleDequant = std::get<1>(quantRes);
auto mm = Matrix::BatchMatmul(DT_INT32, inputTransQuant, postTensors.weightUV);
config::SetSemanticLabel("postDequantWUv");
TileShape::Current().SetVecTile({1, std::min(16, tileBS), std::min(32L, vHeadDim)});
Tensor res = Cast(mm, DataType::DT_FP32);
res = Mul(res, scaleDequant);
res =
Mul(res,
postTensors.weightUvScale);
bmm = Cast(res, dtype, CAST_RINT);
} else {
bmm = Matrix::BatchMatmul(dtype, inputTrans, postTensors.weightUV);
}
config::SetSemanticLabel("postTranspose2");
TileShape::Current().SetVecTile({4, std::min(32, tileBS), vHeadDim});
auto bmmTrans = Transpose(bmm, {0, 1});
config::SetSemanticLabel("postReshape2");
auto bmmRes = Reshape(bmmTrans, {tileBS, n * vHeadDim});
Tensor mmRes;
TileShape::Current().SetCubeTile(
{m, m}, {std::min(512L, n * vHeadDim), std::min(512L, n * vHeadDim)},
{std::min(64L, h), std::min(64L, h)});
if (isQuantWo) {
config::SetSemanticLabel("postQuantWo");
TileShape::Current().SetVecTile({1, n * vHeadDim});
std::tuple<Tensor, Tensor> quantRes;
if (isSmoothWo) {
quantRes = Quant(bmmRes, true, true, postTensors.smoothScalesWo);
} else {
quantRes = Quant(bmmRes, true, false);
}
auto bmmResQuant = std::get<0>(quantRes);
auto scaleDequant = std::get<1>(quantRes);
config::SetSemanticLabel("postMm");
Tensor mm = Matrix::Matmul(DataType::DT_INT32, bmmResQuant, postTensors.weightO);
config::SetSemanticLabel("postDequantWo");
TileShape::Current().SetVecTile({std::min(32, tileBS), std::min(32L, h)});
Tensor res = Cast(mm, DataType::DT_FP32);
res = Mul(res, scaleDequant);
res = Mul(res, postTensors.weightOScale);
mmRes = Cast(res, dtype, CAST_RINT);
} else {
mmRes = Matrix::Matmul(dtype, bmmRes, postTensors.weightO);
}
config::SetSemanticLabel("postReshape3");
auto postOutView = Reshape(mmRes, {tileB, tileS, h});
TileShape::Current().SetVecTile({1, 1, h});
Assemble(postOutView, outOffset, postOut);
}
}
}
void AttentionPostStandalone(Tensor& input, PostTensors& postTensors, const PostTileConfig& tileConfig, Tensor& postOut)
{
FUNCTION(
"POST_MAIN",
{input, postTensors.weightUV, postTensors.weightO, postTensors.weightUvScale, postTensors.smoothScalesWUv,
postTensors.weightOScale, postTensors.smoothScalesWo},
{postOut})
{
PostCompute(input, postTensors, tileConfig, postOut);
}
}
}