* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mmad_impl.h
* \brief General Matrix Multiplication Interface Implementation (Atlas A3, Ascend 950PR/Ascend 950DT)
*/
#ifndef TILEOP_TILE_OPERATOR_MMAD_IMPL__H
#define TILEOP_TILE_OPERATOR_MMAD_IMPL__H
#include "cube_utils.h"
template <bool isZeroC, TransMode transMode, bool kAlignFlag, typename TileAcc, typename TileLeft, typename TileRight>
INLINE void TMatmulImpl(TileAcc& c, TileLeft& a, TileRight& b)
{
int64_t validM = GetShape<0>(a);
int64_t validN = GetShape<1>(b);
int64_t validK = GetShape<1>(a);
if (validM == 0 || validK == 0 || validN == 0) {
return;
}
constexpr uint64_t shapeSizeA = Std::tuple_size<typename TileLeft::Shape>::value;
constexpr uint64_t shapeSizeB = Std::tuple_size<typename TileRight::Shape>::value;
constexpr uint64_t shapeSizeC = Std::tuple_size<typename TileAcc::Shape>::value;
constexpr auto staticL0AH = Std::tuple_element<shapeSizeA - SHAPE_DIM2, typename TileLeft::TileShape>::type::value;
constexpr auto staticL0AW = Std::tuple_element<shapeSizeA - 1, typename TileLeft::TileShape>::type::value;
constexpr auto staticL0BH = Std::tuple_element<shapeSizeB - SHAPE_DIM2, typename TileRight::TileShape>::type::value;
constexpr auto staticL0BW = Std::tuple_element<shapeSizeB - 1, typename TileRight::TileShape>::type::value;
constexpr auto staticL0CH = Std::tuple_element<shapeSizeC - SHAPE_DIM2, typename TileAcc::TileShape>::type::value;
constexpr auto staticL0CW = Std::tuple_element<shapeSizeC - 1, typename TileAcc::TileShape>::type::value;
using tileL0ATensor = pto::TileLeft<typename TileLeft::Type, staticL0AH, staticL0AW, -1, -1>;
using tileL0BTensor = pto::TileRight<typename TileRight::Type, staticL0BH, staticL0BW, -1, -1>;
using tileL0CTensor = pto::TileAcc<typename TileAcc::Type, staticL0CH, staticL0CW, -1, -1>;
validM = (validM + BLOCK_CUBE_M_N - 1) / BLOCK_CUBE_M_N * BLOCK_CUBE_M_N;
tileL0ATensor l0a(validM, validK);
tileL0BTensor l0b(validK, validN);
tileL0CTensor l0c(validM, validN);
if constexpr (std::is_same<typename tileL0ATensor::DType, float>::value) {
l0a.ResetMadMode();
l0a.SetKAligned(kAlignFlag);
}
if constexpr (transMode != TransMode::CAST_NONE) {
l0a.SetMadTF32Mode(static_cast<pto::RoundMode>(transMode));
}
pto::TASSIGN(l0a, static_cast<uint64_t>(a.GetAddr()));
pto::TASSIGN(l0b, static_cast<uint64_t>(b.GetAddr()));
pto::TASSIGN(l0c, static_cast<uint64_t>(c.GetAddr()));
if constexpr (!isZeroC) {
pto::TMATMUL(l0c, l0a, l0b);
} else {
pto::TMATMUL_ACC(l0c, l0c, l0a, l0b);
}
if constexpr (transMode != TransMode::CAST_NONE) {
l0a.ResetMadMode();
}
}
template <TransMode transMode, typename TileAcc, typename TileLeft, typename TileRight, typename TileBias>
INLINE void TMatmulImpl(TileAcc& c, TileLeft& a, TileRight& b, TileBias& bias)
{
int64_t validM = GetShape<0>(a);
int64_t validN = GetShape<1>(b);
int64_t validK = GetShape<1>(a);
if (validM == 0 || validK == 0 || validN == 0) {
return;
}
constexpr uint64_t shapeSizeA = Std::tuple_size<typename TileLeft::Shape>::value;
constexpr uint64_t shapeSizeC = Std::tuple_size<typename TileAcc::Shape>::value;
constexpr uint64_t shapeSizeB = Std::tuple_size<typename TileRight::Shape>::value;
constexpr auto staticL0AW = Std::tuple_element<shapeSizeA - 1, typename TileLeft::TileShape>::type::value;
constexpr auto staticL0AH = Std::tuple_element<shapeSizeA - SHAPE_DIM2, typename TileLeft::TileShape>::type::value;
constexpr auto staticL0BH = Std::tuple_element<shapeSizeB - SHAPE_DIM2, typename TileRight::TileShape>::type::value;
constexpr auto staticL0BW = Std::tuple_element<shapeSizeB - 1, typename TileRight::TileShape>::type::value;
constexpr auto staticL0CW = Std::tuple_element<shapeSizeC - 1, typename TileAcc::TileShape>::type::value;
constexpr auto staticL0CH = Std::tuple_element<shapeSizeC - SHAPE_DIM2, typename TileAcc::TileShape>::type::value;
using tileL0ATensor = pto::TileLeft<typename TileLeft::Type, staticL0AH, staticL0AW, -1, -1>;
using tileL0BTensor = pto::TileRight<typename TileRight::Type, staticL0BH, staticL0BW, -1, -1>;
using tileL0CTensor = pto::TileAcc<typename TileAcc::Type, staticL0CH, staticL0CW, -1, -1>;
using tileBiasTensor =
pto::Tile<pto::TileType::Bias, typename TileBias::Type, 1, staticL0BW, pto::BLayout::RowMajor, -1, -1>;
validM = (validM + BLOCK_CUBE_M_N - 1) / BLOCK_CUBE_M_N * BLOCK_CUBE_M_N;
tileL0ATensor l0a(validM, validK);
tileL0BTensor l0b(validK, validN);
tileL0CTensor l0c(validM, validN);
tileBiasTensor biasT(1, validN);
if constexpr (std::is_same<typename tileL0ATensor::DType, float>::value) {
l0a.ResetMadMode();
}
if constexpr (transMode != TransMode::CAST_NONE) {
l0a.SetMadTF32Mode(static_cast<pto::RoundMode>(transMode));
}
pto::TASSIGN(l0a, static_cast<uint64_t>(a.GetAddr()));
pto::TASSIGN(l0b, static_cast<uint64_t>(b.GetAddr()));
pto::TASSIGN(l0c, static_cast<uint64_t>(c.GetAddr()));
pto::TASSIGN(biasT, static_cast<uint64_t>(bias.GetAddr()));
pto::TMATMUL_BIAS(l0c, l0a, l0b, biasT);
if constexpr (transMode != TransMode::CAST_NONE) {
l0a.ResetMadMode();
}
}
#endif