* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tile_shape.cpp
* \brief
*/
#include "tilefwk/tile_shape.h"
#include "interface/configs/config_manager_ng.h"
using namespace npu::tile_fwk;
bool VecTile::valid() const
{
return std::all_of(tile.begin(), tile.end(), [](int x) { return x > 0; }) && tile.size() > 0;
}
bool CubeTile::valid() const
{
return std::all_of(m.begin(), m.end(), [](int64_t x) { return x > 0; }) &&
std::all_of(k.begin(), k.end(), [](int64_t x) { return x > 0; }) &&
std::all_of(n.begin(), n.end(), [](int64_t x) { return x > 0; });
}
std::string CubeTile::ToString() const
{
std::stringstream ss;
ss << "CubeTile: " << '{' << "m: {" << m[0] << ", " << m[1] << '}' << ", "
<< "k: {" << k[0] << ", " << k[1] << ", " << k[0x2] << '}' << ", "
<< "n: {" << n[0] << ", " << n[1] << '}' << ", "
<< "enableSplitK: " << enableSplitK << "}; ";
return ss.str();
}
bool ConvTile::valid() const
{
if (tileL1Info.tileHin <= 0 || tileL1Info.tileHout <= 0 || tileL1Info.tileWin <= 0 || tileL1Info.tileWout <= 0 ||
tileL1Info.tileCinFmap <= 0 || tileL1Info.tileCinWeight <= 0 || tileL1Info.tileN <= 0 ||
tileL1Info.tileBatch <= 0) {
return false;
}
if (setL0Tile) {
if (tileL0Info.tileH <= 0 || tileL0Info.tileW <= 0 || tileL0Info.tileK <= 0 || tileL0Info.tileN <= 0) {
return false;
}
}
return true;
}
std::string ConvTile::ToString() const
{
std::stringstream ss;
ss << "ConvTile: " << '{' << "tileL1Info: {"
<< "tileHin: " << tileL1Info.tileHin << ", "
<< "tileHout: " << tileL1Info.tileHout << ", "
<< "tileWin: " << tileL1Info.tileWin << ", "
<< "tileWout: " << tileL1Info.tileWout << ", "
<< "tileCinFmap: " << tileL1Info.tileCinFmap << ", "
<< "tileCinWeight: " << tileL1Info.tileCinWeight << ", "
<< "tileN: " << tileL1Info.tileN << ", "
<< "tileBatch: " << tileL1Info.tileBatch << "}, "
<< "tileL0Info: {"
<< "tileH: " << tileL0Info.tileH << ", "
<< "tileW: " << tileL0Info.tileW << ", "
<< "tileK: " << tileL0Info.tileK << ", "
<< "tileN: " << tileL0Info.tileN << "}, "
<< "setL0Tile: " << (setL0Tile ? "true" : "false") << "};";
return ss.str();
}
bool DistTile::valid() const
{
return std::all_of(row.begin(), row.end(), [](int x) { return x > 0; }) &&
std::all_of(col.begin(), col.end(), [](int x) { return x > 0; }) &&
std::all_of(rank.begin(), rank.end(), [](int x) { return x > 0; }) && rankId >= 0;
}
std::string DistTile::ToString() const
{
std::stringstream ss;
ss << "DistTile: " << '{' << "row: {" << row[0] << ", " << row[1] << ", " << row[0x2] << "}, "
<< "col: {" << col[0] << ", " << col[1] << ", " << col[0x2] << "}, "
<< "rank: {" << rank[0] << ", " << rank[1] << ", " << rank[0x2] << "}, "
<< "rankId: " << rankId << "}; ";
return ss.str();
}
TileShape::TileShape() : vecTile{}, cubeTile{}, convTile{}, distTile{}, matrixSize{} {}
TileShape::TileShape(
const std::vector<int64_t>& vTile, const CubeTile& cTile, const ConvTile& cvTile, const DistTile& dTile,
const std::vector<int64_t>& mSize)
: vecTile{vTile}, cubeTile(cTile), convTile(cvTile), distTile(dTile), matrixSize(mSize)
{}
TileShape& TileShape::Current()
{
static TileShape instance;
instance = ConfigManagerNg::CurrentScope()->GenerateTileShape();
return instance;
}
void TileShape::SetConvTile(const Conv::TileL1Info& tileL1Info, const Conv::TileL0Info& tileL0Info, bool setL0Tile)
{
convTile = {tileL1Info, tileL0Info, setL0Tile};
ConfigManagerNg::CurrentScope()->UpdateValue("conv_tile_shapes", convTile);
}
void TileShape::SetVecTile(const std::vector<int64_t>& tile)
{
vecTile = {tile};
ConfigManagerNg::CurrentScope()->UpdateValue("vec_tile_shapes", tile);
}
void TileShape::SetVecTile(const VecTile& tile)
{
vecTile = tile;
ConfigManagerNg::CurrentScope()->UpdateValue("vec_tile_shapes", tile.tile);
}
void TileShape::SetCubeTile(
const std::array<int64_t, MAX_M_DIM_SIZE>& m, const std::array<int64_t, MAX_K_DIM_SIZE>& k,
const std::array<int64_t, MAX_N_DIM_SIZE>& n, bool enableSplitK)
{
auto nk = k;
if (nk[2] == 0) {
nk[2] = nk[1];
}
cubeTile = {m, nk, n, enableSplitK};
ConfigManagerNg::CurrentScope()->UpdateValue("cube_tile_shapes", cubeTile);
}
void TileShape::SetMatrixSize(const std::vector<int64_t>& size)
{
this->matrixSize = size;
ConfigManagerNg::CurrentScope()->UpdateValue("matrix_size", size);
}
void TileShape::UpdateScopeDistTile() { ConfigManagerNg::CurrentScope()->UpdateValue("dist_tile_shapes", distTile); }
void TileShape::SetDistTile(
const std::array<int, MAX_DIST_DIM_SIZE>& row, const std::array<int, MAX_DIST_DIM_SIZE>& col,
const std::array<int, MAX_DIST_DIM_SIZE>& rank)
{
distTile.row = row;
distTile.col = col;
distTile.rank = rank;
UpdateScopeDistTile();
}
void TileShape::SetDistRankId(int64_t rankId)
{
distTile.rankId = rankId;
UpdateScopeDistTile();
}
void TileShape::SetDistTileCol(const std::array<int, MAX_DIST_DIM_SIZE>& col)
{
distTile.col = col;
UpdateScopeDistTile();
}
void TileShape::SetDistTileRow(const std::array<int, MAX_DIST_DIM_SIZE>& row)
{
distTile.row = row;
UpdateScopeDistTile();
}
void TileShape::SetDistTileRank(const std::array<int, MAX_DIST_DIM_SIZE>& rank)
{
distTile.rank = rank;
UpdateScopeDistTile();
}
std::string TileShape::ToString(TileType type) const
{
std::stringstream ss;
if (type == TileType::VEC || type == TileType::MAX) {
ss << "VecTile: " << '{';
for (size_t i = 0; i < vecTile.tile.size(); ++i) {
if (i != 0)
ss << ", ";
ss << vecTile.tile[i];
}
ss << "}; ";
}
if (type == TileType::CUBE || type == TileType::MAX) {
ss << cubeTile.ToString();
}
if (type == TileType::CONV || type == TileType::MAX) {
ss << convTile.ToString();
}
if (type == TileType::DIST || type == TileType::MAX) {
ss << distTile.ToString();
}
return ss.str();
}