* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file fa_def.cpp
* \brief
*/
#include "operator/models/llama/llama_def.h"
namespace npu::tile_fwk {
static constexpr int NUM_64 = 64;
static constexpr int NUM_128 = 128;
Tensor FlashAttentionNew(
const Tensor& q, const Tensor& k, const Tensor& v, const Tensor& m, const Tensor& l, const AttentionDims& atDims)
{
(void)m;
(void)l;
int dim0 = q.GetShape()[0];
int dim1 = q.GetShape()[1];
int b = atDims.b;
int n = atDims.n;
int s = dim0 / b;
int d = dim1 / n;
int singleM = atDims.singleM;
int singleN = atDims.singleN;
int s1Loop = s / singleM;
int s2Loop = s / singleN;
std::cout << "FlashAttention, B, N, S, D --------" << b << "," << n << "," << s << "," << d << std::endl;
std::cout << "s1Loop, s2Loop -------" << s1Loop << "," << s2Loop << "," << std::endl;
auto bns = atDims.b * atDims.n * atDims.s;
std::vector<int64_t> shapeReduce = {bns, 1};
std::vector<float> max(bns, 0);
std::vector<float> sum(bns, 0);
std::map<std::vector<int64_t>, Tensor> lastOi;
std::map<std::vector<int64_t>, Tensor> lastMi;
std::map<std::vector<int64_t>, Tensor> lastLi;
Tensor result;
for (int bIdx = 0; bIdx < b; bIdx++) {
for (int nIdx = 0; nIdx < n; nIdx++) {
for (int s2Idx = 0; s2Idx < s2Loop; s2Idx++) {
auto kj = View(k, {singleN, d}, {bIdx * s + s2Idx * singleN, nIdx * d});
auto vj = View(v, {singleN, d}, {bIdx * s + s2Idx * singleN, nIdx * d});
for (int s1Idx = 0; s1Idx < s1Loop; s1Idx++) {
auto qi = View(q, {singleM, d}, {bIdx * s + s1Idx * singleM, nIdx * d});
std::vector<int64_t> oiOffset = {bIdx * s + s1Idx * singleM, nIdx * d};
std::vector<int64_t> liOffset = {(bIdx * n + nIdx) * s + s1Idx * singleM, 0};
std::vector<int64_t> miOffset = {(bIdx * n + nIdx) * s + s1Idx * singleM, 0};
auto sij = Matrix::Matmul(DataType::DT_FP32, qi, kj, false, true);
auto tildaMij = Amax(sij, -1, true);
auto tsub = Sub(sij, tildaMij);
auto tildaPij = Exp(tsub);
auto tildaPijF16 = Cast(tildaPij, DataType::DT_FP16);
auto tildaLij = Sum(tildaPij, -1, true);
if (!s2Idx) {
auto oiTmp = Matrix::Matmul(DataType::DT_FP32, tildaPijF16, vj, false, false);
if (s2Loop == 1) {
auto liExpand = Reciprocal(tildaLij);
lastOi[oiOffset] = Mul(oiTmp, liExpand);
} else {
lastOi[oiOffset] = oiTmp;
}
lastLi[liOffset] = tildaLij;
lastMi[miOffset] = tildaMij;
continue;
}
ASSERT(lastOi.count(oiOffset) > 0);
ASSERT(lastLi.count(liOffset) > 0);
ASSERT(lastMi.count(miOffset) > 0);
auto oi = lastOi[oiOffset];
auto li = lastLi[liOffset];
auto mi = lastMi[miOffset];
auto miNew = Maximum(mi, tildaMij);
auto t1 = Sub(mi, miNew);
auto t2 = Exp(t1);
auto t3 = Sub(tildaMij, miNew);
auto t4 = Exp(t3);
auto t5 = Mul(t4, tildaLij);
auto t6 = Mul(t2, li);
auto liNew = Add(t6, t5);
auto q3 = Mul(oi, t2);
auto q1 = Matrix::Matmul(DataType::DT_FP32, tildaPijF16, vj, false, false);
auto q2 = Mul(q1, t4);
auto oiTmp = Add(q3, q2);
if (s2Idx == s2Loop - 1) {
lastOi[oiOffset] = Mul(oiTmp, Reciprocal(liNew));
} else {
lastOi[oiOffset] = oiTmp;
}
lastLi[liOffset] = liNew;
lastMi[miOffset] = miNew;
}
}
}
}
std::vector<std::pair<Tensor, std::vector<int64_t>>> aggregation;
for (auto& [offset, tensor] : lastOi) {
aggregation.emplace_back(tensor, offset);
}
result = Assemble(aggregation);
assert(result.GetShape()[0] == b * s);
assert(result.GetShape()[1] == n * d);
return result;
}
}