* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file layout.h
* \brief
*/
#ifndef TILEOP_UTILS_LAYOUT_H
#define TILEOP_UTILS_LAYOUT_H
#include "tuple.h"
#include "../tileop_common.h"
constexpr size_t DIM_1ST = 0;
constexpr size_t DIM_2ND = 1;
constexpr size_t DIM_3RD = 2;
constexpr size_t DIM_4TH = 3;
constexpr size_t DIM_5TH = 4;
constexpr size_t DIM_6TH = 5;
constexpr size_t MAX_DIMS = 5;
namespace TileOp {
template <typename CoordType, typename ShapeType, typename StrideType>
__aicore__ inline constexpr auto Crd2Idx(const CoordType& coord, const ShapeType& shape, const StrideType& stride);
template <typename... Shapes>
using Shape = Std::tuple<Shapes...>;
template <typename... Strides>
using Stride = Std::tuple<Strides...>;
template <typename... TileShapes>
using TileShape = Std::tuple<TileShapes...>;
template <typename... Coords>
using Coord = Std::tuple<Coords...>;
template <typename... LastUses>
using LastUse = Std::tuple<LastUses...>;
template <typename Tuple>
constexpr bool IsConstantTuple = Std::IsIntegralConstantV<typename Std::tuple_element<0, Tuple>::type>;
template <typename... Ts>
__aicore__ inline constexpr Shape<Ts...> MakeShape(const Ts&... t)
{
return {t...};
}
template <typename... Ts>
__aicore__ inline constexpr Stride<Ts...> MakeStride(const Ts&... t)
{
return {t...};
}
template <typename... Ts>
__aicore__ inline constexpr TileShape<Ts...> MakeTileShape(const Ts&... t)
{
return {t...};
}
template <typename Tuple, size_t index, size_t expect_size = Std::tuple_size<Tuple>::value, size_t default_value = 1>
__aicore__ inline constexpr size_t GetTupleElement()
{
static_assert(index < expect_size, "The index of tuple is out of range.");
constexpr auto size = Std::tuple_size<Tuple>::value;
if constexpr (size < expect_size && index < (expect_size - size)) {
return default_value;
} else {
return Std::tuple_element<index + size - expect_size, Tuple>::type::value;
}
}
template <typename Tuple, size_t index, size_t expect_size = Std::tuple_size<Tuple>::value, size_t default_value = 1>
__aicore__ inline constexpr size_t GetTupleElement(const Tuple& t)
{
static_assert(index < expect_size, "The index of tuple is out of range.");
constexpr auto size = Std::tuple_size<Tuple>::value;
if constexpr (size < expect_size && index < (expect_size - size)) {
return default_value;
} else {
return Std::get<index + size - expect_size>(t);
}
}
template <typename ShapeType, typename StrideType, typename TileShapeType>
struct Layout : private Std::tuple<ShapeType, StrideType, TileShapeType> {
using Shape = ShapeType;
using Stride = StrideType;
using TileShape = TileShapeType;
__aicore__ inline constexpr Layout(
const ShapeType& shape = {}, const StrideType& stride = {}, const TileShapeType& tileShape = {})
: Std::tuple<ShapeType, StrideType, TileShapeType>(shape, stride, tileShape)
{
static_assert(
Std::is_tuple_v<ShapeType> && Std::is_tuple_v<StrideType> && Std::is_tuple_v<TileShapeType>,
"Shape, Stride or TileShape is not tuple!");
}
__aicore__ inline constexpr decltype(auto) layout() { return *this; }
__aicore__ inline constexpr decltype(auto) layout() const { return *this; }
template <size_t... I>
__aicore__ inline constexpr decltype(auto) GetShape()
{
return GetValue<0, I...>(static_cast<Std::tuple<ShapeType, StrideType, TileShapeType>&>(*this));
}
template <size_t... I>
__aicore__ inline constexpr decltype(auto) GetShape() const
{
return GetValue<0, I...>(static_cast<const Std::tuple<ShapeType, StrideType, TileShapeType>&>(*this));
}
template <size_t index, size_t expect_size = Std::tuple_size<ShapeType>::value>
__aicore__ inline constexpr decltype(auto) GetShapeDim() const
{
if constexpr (IsConstantTuple<ShapeType> == true) {
return GetTupleElement<ShapeType, index, expect_size, 1>();
} else {
return GetTupleElement<ShapeType, index, expect_size, 1>(GetShape());
}
}
template <size_t... I>
__aicore__ inline constexpr decltype(auto) GetStride()
{
return GetValue<1, I...>(static_cast<Std::tuple<ShapeType, StrideType, TileShapeType>&>(*this));
}
template <size_t... I>
__aicore__ inline constexpr decltype(auto) GetStride() const
{
return GetValue<1, I...>(static_cast<const Std::tuple<ShapeType, StrideType, TileShapeType>&>(*this));
}
template <size_t index, size_t expect_size = Std::tuple_size<StrideType>::value>
__aicore__ inline constexpr decltype(auto) GetStrideDim() const
{
if constexpr (IsConstantTuple<StrideType> == true) {
return GetTupleElement<StrideType, index, expect_size, 0>();
} else {
return GetTupleElement<StrideType, index, expect_size, 0>(GetStride());
}
}
template <size_t... I>
__aicore__ inline constexpr decltype(auto) GetTileShape()
{
return GetValue<2, I...>(static_cast<Std::tuple<ShapeType, StrideType, TileShapeType>&>(*this));
}
template <size_t... I>
__aicore__ inline constexpr decltype(auto) GetTileShape() const
{
return GetValue<2, I...>(static_cast<const Std::tuple<ShapeType, StrideType, TileShapeType>&>(*this));
}
template <size_t index, size_t expect_size = Std::tuple_size<TileShapeType>::value>
__aicore__ inline constexpr decltype(auto) GetTileShapeDim() const
{
if constexpr (IsConstantTuple<TileShapeType> == true) {
return GetTupleElement<TileShapeType, index, expect_size, 0>();
} else {
return GetTupleElement<TileShapeType, index, expect_size, 0>(GetTileShape());
}
}
__aicore__ inline static constexpr auto IsStaticLayout() { return IsConstantTuple<ShapeType> == true; }
template <typename CoordType>
__aicore__ inline constexpr auto operator()(const CoordType& coord) const
{
return Crd2Idx(coord, GetShape(), GetStride());
}
template <typename Tuple, size_t expect_size = Std::tuple_size<Tuple>::value>
__aicore__ inline constexpr decltype(auto) GetGmOffset(const Tuple& coordinate) const
{
auto s0 = GetStrideDim<DIM_1ST, expect_size>();
auto s1 = GetStrideDim<DIM_2ND, expect_size>();
auto s2 = GetStrideDim<DIM_3RD, expect_size>();
auto s3 = GetStrideDim<DIM_4TH, expect_size>();
auto s4 = GetStrideDim<DIM_5TH, expect_size>();
auto c0 = GetTupleElement<Tuple, DIM_1ST, expect_size, 0>(coordinate);
auto c1 = GetTupleElement<Tuple, DIM_2ND, expect_size, 0>(coordinate);
auto c2 = GetTupleElement<Tuple, DIM_3RD, expect_size, 0>(coordinate);
auto c3 = GetTupleElement<Tuple, DIM_4TH, expect_size, 0>(coordinate);
auto c4 = GetTupleElement<Tuple, DIM_5TH, expect_size, 0>(coordinate);
return s4 * c4 + s3 * c3 + s2 * c2 + s1 * c1 + s0 * c0;
}
private:
template <size_t index, size_t I, size_t... Is, typename Tuple>
__aicore__ inline constexpr decltype(auto) GetValue(const Tuple& t)
{
auto tupleEle = Std::get<index>(t);
return Std::make_tuple(Std::get<I>(tupleEle), Std::get<Is>(tupleEle)...);
}
template <size_t index, size_t I, size_t... Is, typename Tuple>
__aicore__ inline constexpr decltype(auto) GetValue(const Tuple& t) const
{
auto tupleEle = Std::get<index>(t);
return Std::make_tuple(Std::get<I>(tupleEle), Std::get<Is>(tupleEle)...);
}
template <size_t index, typename Tuple>
__aicore__ inline constexpr decltype(auto) GetValue(const Tuple& t)
{
return Std::get<index>(t);
}
template <size_t index, typename Tuple>
__aicore__ inline constexpr decltype(auto) GetValue(const Tuple& t) const
{
return Std::get<index>(t);
}
};
template <typename ShapeType, typename StrideType, typename TileShapeType>
__aicore__ inline constexpr auto MakeLayout(
const ShapeType& shape, const StrideType& stride, const TileShapeType& tileShape)
{
return Layout<ShapeType, StrideType, TileShapeType>(shape, stride, tileShape);
}
template <typename T>
struct is_layout : Std::false_type {};
template <typename ShapeType, typename StrideType, typename TileShapeType>
struct is_layout<Layout<ShapeType, StrideType, TileShapeType>> : Std::true_type {};
template <typename T>
constexpr bool is_layout_v = is_layout<T>::value;
template <typename StrideType>
__aicore__ inline constexpr auto GetOuterStride()
{
return Std::tuple_element<0, StrideType>::type::value;
}
template <typename T, size_t index, size_t expect_size = Std::tuple_size<typename T::Shape>::value>
__aicore__ inline constexpr size_t GetTensorShapeDim()
{
return GetTupleElement<typename T::Shape, index, expect_size, 1>();
}
template <typename T, size_t index, size_t expect_size = Std::tuple_size<typename T::Stride>::value>
__aicore__ inline constexpr size_t GetTensorStrideDim()
{
return GetTupleElement<typename T::Stride, index, expect_size, 0>();
}
template <typename T, size_t index, size_t expect_size = Std::tuple_size<typename T::TileShape>::value>
__aicore__ inline constexpr size_t GetTensorTileShapeDim()
{
return GetTupleElement<typename T::TileShape, index, expect_size, 1>();
}
template <int leftAxis, int rightAxis, typename Shape>
__aicore__ inline constexpr size_t GetAnyAxisMergeResult()
{
constexpr auto n0 = []() constexpr {
if constexpr (leftAxis <= 1 && 1 <= rightAxis) {
return Std::tuple_element<DIM_1ST, Shape>::type::value;
} else {
return 1;
}
}();
constexpr auto n1 = []() constexpr {
if constexpr (leftAxis <= 2 && 2 <= rightAxis) {
return Std::tuple_element<DIM_2ND, Shape>::type::value;
} else {
return 1;
}
}();
constexpr auto n2 = []() constexpr {
if constexpr (leftAxis <= 3 && 3 <= rightAxis) {
return Std::tuple_element<DIM_3RD, Shape>::type::value;
} else {
return 1;
}
}();
constexpr auto n3 = []() constexpr {
if constexpr (leftAxis <= 4 && 4 <= rightAxis) {
return Std::tuple_element<DIM_4TH, Shape>::type::value;
} else {
return 1;
}
}();
constexpr auto n4 = []() constexpr {
if constexpr (leftAxis <= 5 && 5 <= rightAxis) {
return Std::tuple_element<DIM_5TH, Shape>::type::value;
} else {
return 1;
}
}();
return n0 * n1 * n2 * n3 * n4;
}
template <size_t shapeSize, typename Shape>
__aicore__ inline constexpr size_t GetNonFirstAxisMergeResult()
{
return GetAnyAxisMergeResult<2, shapeSize, Shape>();
}
template <size_t shapeSize, typename Shape>
__aicore__ inline constexpr size_t GetOutterAxisMergeResult()
{
constexpr size_t expectSize = 5;
constexpr auto n0 = GetTupleElement<Shape, DIM_1ST, expectSize, 1>();
constexpr auto n1 = GetTupleElement<Shape, DIM_2ND, expectSize, 1>();
constexpr auto n2 = GetTupleElement<Shape, DIM_3RD, expectSize, 1>();
constexpr auto n3 = GetTupleElement<Shape, DIM_4TH, expectSize, 1>();
return n0 * n1 * n2 * n3;
}
template <int... BrcOperands>
__aicore__ inline constexpr bool HasBrcOperand()
{
if constexpr (sizeof...(BrcOperands) == 0) {
return false;
}
return ((BrcOperands != 0) || ...);
}
template <typename T0>
__aicore__ inline constexpr bool JudgeValidShapeEqualTileShape()
{
if constexpr (T0::IsStaticLayout()) {
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
if constexpr (shapeSize == 1 || shapeSize == 2) {
return true;
}
constexpr auto outterStride = GetOuterStride<typename T0::Stride>();
constexpr auto nonFirstAxis = GetNonFirstAxisMergeResult<shapeSize, typename T0::Shape>();
if constexpr (outterStride == nonFirstAxis) {
return true;
}
return false;
}
return false;
}
template <typename T0>
__aicore__ constexpr bool IsConstContinous()
{
return JudgeValidShapeEqualTileShape<T0>();
}
template <typename T0, typename T1, typename... Args>
__aicore__ constexpr bool IsConstContinous()
{
if constexpr (!JudgeValidShapeEqualTileShape<T0>()) {
return false;
}
return IsConstContinous<T1, Args...>();
}
}
using Shape1Dim = TileOp::Shape<size_t>;
using Shape2Dim = TileOp::Shape<size_t, size_t>;
using Shape3Dim = TileOp::Shape<size_t, size_t, size_t>;
using Shape4Dim = TileOp::Shape<size_t, size_t, size_t, size_t>;
using Shape5Dim = TileOp::Shape<size_t, size_t, size_t, size_t, size_t>;
using Shape6Dim = TileOp::Shape<size_t, size_t, size_t, size_t, size_t, size_t>;
using Stride1Dim = TileOp::Stride<size_t>;
using Stride2Dim = TileOp::Stride<size_t, size_t>;
using Stride3Dim = TileOp::Stride<size_t, size_t, size_t>;
using Stride4Dim = TileOp::Stride<size_t, size_t, size_t, size_t>;
using Stride5Dim = TileOp::Stride<size_t, size_t, size_t, size_t, size_t>;
using Stride6Dim = TileOp::Stride<size_t, size_t, size_t, size_t, size_t, size_t>;
using Coord1Dim = TileOp::Coord<size_t>;
using Coord2Dim = TileOp::Coord<size_t, size_t>;
using Coord3Dim = TileOp::Coord<size_t, size_t, size_t>;
using Coord4Dim = TileOp::Coord<size_t, size_t, size_t, size_t>;
using Coord5Dim = TileOp::Coord<size_t, size_t, size_t, size_t, size_t>;
using Coord6Dim = TileOp::Coord<size_t, size_t, size_t, size_t, size_t, size_t>;
using DynLayout1Dim = TileOp::Layout<Shape1Dim, Stride1Dim, TileOp::TileShape<size_t>>;
using DynLayout2Dim = TileOp::Layout<Shape2Dim, Stride2Dim, TileOp::TileShape<size_t, size_t>>;
using DynLayout3Dim = TileOp::Layout<Shape3Dim, Stride3Dim, TileOp::TileShape<size_t, size_t, size_t>>;
using DynLayout4Dim = TileOp::Layout<Shape4Dim, Stride4Dim, TileOp::TileShape<size_t, size_t, size_t, size_t>>;
using DynLayout5Dim = TileOp::Layout<Shape5Dim, Stride5Dim, TileOp::TileShape<size_t, size_t, size_t, size_t, size_t>>;
using DynLayout6Dim =
TileOp::Layout<Shape6Dim, Stride6Dim, TileOp::TileShape<size_t, size_t, size_t, size_t, size_t, size_t>>;
template <size_t TileW = 0>
using LastUse1Dim = TileOp::LastUse<Std::Int<TileW>>;
template <size_t TileH = 0, size_t TileW = 0>
using LastUse2Dim = TileOp::LastUse<Std::Int<TileH>, Std::Int<TileW>>;
template <size_t TileD = 0, size_t TileH = 0, size_t TileW = 0>
using LastUse3Dim = TileOp::LastUse<Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>;
template <size_t TileN = 0, size_t TileD = 0, size_t TileH = 0, size_t TileW = 0>
using LastUse4Dim = TileOp::LastUse<Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>;
template <size_t TileS = 0, size_t TileN = 0, size_t TileD = 0, size_t TileH = 0, size_t TileW = 0>
using LastUse5Dim =
TileOp::LastUse<Std::Int<TileS>, Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>;
template <size_t TileB = 0, size_t TileS = 0, size_t TileN = 0, size_t TileD = 0, size_t TileH = 0, size_t TileW = 0>
using LastUse6Dim = TileOp::LastUse<
Std::Int<TileB>, Std::Int<TileS>, Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>;
template <size_t TileW>
using LocalLayout1Dim = TileOp::Layout<Shape1Dim, TileOp::Stride<Std::Int<1>>, TileOp::TileShape<Std::Int<TileW>>>;
template <size_t TileH, size_t TileW>
using LocalLayout2Dim = TileOp::Layout<
Shape2Dim, TileOp::Stride<Std::Int<TileW>, Std::Int<1>>, TileOp::TileShape<Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t TileD, size_t TileH, size_t TileW>
using LocalLayout3Dim = TileOp::Layout<
Shape3Dim, TileOp::Stride<Std::Int<TileH * TileW>, Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t TileN, size_t TileD, size_t TileH, size_t TileW>
using LocalLayout4Dim = TileOp::Layout<
Shape4Dim, TileOp::Stride<Std::Int<TileD * TileH * TileW>, Std::Int<TileH * TileW>, Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t TileS, size_t TileN, size_t TileD, size_t TileH, size_t TileW>
using LocalLayout5Dim = TileOp::Layout<
Shape5Dim,
TileOp::Stride<
Std::Int<TileN * TileD * TileH * TileW>, Std::Int<TileD * TileH * TileW>, Std::Int<TileH * TileW>,
Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileS>, Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t TileB, size_t TileS, size_t TileN, size_t TileD, size_t TileH, size_t TileW>
using LocalLayout6Dim = TileOp::Layout<
Shape6Dim,
TileOp::Stride<
Std::Int<TileS * TileN * TileD * TileH * TileW>, Std::Int<TileN * TileD * TileH * TileW>,
Std::Int<TileD * TileH * TileW>, Std::Int<TileH * TileW>, Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<
Std::Int<TileB>, Std::Int<TileS>, Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t W, size_t TileW>
using StaticLayout1Dim =
TileOp::Layout<TileOp::Shape<Std::Int<W>>, TileOp::Stride<Std::Int<1>>, TileOp::TileShape<Std::Int<TileW>>>;
template <size_t H, size_t W, size_t TileH, size_t TileW>
using StaticLayout2Dim = TileOp::Layout<
TileOp::Shape<Std::Int<H>, Std::Int<W>>, TileOp::Stride<Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t D, size_t H, size_t W, size_t TileD, size_t TileH, size_t TileW>
using StaticLayout3Dim = TileOp::Layout<
TileOp::Shape<Std::Int<D>, Std::Int<H>, Std::Int<W>>,
TileOp::Stride<Std::Int<TileH * TileW>, Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <size_t N, size_t D, size_t H, size_t W, size_t TileN, size_t TileD, size_t TileH, size_t TileW>
using StaticLayout4Dim = TileOp::Layout<
TileOp::Shape<Std::Int<N>, Std::Int<D>, Std::Int<H>, Std::Int<W>>,
TileOp::Stride<Std::Int<TileD * TileH * TileW>, Std::Int<TileH * TileW>, Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <
size_t S, size_t N, size_t D, size_t H, size_t W, size_t TileS, size_t TileN, size_t TileD, size_t TileH,
size_t TileW>
using StaticLayout5Dim = TileOp::Layout<
TileOp::Shape<Std::Int<S>, Std::Int<N>, Std::Int<D>, Std::Int<H>, Std::Int<W>>,
TileOp::Stride<
Std::Int<TileN * TileD * TileH * TileW>, Std::Int<TileD * TileH * TileW>, Std::Int<TileH * TileW>,
Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<Std::Int<TileS>, Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
template <
size_t B, size_t S, size_t N, size_t D, size_t H, size_t W, size_t TileB, size_t TileS, size_t TileN, size_t TileD,
size_t TileH, size_t TileW>
using StaticLayout6Dim = TileOp::Layout<
TileOp::Shape<Std::Int<B>, Std::Int<S>, Std::Int<N>, Std::Int<D>, Std::Int<H>, Std::Int<W>>,
TileOp::Stride<
Std::Int<TileS * TileN * TileD * TileH * TileW>, Std::Int<TileN * TileD * TileH * TileW>,
Std::Int<TileD * TileH * TileW>, Std::Int<TileH * TileW>, Std::Int<TileW>, Std::Int<1>>,
TileOp::TileShape<
Std::Int<TileB>, Std::Int<TileS>, Std::Int<TileN>, Std::Int<TileD>, Std::Int<TileH>, Std::Int<TileW>>>;
#ifdef __DAV_V310
using LoopVar = uint16_t;
#else
using LoopVar = size_t;
#endif
#endif