* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef TLA_LAYOUT_HPP
#define TLA_LAYOUT_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/arch.hpp"
#include "catlass/numeric_size.hpp"
#include "tla/numeric/integral_constant.hpp"
#include "tla/numeric/math.hpp"
#include "tla/tuple.hpp"
#include "tla/int_tuple.hpp"
#include "catlass/layout/layout.hpp"
namespace tla {
template <class... Shapes>
using Shape = tla::tuple<Shapes...>;
template <class... Strides>
using Stride = tla::tuple<Strides...>;
template <class... Coords>
using Coord = tla::tuple<Coords...>;
template <class... Ts>
CATLASS_HOST_DEVICE constexpr
Shape<Ts...> MakeShape(Ts const&... t) {
return {t...};
}
template <class... Ts>
CATLASS_HOST_DEVICE constexpr
Stride<Ts...> MakeStride(Ts const&... t) {
return {t...};
}
template <class... Ts>
CATLASS_HOST_DEVICE constexpr
Coord<Ts...> MakeCoord(Ts const&... t) {
return {t...};
}
namespace detail {
template <int Rank, class Sequence = void>
struct MakeOriginShapeTypeImpl;
template <int Rank, size_t... Is>
struct MakeOriginShapeTypeImpl<Rank, tla::index_sequence<Is...>> {
template <size_t>
using repeat_type = uint32_t;
using type = Shape<repeat_type<Is>...>;
};
template <class Stride>
using MakeOriginShapeType = typename MakeOriginShapeTypeImpl<rank_v<Stride>, tla::make_index_sequence<rank_v<Stride>>>::type;
struct UnpackedMakeOriginShapeU32 {
template <class... T>
CATLASS_HOST_DEVICE constexpr
auto operator()(T const&... a) const {
return MakeShape(static_cast<uint32_t>(a)...);
}
};
}
template <class Shape, class Stride, class OriginShape = detail::MakeOriginShapeType<Stride>>
struct Layout : private tla::tuple<Shape, Stride, OriginShape> {
CATLASS_HOST_DEVICE constexpr
Layout(Shape const& shape = {}, Stride const& stride = {}, OriginShape const& originShape = {})
: tla::tuple<Shape, Stride, OriginShape>(shape, stride, originShape) {}
static constexpr int rank = rank_v<Stride>;
static constexpr int depth = depth_v<Stride>;
template <int... I>
CATLASS_HOST_DEVICE constexpr
decltype(auto) shape()
{
return get<0, I...>(static_cast<tla::tuple<Shape, Stride, OriginShape>&>(*this));
}
template <int... I>
CATLASS_HOST_DEVICE constexpr
decltype(auto) shape() const
{
return get<0, I...>(static_cast<tla::tuple<Shape, Stride, OriginShape> const&>(*this));
}
template <int... I>
CATLASS_HOST_DEVICE constexpr
decltype(auto) stride()
{
return get<1, I...>(static_cast<tla::tuple<Shape, Stride, OriginShape>&>(*this));
}
template <int... I>
CATLASS_HOST_DEVICE constexpr
decltype(auto) stride() const
{
return get<1, I...>(static_cast<tla::tuple<Shape, Stride, OriginShape> const&>(*this));
}
template <int... I>
CATLASS_HOST_DEVICE constexpr
decltype(auto) originShape()
{
return get<2, I...>(static_cast<tla::tuple<Shape, Stride, OriginShape>&>(*this));
}
template <int... I>
CATLASS_HOST_DEVICE constexpr
decltype(auto) originShape() const
{
return get<2, I...>(static_cast<tla::tuple<Shape, Stride, OriginShape> const&>(*this));
}
template <class Coord>
CATLASS_HOST_DEVICE constexpr
auto operator()(Coord const& coord) const
{
return crd2offset(coord, shape(), stride());
}
};
template <class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
auto MakeLayout(Shape const& shape, Stride const& stride, OriginShape const& originShape)
{
static_assert(is_tuple<Shape>::value || is_integral<Shape>::value);
static_assert(is_tuple<Stride>::value || is_integral<Stride>::value);
static_assert(is_tuple<OriginShape>::value || is_integral<OriginShape>::value);
return Layout<Shape, Stride, OriginShape>(shape, stride, originShape);
}
template <class Shape, class Stride>
CATLASS_HOST_DEVICE constexpr
auto MakeLayout(Shape const& shape, Stride const& stride)
{
static_assert(is_tuple<Shape>::value || is_integral<Shape>::value);
static_assert(is_tuple<Stride>::value || is_integral<Stride>::value);
return MakeLayout(shape, stride, tla::transform_apply(shape, Product{}, detail::UnpackedMakeOriginShapeU32{}));
}
template <class LayoutTag>
CATLASS_HOST_DEVICE constexpr
auto MakeLayoutFromTag(LayoutTag const& tag)
{
static_assert(std::is_same_v<LayoutTag, Catlass::layout::RowMajor> ||
std::is_same_v<LayoutTag, Catlass::layout::ColumnMajor> ||
std::is_same_v<LayoutTag, Catlass::layout::VectorLayout> ||
std::is_same_v<LayoutTag, Catlass::layout::zN> ||
std::is_same_v<LayoutTag, Catlass::layout::nZ> ||
std::is_same_v<LayoutTag, Catlass::layout::L0C>,
"Unsupported LayoutTag for MakeLayoutFromTag, only support Catlass::layout::RowMajor or"
"Catlass::layout::ColumnMajor or Catlass::layout::VectorLayout or Catlass::layout::zN or Catlass::layout::nZ or Catlass::layout::L0C");
if constexpr (std::is_same_v<LayoutTag, Catlass::layout::VectorLayout>) {
return MakeLayout(MakeShape(tag.shape(0)),
MakeStride(tag.stride(0)),
MakeShape(tag.shape(0)));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::RowMajor>) {
return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)),
MakeStride(tag.stride(0), Int<1>{}),
MakeShape(tag.shape(0), tag.shape(1)));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::ColumnMajor>) {
return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)),
MakeStride(Int<1>{}, tag.stride(1)),
MakeShape(tag.shape(0), tag.shape(1)));
} else {
return MakeLayout(MakeShape(MakeShape(tag.shape(0), tag.shape(1)), MakeShape(tag.shape(2), tag.shape(3))),
MakeStride(MakeStride(tag.stride(0), tag.stride(1)), MakeStride(tag.stride(2), tag.stride(3))),
MakeShape(tag.orgShape(0), tag.orgShape(1)));
}
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
decltype(auto) shape(Layout<Shape, Stride, OriginShape>& layout)
{
return layout.template shape<Is...>();
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
decltype(auto) shape(Layout<Shape, Stride, OriginShape> const& layout)
{
return layout.template shape<Is...>();
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
decltype(auto) stride(Layout<Shape, Stride, OriginShape>& layout)
{
return layout.template stride<Is...>();
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
decltype(auto) stride(Layout<Shape, Stride, OriginShape> const& layout)
{
return layout.template stride<Is...>();
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
decltype(auto) originShape(Layout<Shape, Stride, OriginShape>& layout)
{
return layout.template originShape<Is...>();
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
decltype(auto) originShape(Layout<Shape, Stride, OriginShape> const& layout)
{
return layout.template originShape<Is...>();
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
auto rank(Layout<Shape, Stride, OriginShape> const& layout)
{
return rank(shape<Is...>(layout));
}
template <int... Is, class Shape, class Stride, class OriginShape>
CATLASS_HOST_DEVICE constexpr
auto depth(Layout<Shape, Stride, OriginShape> const& layout)
{
return depth(shape<Is...>(layout));
}
template <class Coord, class Shape, class Stride>
CATLASS_HOST_DEVICE constexpr
auto crd2offset(Coord const& coord, Shape const& shape, Stride const& stride);
namespace detail {
template <class Coord, class Shape, class Stride, int... Is>
CATLASS_HOST_DEVICE constexpr
auto crd2offset_ttt(Coord const& coord, Shape const& shape, Stride const& stride, seq<Is...>)
{
return (... + crd2offset(get<Is>(coord), get<Is>(shape), get<Is>(stride)));
}
template <class CInt, class STuple, class DTuple, int I0, int... Is>
CATLASS_HOST_DEVICE constexpr
auto crd2offset_itt(CInt const& coord, STuple const& shape, DTuple const& stride, seq<I0, Is...>)
{
if constexpr (sizeof...(Is) == 0) {
return crd2offset(coord, get<I0>(shape), get<I0>(stride));
} else if constexpr (is_constant<0, CInt>::value) {
return crd2offset(_0{}, get<I0>(shape), get<I0>(stride)) +
(_0{} + ... + crd2offset(_0{}, get<Is>(shape), get<Is>(stride)));
} else {
return crd2offset(coord % Product{}(get<I0>(shape)), get<I0>(shape), get<I0>(stride)) +
crd2offset_itt(coord / Product{}(get<I0>(shape)), shape, stride, seq<Is...>{});
}
}
}
template <class Coord, class Shape, class Stride>
CATLASS_HOST_DEVICE constexpr
auto crd2offset(Coord const& coord, Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Coord>::value) {
if constexpr (is_tuple<Shape>::value) {
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched Ranks");
static_assert(tuple_size<Coord>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return detail::crd2offset_ttt(coord, shape, stride, tuple_seq<Coord>{});
} else {
static_assert(sizeof(Coord) == 0, "Invalid parameters");
}
} else {
if constexpr (is_tuple<Shape>::value) {
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return detail::crd2offset_itt(coord, shape, stride, tuple_seq<Shape>{});
} else {
return coord * stride;
}
}
}
template <class Layout>
struct is_layout : false_type {};
template <class Shape, class Stride, class OriginShape>
struct is_layout<Layout<Shape, Stride, OriginShape>> : true_type {};
namespace detail {
template <class Layout, class Enable = void>
struct isVector {
static bool const value = false;
};
template <class Layout>
struct isVector<Layout, std::enable_if_t<Layout::depth == 1 && Layout::rank == 1>> {
static bool const value = (stride<0>(Layout{}) == 1);
};
template <class Layout, class Enable = void>
struct isRowMajor {
static bool const value = false;
};
template <class Layout>
struct isRowMajor<Layout, std::enable_if_t<Layout::depth == 1 && Layout::rank == 2>> {
static bool const value = (stride<1>(Layout{}) == 1);
};
template <class Layout, class Enable = void>
struct isColumnMajor {
static bool const value = false;
};
template <class Layout>
struct isColumnMajor<Layout, std::enable_if_t<Layout::depth == 1 && Layout::rank == 2>> {
static bool const value = (stride<0>(Layout{}) == 1);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct iszN {
static bool const value = false;
};
template <class Element, class Layout>
struct iszN<Element, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2 && rank_v<decltype(stride<0>(Layout{}))> == 2 &&
rank_v<decltype(stride<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 =
Catlass::BytesToBits(Catlass::BYTE_PER_C0) / Catlass::SizeOfBits<Element>::value;
static constexpr uint32_t ELE_NUM_PER_FRACTAL =
Catlass::BytesToBits(Catlass::BYTE_PER_FRACTAL) / Catlass::SizeOfBits<Element>::value;
static bool const value = (shape<0, 0>(Layout{}) == Catlass::C0_NUM_PER_FRACTAL &&
shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<1, 0>(Layout{}) == 1 &&
stride<0, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
For matmul m axis is not c0 Aligned.
Exp: oriShape(m, k) : (127, 256)
zNUnAlign shape:((127, 1), (16, 256/16)) zN shape: ((16, Ceil(127/16)), (16, 256/16))
*/
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct iszNUnAlign {
static bool const value = false;
};
template <class Element, class Layout>
struct iszNUnAlign<Element, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 =
Catlass::BytesToBits(Catlass::BYTE_PER_C0) / Catlass::SizeOfBits<Element>::value;
static bool const value = (shape<0, 1>(Layout{}) == 1 &&
shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<0, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<1, 0>(Layout{}) == 1);
};
template <class Element, class Layout, class Enable = void>
struct iszZ {
static bool const value = false;
};
template <class Element, class Layout>
struct iszZ<Element, Layout, std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 =
Catlass::BytesToBits(Catlass::BYTE_PER_C0) / Catlass::SizeOfBits<Element>::value;
static constexpr uint32_t ELE_NUM_PER_FRACTAL =
Catlass::BytesToBits(Catlass::BYTE_PER_FRACTAL) / Catlass::SizeOfBits<Element>::value;
static bool const value = (shape<0, 0>(Layout{}) == Catlass::C0_NUM_PER_FRACTAL &&
shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<1, 0>(Layout{}) == 1 &&
stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isnZ {
static bool const value = false;
};
template <class Element, class Layout>
struct isnZ<Element, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(stride<0>(Layout{}))> == 2 &&
rank_v<decltype(stride<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 =
Catlass::BytesToBits(Catlass::BYTE_PER_C0) / Catlass::SizeOfBits<Element>::value;
static constexpr uint32_t ELE_NUM_PER_FRACTAL =
Catlass::BytesToBits(Catlass::BYTE_PER_FRACTAL) / Catlass::SizeOfBits<Element>::value;
static bool const value = (shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 &&
shape<1, 0>(Layout{}) == Catlass::C0_NUM_PER_FRACTAL &&
stride<0, 0>(Layout{}) == 1 &&
stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
#if (defined(CATLASS_ARCH) && CATLASS_ARCH == 3510)
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleForRowMajorA {
static bool const value = false;
};
template <class Layout>
struct isMxScaleForRowMajorA<float8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(stride<0>(Layout{}))> == 1 &&
rank_v<decltype(stride<1>(Layout{}))> == 2 && !is_constant<2, decltype(stride<0>(Layout{}))>::value &&
((rank_v<decltype(shape<0>(Layout{}))> == 1 && rank_v<decltype(shape<1>(Layout{}))> == 2) ||
(rank_v<decltype(shape<0>(Layout{}))> == 2 && rank_v<decltype(shape<1>(Layout{}))> == 2))>> {
static bool const value = true;
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleForColumnMajorA {
static bool const value = false;
};
template <class Layout>
struct isMxScaleForColumnMajorA<float8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(stride<0>(Layout{}))> == 1 &&
rank_v<decltype(stride<1>(Layout{}))> == 2 && is_constant<2, decltype(stride<0>(Layout{}))>::value &&
((rank_v<decltype(shape<0>(Layout{}))> == 1 && rank_v<decltype(shape<1>(Layout{}))> == 2) ||
(rank_v<decltype(shape<0>(Layout{}))> == 2 && rank_v<decltype(shape<1>(Layout{}))> == 2))>> {
static bool const value = true;
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void, class Enable3 = void>
struct isMxScaleForRowMajorB {
static bool const value = false;
};
template <class Layout>
struct isMxScaleForRowMajorB<float8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(stride<0>(Layout{}))> == 2 &&
rank_v<decltype(stride<1>(Layout{}))> == 1>,
std::enable_if_t<!is_constant<2, decltype(stride<0, 1>(Layout{}))>::value &&
((rank_v<decltype(shape<0>(Layout{}))> == 2 && rank_v<decltype(shape<1>(Layout{}))> == 1) ||
(rank_v<decltype(shape<0>(Layout{}))> == 2 && rank_v<decltype(shape<1>(Layout{}))> == 2))>> {
static bool const value = true;
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void, class Enable3 = void>
struct isMxScaleForColumnMajorB {
static bool const value = false;
};
template <class Layout>
struct isMxScaleForColumnMajorB<float8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(stride<0>(Layout{}))> == 2 &&
rank_v<decltype(stride<1>(Layout{}))> == 1>,
std::enable_if_t<is_constant<2, decltype(stride<0, 1>(Layout{}))>::value &&
((rank_v<decltype(shape<0>(Layout{}))> == 2 && rank_v<decltype(shape<1>(Layout{}))> == 1) ||
(rank_v<decltype(shape<0>(Layout{}))> == 2 && rank_v<decltype(shape<1>(Layout{}))> == 2))>> {
static bool const value = true;
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleForzZ {
static bool const value = false;
};
template <class Layout>
struct isMxScaleForzZ<float8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2 && rank_v<decltype(stride<0>(Layout{}))> == 2 &&
rank_v<decltype(stride<1>(Layout{}))> == 2>> {
static bool const value = true;
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleFornN {
static bool const value = false;
};
template <class Layout>
struct isMxScaleFornN<float8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>,
std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2 && rank_v<decltype(stride<0>(Layout{}))> == 2 &&
rank_v<decltype(stride<1>(Layout{}))> == 2>> {
static bool const value = true;
};
#endif
}
template <class T>
CATLASS_HOST_DEVICE constexpr
auto MakeLayout(T const& len)
{
return MakeLayout(MakeShape(len), MakeStride(Int<1>{}), MakeShape(len));
}
template <class Element, class LayoutTag, class T, class U>
CATLASS_HOST_DEVICE constexpr
auto MakeLayout(T const& rows, U const& cols)
{
static_assert(std::is_same_v<LayoutTag, Catlass::layout::RowMajor> ||
std::is_same_v<LayoutTag, Catlass::layout::ColumnMajor> ||
std::is_same_v<LayoutTag, Catlass::layout::VectorLayout> ||
std::is_same_v<LayoutTag, Catlass::layout::zN> ||
std::is_same_v<LayoutTag, Catlass::layout::nZ> ||
std::is_same_v<LayoutTag, Catlass::layout::zZ> ||
std::is_same_v<LayoutTag, Catlass::layout::L0C>,
"Unsupported LayoutTag for MakeLayoutFromTag, only support Catlass::layout::RowMajor or"
"Catlass::layout::ColumnMajor or Catlass::layout::zN or Catlass::layout::nZ or Catlass::layout::zZ or Catlass::layout::L0C");
constexpr uint32_t ELE_NUM_PER_C0 =
Catlass::BytesToBits(Catlass::BYTE_PER_C0) / Catlass::SizeOfBits<Element>::value;
constexpr uint32_t ELE_NUM_PER_FRACTAL =
Catlass::BytesToBits(Catlass::BYTE_PER_FRACTAL) / Catlass::SizeOfBits<Element>::value;
if constexpr (std::is_same_v<LayoutTag, Catlass::layout::VectorLayout>) {
return MakeLayout(MakeShape(cols), MakeStride(Int<1>{}), MakeShape(cols));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::RowMajor>) {
#if (defined(CATLASS_ARCH) && CATLASS_ARCH == 3510)
if constexpr (std::is_same_v<Element, float4_e2m1x2_t> || std::is_same_v<Element, float4_e1m2x2_t>) {
return MakeLayout(MakeShape(rows, cols),
MakeStride((int64_t)RoundUp(cols, 2), Int<1>{}),
MakeShape(rows, cols));
}
#endif
return MakeLayout(MakeShape(rows, cols),
MakeStride((int64_t)cols, Int<1>{}),
MakeShape(rows, cols));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::ColumnMajor>) {
#if (defined(CATLASS_ARCH) && CATLASS_ARCH == 3510)
if constexpr (std::is_same_v<Element, float4_e2m1x2_t> || std::is_same_v<Element, float4_e1m2x2_t>) {
return MakeLayout(MakeShape(rows, cols),
MakeStride(Int<1>{}, (int64_t)RoundUp(rows, 2)),
MakeShape(rows, cols));
}
#endif
return MakeLayout(MakeShape(rows, cols),
MakeStride(Int<1>{}, (int64_t)rows),
MakeShape(rows, cols));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::zN>) {
return MakeLayout(
MakeShape(MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<Catlass::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(MakeStride(Int<ELE_NUM_PER_C0>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(Int<1>{}, RoundUp((int64_t)rows, Int<Catlass::C0_NUM_PER_FRACTAL>{}) * ELE_NUM_PER_C0)),
MakeShape(rows, cols));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::zZ>) {
return MakeLayout(
MakeShape(MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<Catlass::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(MakeStride(Int<ELE_NUM_PER_C0>{},
RoundUp((int64_t)cols, Int<ELE_NUM_PER_C0>{}) * Catlass::C0_NUM_PER_FRACTAL),
MakeStride(Int<1>{}, Int<ELE_NUM_PER_FRACTAL>{})),
MakeShape(rows, cols));
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::L0C>) {
constexpr uint32_t ELE_NUM_PER_FRACTAL = 256;
return MakeLayout(
MakeShape(MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<Catlass::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<Catlass::C0_NUM_PER_FRACTAL>{}))),
MakeStride(MakeStride(Int<Catlass::C0_NUM_PER_FRACTAL>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(Int<1>{}, RoundUp((int64_t)rows, Int<Catlass::C0_NUM_PER_FRACTAL>{}) * Catlass::C0_NUM_PER_FRACTAL)),
MakeShape(rows, cols));
} else {
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})),
MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<Catlass::C0_NUM_PER_FRACTAL>{}))),
MakeStride(
MakeStride(Int<1>{}, RoundUp((int64_t)cols, Int<Catlass::C0_NUM_PER_FRACTAL>{}) * ELE_NUM_PER_C0),
MakeStride(Int<ELE_NUM_PER_C0>{}, Int<ELE_NUM_PER_FRACTAL>{})),
MakeShape(rows, cols));
}
}
#if (defined(CATLASS_ARCH) && CATLASS_ARCH == 3510)
template <class Element, class LayoutTag, bool isMxScaleB, class T, class U>
CATLASS_HOST_DEVICE constexpr
auto MakeMxScaleLayout(T const& rows, U const& cols)
{
static_assert(
std::is_same_v<Element, float8_e8m0_t> &&
(std::is_same_v<LayoutTag, Catlass::layout::RowMajor> ||
std::is_same_v<LayoutTag, Catlass::layout::ColumnMajor> ||
std::is_same_v<LayoutTag, Catlass::layout::zZ> || std::is_same_v<LayoutTag, Catlass::layout::nN>),
"only support RowMajor, ColumnMajor, zZ, nN in fp8_e8m0_t dtype"
);
constexpr uint32_t ELE_NUM_PER_C0 = 2;
constexpr uint32_t ELE_NUM_PER_FRACTAL = 32;
if constexpr (std::is_same_v<LayoutTag, Catlass::layout::RowMajor>) {
if constexpr (!isMxScaleB) {
return MakeLayout(
MakeShape(rows, MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(RoundUp(cols, Int<ELE_NUM_PER_C0>{}), MakeStride(Int<1>{}, Int<ELE_NUM_PER_C0>{})),
MakeShape(rows, cols));
} else {
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})), cols),
MakeStride(MakeStride(Int<1>{}, cols * ELE_NUM_PER_C0), Int<ELE_NUM_PER_C0>{}),
MakeShape(rows, cols));
}
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::ColumnMajor>) {
if constexpr (!isMxScaleB) {
return MakeLayout(
MakeShape(rows, MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(Int<ELE_NUM_PER_C0>{}, MakeStride(Int<1>{}, rows * ELE_NUM_PER_C0)),
MakeShape(rows, cols));
} else {
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})), cols),
MakeStride(MakeStride(Int<1>{}, Int<ELE_NUM_PER_C0>{}), RoundUp(rows, Int<ELE_NUM_PER_C0>{})),
MakeShape(rows, cols));
}
} else if constexpr (std::is_same_v<LayoutTag, Catlass::layout::zZ>) {
return MakeLayout(
MakeShape(
MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<Catlass::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))
),
MakeStride(
MakeStride(
Int<ELE_NUM_PER_C0>{}, RoundUp((int64_t)cols, Int<ELE_NUM_PER_C0>{}) * Catlass::C0_NUM_PER_FRACTAL
),
MakeStride(Int<1>{}, Int<ELE_NUM_PER_FRACTAL>{})
),
MakeShape(rows, cols));
} else {
return MakeLayout(
MakeShape(
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})),
MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<Catlass::C0_NUM_PER_FRACTAL>{}))
),
MakeStride(
MakeStride(Int<1>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(
Int<ELE_NUM_PER_C0>{}, RoundUp((int64_t)rows, Int<ELE_NUM_PER_C0>{}) * Catlass::C0_NUM_PER_FRACTAL
)
),
MakeShape(rows, cols));
}
}
#endif
namespace detail {
template <class OriginBase, class TileShape, class Coord, int... Is>
CATLASS_HOST_DEVICE constexpr
auto CropOriginShape(OriginBase const& originBase, TileShape const& tileShape, Coord const& coord, seq<Is...>)
{
return MakeShape(
tla::min(
static_cast<uint32_t>(get<Is>(tileShape)),
(static_cast<uint32_t>(get<Is>(coord)) < static_cast<uint32_t>(get<Is>(originBase)))
? (static_cast<uint32_t>(get<Is>(originBase)) - static_cast<uint32_t>(get<Is>(coord)))
: 0u
)...
);
}
}
template <class Layout, class TileShape, class Coord>
CATLASS_HOST_DEVICE constexpr
auto GetTileLayout(Layout const& layout, TileShape const& tileShape, Coord const& coord)
{
static_assert(is_tuple<TileShape>::value && depth_v<TileShape> == 1 && rank_v<TileShape> >= 1,
"GetTileLayout: TileShape must be a flat tuple with rank >= 1.");
static_assert(is_tuple<Coord>::value && depth_v<Coord> == 1 && rank_v<Coord> == rank_v<TileShape>,
"GetTileLayout: Coord must have the same rank as TileShape.");
auto tileOriginShape = detail::CropOriginShape(layout.originShape(), tileShape, coord, tuple_seq<TileShape>{});
if constexpr (Layout::depth == 1) {
static_assert(Layout::rank == rank_v<TileShape>,
"GetTileLayout: for depth==1 layouts, TileShape rank must match layout rank.");
return MakeLayout(tileShape, layout.stride(), tileOriginShape);
} else {
static_assert(rank_v<TileShape> == 2,
"GetTileLayout: for depth>1 (fractal) layouts, TileShape must be rank-2 (rows, cols).");
const uint32_t rows = get<0>(tileShape);
const uint32_t cols = get<1>(tileShape);
if constexpr (Layout::depth == 2 && Layout::rank == 2 && rank_v<decltype(shape<0>(Layout{}))> == 1 &&
rank_v<decltype(shape<1>(Layout{}))> == 2) {
constexpr uint32_t ELE_NUM_PER_C0 = decltype(shape<1, 0>(layout))::value;
return MakeLayout(
MakeShape(rows, MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
layout.stride(),
tileOriginShape);
}
else if constexpr (Layout::depth == 2 && Layout::rank == 2 && rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 1) {
constexpr uint32_t ELE_NUM_PER_C0 = decltype(shape<0, 0>(layout))::value;
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})), cols),
layout.stride(),
tileOriginShape);
}
else if constexpr (is_static<decltype(shape<0, 0>(layout))>::value &&
is_static<decltype(shape<1, 0>(layout))>::value) {
constexpr uint32_t dstInnerShapeRow = decltype(shape<0, 0>(layout))::value;
constexpr uint32_t dstInnerShapeCol = decltype(shape<1, 0>(layout))::value;
return MakeLayout(
MakeShape(MakeShape(Int<dstInnerShapeRow>{}, CeilDiv<dstInnerShapeRow>(rows)),
MakeShape(Int<dstInnerShapeCol>{}, CeilDiv<dstInnerShapeCol>(cols))),
layout.stride(),
tileOriginShape);
}
else {
const uint32_t dstInnerShapeRow = shape<0, 0>(layout);
const uint32_t dstInnerShapeCol = shape<1, 0>(layout);
return MakeLayout(
MakeShape(MakeShape(dstInnerShapeRow, CeilDiv(rows, dstInnerShapeRow)),
MakeShape(dstInnerShapeCol, CeilDiv(cols, dstInnerShapeCol))),
layout.stride(),
tileOriginShape);
}
}
}
template <class T, class U>
CATLASS_HOST_DEVICE constexpr
auto MakeLayoutL0C(T const& rows, U const& cols)
{
constexpr uint32_t ELE_NUM_PER_FRACTAL = 256;
return MakeLayout(
MakeShape(MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<Catlass::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<Catlass::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<Catlass::C0_NUM_PER_FRACTAL>{}))),
MakeStride(MakeStride(Int<Catlass::C0_NUM_PER_FRACTAL>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(Int<1>{}, RoundUp((int64_t)rows, Int<Catlass::C0_NUM_PER_FRACTAL>{}) * Catlass::C0_NUM_PER_FRACTAL)),
MakeShape(rows, cols));
}
template <class Element, class T1, class T2, class T3, class T4>
CATLASS_HOST_DEVICE constexpr
auto MakeLayoutFmap(T1 const& batch, T2 const& cin1, T3 const& h, T4 const& w)
{
constexpr uint32_t ELE_NUM_PER_C0 = Catlass::BYTE_PER_C0 / sizeof(Element);
const int64_t strideH = w * ELE_NUM_PER_C0;
const int64_t strideCin1 = h * strideH;
const int64_t strideBatch = cin1 * strideCin1;
return MakeLayout(
MakeShape(static_cast<uint32_t>(batch), static_cast<uint32_t>(cin1), static_cast<uint32_t>(h),
static_cast<uint32_t>(w), Int<ELE_NUM_PER_C0>{}),
MakeStride(strideBatch, strideCin1, strideH, Int<ELE_NUM_PER_C0>{}, Int<1>{}),
MakeShape(static_cast<uint32_t>(batch), static_cast<uint32_t>(cin1), static_cast<uint32_t>(h),
static_cast<uint32_t>(w), ELE_NUM_PER_C0)
);
}
template <class Element, class PositionType, class T1, class T2, class T3, class T4>
CATLASS_HOST_DEVICE constexpr
auto MakeLayoutFilter(T1 const& cin1, T2 const& kh, T3 const& kw, T4 const& cout)
{
constexpr uint32_t ELE_NUM_PER_C0 = Catlass::BYTE_PER_C0 / sizeof(Element);
const uint32_t coutRound = std::is_same_v<PositionType, Catlass::Arch::PositionL1> ?
RoundUp(cout, Catlass::C0_NUM_PER_FRACTAL) : cout;
const int64_t strideKw = coutRound * ELE_NUM_PER_C0;
const int64_t strideKh = kw * strideKw;
const int64_t strideCin1 = kh * strideKh;
return MakeLayout(
MakeShape(static_cast<uint32_t>(cin1), static_cast<uint32_t>(kh), static_cast<uint32_t>(kw),
static_cast<uint32_t>(coutRound), Int<ELE_NUM_PER_C0>{}),
MakeStride(strideCin1, strideKh, strideKw, Int<ELE_NUM_PER_C0>{}, Int<1>{}),
MakeShape(static_cast<uint32_t>(cin1), static_cast<uint32_t>(kh), static_cast<uint32_t>(kw),
static_cast<uint32_t>(coutRound), ELE_NUM_PER_C0)
);
}
namespace detail {
template <class Layout, class NewShapeT = uint32_t, class NewStrideT = int64_t, class NewOriginT = uint32_t,
class Seq = tla::make_seq<Layout::rank>>
struct PrependDimLayout;
template <class Layout, class NewShapeT, class NewStrideT, class NewOriginT, int... Is>
struct PrependDimLayout<Layout, NewShapeT, NewStrideT, NewOriginT, tla::seq<Is...>> {
using ShapeOld = tla::remove_cvref_t<decltype(std::declval<Layout const&>().shape())>;
using StrideOld = tla::remove_cvref_t<decltype(std::declval<Layout const&>().stride())>;
using OriginOld = tla::remove_cvref_t<decltype(std::declval<Layout const&>().originShape())>;
using ShapeNew = tla::Shape<NewShapeT, tla::remove_cvref_t<decltype(tla::get<Is>(std::declval<ShapeOld>()))>...>;
using StrideNew = tla::Stride<NewStrideT, tla::remove_cvref_t<decltype(tla::get<Is>(std::declval<StrideOld>()))>...>;
using OriginNew = tla::Shape<NewOriginT, tla::remove_cvref_t<decltype(tla::get<Is>(std::declval<OriginOld>()))>...>;
using type = tla::Layout<ShapeNew, StrideNew, OriginNew>;
};
}
template <class Layout, class NewShapeT = uint32_t, class NewStrideT = int64_t, class NewOriginT = uint32_t>
using PrependDimLayout_t = typename detail::PrependDimLayout<Layout, NewShapeT, NewStrideT, NewOriginT>::type;
template <class Layout>
using MakeBatchedLayout_t = PrependDimLayout_t<Layout>;
}
# endif