* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef TLA_LAYOUT_HPP
#define TLA_LAYOUT_HPP
#include "../attn_infra/bsa_base_defs.hpp"
#include "../tla/numeric/integral_constant.hpp"
#include "../tla/tuple.hpp"
#include "../tla/int_tuple.hpp"
#include "../attn_infra/layout/bsa_layout.hpp"
namespace tla {
template <class... Shapes>
using Shape = tla::tuple<Shapes...>;
template <class... Strides>
using Stride = tla::tuple<Strides...>;
template <class... Coords>
using Coord = tla::tuple<Coords...>;
template <class... Ts>
HOST_DEVICE constexpr
Shape<Ts...> MakeShape(Ts const&... t) {
return {t...};
}
template <class... Ts>
HOST_DEVICE constexpr
Stride<Ts...> MakeStride(Ts const&... t) {
return {t...};
}
template <class... Ts>
HOST_DEVICE constexpr
Coord<Ts...> MakeCoord(Ts const&... t) {
return {t...};
}
template <class Shape, class Stride>
struct Layout : private tla::tuple<Shape, Stride> {
HOST_DEVICE constexpr
Layout(Shape const& shape = {}, Stride const& stride = {})
: tla::tuple<Shape, Stride>(shape, stride) {}
static constexpr int rank = rank_v<Stride>;
static constexpr int depth = depth_v<Stride>;
template <int... I>
HOST_DEVICE constexpr
decltype(auto) shape()
{
return get<0, I...>(static_cast<tla::tuple<Shape, Stride>&>(*this));
}
template <int... I>
HOST_DEVICE constexpr
decltype(auto) shape() const
{
return get<0, I...>(static_cast<tla::tuple<Shape, Stride> const&>(*this));
}
template <int... I>
HOST_DEVICE constexpr
decltype(auto) stride()
{
return get<1, I...>(static_cast<tla::tuple<Shape, Stride>&>(*this));
}
template <int... I>
HOST_DEVICE constexpr
decltype(auto) stride() const
{
return get<1, I...>(static_cast<tla::tuple<Shape, Stride> const&>(*this));
}
template <class Coord>
HOST_DEVICE constexpr
auto operator()(Coord const& coord) const
{
return crd2offset(coord, shape(), stride());
}
};
template <class Shape, class Stride>
HOST_DEVICE constexpr
auto MakeLayout(Shape const& shape, Stride const& stride)
{
static_assert(is_tuple<Shape>::value || is_integral<Shape>::value);
static_assert(is_tuple<Stride>::value || is_integral<Stride>::value);
return Layout<Shape, Stride>(shape, stride);
}
template <class LayoutTag>
HOST_DEVICE constexpr
auto MakeLayoutFromTag(LayoutTag const& tag)
{
static_assert(std::is_same_v<LayoutTag, NpuArch::layout::RowMajor> ||
std::is_same_v<LayoutTag, NpuArch::layout::ColumnMajor> ||
std::is_same_v<LayoutTag, NpuArch::layout::zN> ||
std::is_same_v<LayoutTag, NpuArch::layout::nZ>,
"Unsupported LayoutTag for MakeLayoutFromTag, only support NpuArch::layout::RowMajor or"
"NpuArch::layout::ColumnMajor or NpuArch::layout::zN or NpuArch::layout::nZ");
if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::RowMajor>) {
return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)), MakeStride(tag.stride(0), Int<1>{}));
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::ColumnMajor>) {
return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)), MakeStride(Int<1>{}, tag.stride(1)));
} else {
return MakeLayout(MakeShape(MakeShape(tag.shape(0), tag.shape(1)), MakeShape(tag.shape(2), tag.shape(3))),
MakeStride(MakeStride(tag.stride(0), tag.stride(1)), MakeStride(tag.stride(2), tag.stride(3))));
}
}
template <int... Is, class Shape, class Stride>
HOST_DEVICE constexpr
decltype(auto) shape(Layout<Shape, Stride>& layout)
{
return layout.template shape<Is...>();
}
template <int... Is, class Shape, class Stride>
HOST_DEVICE constexpr
decltype(auto) shape(Layout<Shape, Stride> const& layout)
{
return layout.template shape<Is...>();
}
template <int... Is, class Shape, class Stride>
HOST_DEVICE constexpr
decltype(auto) stride(Layout<Shape, Stride>& layout)
{
return layout.template stride<Is...>();
}
template <int... Is, class Shape, class Stride>
HOST_DEVICE constexpr
decltype(auto) stride(Layout<Shape, Stride> const& layout)
{
return layout.template stride<Is...>();
}
template <int... Is, class Shape, class Stride>
HOST_DEVICE constexpr
auto rank(Layout<Shape, Stride> const& layout)
{
return rank(shape<Is...>(layout));
}
template <int... Is, class Shape, class Stride>
HOST_DEVICE constexpr
auto depth(Layout<Shape, Stride> const& layout)
{
return depth(shape<Is...>(layout));
}
template <class Coord, class Shape, class Stride>
HOST_DEVICE constexpr
auto crd2offset(Coord const& coord, Shape const& shape, Stride const& stride);
namespace detail {
template <class Coord, class Shape, class Stride, int... Is>
HOST_DEVICE constexpr
auto crd2offset_ttt(Coord const& coord, Shape const& shape, Stride const& stride, seq<Is...>)
{
return (... + crd2offset(get<Is>(coord), get<Is>(shape), get<Is>(stride)));
}
template <class CInt, class STuple, class DTuple, int I0, int... Is>
HOST_DEVICE constexpr
auto crd2offset_itt(CInt const& coord, STuple const& shape, DTuple const& stride, seq<I0, Is...>)
{
if constexpr (sizeof...(Is) == 0) {
return crd2offset(coord, get<I0>(shape), get<I0>(stride));
} else if constexpr (is_constant<0, CInt>::value) {
return crd2offset(_0{}, get<I0>(shape), get<I0>(stride)) +
(_0{} + ... + crd2offset(_0{}, get<Is>(shape), get<Is>(stride)));
} else {
return crd2offset(coord % Product{}(get<I0>(shape)), get<I0>(shape), get<I0>(stride)) +
crd2offset_itt(coord / Product{}(get<I0>(shape)), shape, stride, seq<Is...>{});
}
}
}
template <class Coord, class Shape, class Stride>
HOST_DEVICE constexpr
auto crd2offset(Coord const& coord, Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Coord>::value) {
if constexpr (is_tuple<Shape>::value) {
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched Ranks");
static_assert(tuple_size<Coord>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return detail::crd2offset_ttt(coord, shape, stride, tuple_seq<Coord>{});
} else {
static_assert(sizeof(Coord) == 0, "Invalid parameters");
}
} else {
if constexpr (is_tuple<Shape>::value) {
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return detail::crd2offset_itt(coord, shape, stride, tuple_seq<Shape>{});
} else {
return coord * stride;
}
}
}
template <class Layout>
struct is_layout : false_type {};
template <class Shape, class Stride>
struct is_layout<Layout<Shape, Stride>> : true_type {};
namespace detail {
template <class Layout, class Enable = void>
struct isVector {
static bool const value = false;
};
template <class Layout>
struct isVector<Layout, std::enable_if_t<Layout::depth == 1 && Layout::rank == 1>> {
static bool const value = (stride<0>(Layout{}) == 1);
};
template <class Layout, class Enable = void>
struct isRowMajor {
static bool const value = false;
};
template <class Layout>
struct isRowMajor<Layout, std::enable_if_t<Layout::depth == 1 && Layout::rank == 2>> {
static bool const value = (stride<1>(Layout{}) == 1);
};
template <class Layout, class Enable = void>
struct isColumnMajor {
static bool const value = false;
};
template <class Layout>
struct isColumnMajor<Layout, std::enable_if_t<Layout::depth == 1 && Layout::rank == 2>> {
static bool const value = (stride<0>(Layout{}) == 1);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct iszN {
static bool const value = false;
};
template <class Element, class Layout>
struct iszN<Element, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = NpuArch::BYTE_PER_C0 / sizeof(Element);
static constexpr uint32_t ELE_NUM_PER_FRACTAL = NpuArch::BYTE_PER_FRACTAL / sizeof(Element);
static bool const value = (shape<0, 0>(Layout{}) == NpuArch::C0_NUM_PER_FRACTAL &&
shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<1, 0>(Layout{}) == 1 &&
stride<0, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct iszZ {
static bool const value = false;
};
template <class Element, class Layout>
struct iszZ<Element, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = NpuArch::BYTE_PER_C0 / sizeof(Element);
static constexpr uint32_t ELE_NUM_PER_FRACTAL = NpuArch::BYTE_PER_FRACTAL / sizeof(Element);
static bool const value = (shape<0, 0>(Layout{}) == NpuArch::C0_NUM_PER_FRACTAL &&
shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<1, 0>(Layout{}) == 1 &&
stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isnZ {
static bool const value = false;
};
template <class Element, class Layout>
struct isnZ<Element, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = NpuArch::BYTE_PER_C0 / sizeof(Element);
static constexpr uint32_t ELE_NUM_PER_FRACTAL = NpuArch::BYTE_PER_FRACTAL / sizeof(Element);
static bool const value = (shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 &&
shape<1, 0>(Layout{}) == NpuArch::C0_NUM_PER_FRACTAL &&
stride<0, 0>(Layout{}) == 1 &&
stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
#if defined(CATLASS_ARCH_A5_ENABLED)
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleANoTrans {
static bool const value = false;
};
template <class Layout>
struct isMxScaleANoTrans<AscendC::fp8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 1 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = 2;
static bool const value =
(shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 && stride<1, 0>(Layout{}) == 1 &&
stride<1, 1>(Layout{}) == ELE_NUM_PER_C0);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleATrans {
static bool const value = false;
};
template <class Layout>
struct isMxScaleATrans<AscendC::fp8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 1 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = 2;
static bool const value =
(shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 && stride<1, 0>(Layout{}) == 1 &&
stride<0>(Layout{}) == ELE_NUM_PER_C0);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleBNoTrans {
static bool const value = false;
};
template <class Layout>
struct isMxScaleBNoTrans<AscendC::fp8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 1>> {
static constexpr uint32_t ELE_NUM_PER_C0 = 2;
static bool const value =
(shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 && stride<0, 0>(Layout{}) == 1 &&
stride<1>(Layout{}) == ELE_NUM_PER_C0);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScaleBTrans {
static bool const value = false;
};
template <class Layout>
struct isMxScaleBTrans<AscendC::fp8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 1>> {
static constexpr uint32_t ELE_NUM_PER_C0 = 2;
static bool const value =
(shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 && stride<0, 0>(Layout{}) == 1 &&
stride<0, 1>(Layout{}) == ELE_NUM_PER_C0);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScalezZ {
static bool const value = false;
};
template <class Layout>
struct isMxScalezZ<AscendC::fp8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = 2;
static constexpr uint32_t ELE_NUM_PER_FRACTAL = 32;
static bool const value = (shape<0, 0>(Layout{}) == NpuArch::C0_NUM_PER_FRACTAL &&
shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 &&
stride<1, 0>(Layout{}) == 1 &&
stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
template <class Element, class Layout, class Enable1 = void, class Enable2 = void>
struct isMxScalenN {
static bool const value = false;
};
template <class Layout>
struct isMxScalenN<AscendC::fp8_e8m0_t, Layout,
std::enable_if_t<Layout::depth == 2 && Layout::rank == 2>, std::enable_if_t<rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 2>> {
static constexpr uint32_t ELE_NUM_PER_C0 = 2;
static constexpr uint32_t ELE_NUM_PER_FRACTAL = 32;
static bool const value = (shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 &&
shape<1, 0>(Layout{}) == NpuArch::C0_NUM_PER_FRACTAL &&
stride<0, 0>(Layout{}) == 1 &&
stride<0, 1>(Layout{}) == ELE_NUM_PER_FRACTAL);
};
#endif
}
template <class T>
HOST_DEVICE constexpr
auto MakeLayout(T const& len)
{
return MakeLayout(MakeShape(len), MakeStride(Int<1>{}));
}
template <class Element, class LayoutTag, class T, class U>
HOST_DEVICE constexpr
auto MakeLayout(T const& rows, U const& cols)
{
static_assert(std::is_same_v<LayoutTag, NpuArch::layout::RowMajor> ||
std::is_same_v<LayoutTag, NpuArch::layout::ColumnMajor> ||
std::is_same_v<LayoutTag, NpuArch::layout::VectorLayout> ||
std::is_same_v<LayoutTag, NpuArch::layout::zN> ||
std::is_same_v<LayoutTag, NpuArch::layout::nZ> ||
std::is_same_v<LayoutTag, NpuArch::layout::zZ>,
"Unsupported LayoutTag for MakeLayoutFromTag, only support NpuArch::layout::RowMajor or"
"NpuArch::layout::ColumnMajor or NpuArch::layout::zN or NpuArch::layout::nZ or NpuArch::layout::zZ");
constexpr uint32_t ELE_NUM_PER_C0 = NpuArch::BYTE_PER_C0 / sizeof(Element);
constexpr uint32_t ELE_NUM_PER_FRACTAL = NpuArch::BYTE_PER_FRACTAL / sizeof(Element);
if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::VectorLayout>) {
return MakeLayout(MakeShape(cols), MakeStride(Int<1>{}));
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::RowMajor>) {
return MakeLayout(MakeShape(rows, cols), MakeStride((int64_t)cols, Int<1>{}));
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::ColumnMajor>) {
return MakeLayout(MakeShape(rows, cols), MakeStride(Int<1>{}, (int64_t)rows));
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::zN>) {
return MakeLayout(
MakeShape(MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<NpuArch::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(MakeStride(Int<ELE_NUM_PER_C0>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(Int<1>{}, RoundUp((int64_t)rows, Int<NpuArch::C0_NUM_PER_FRACTAL>{}) * ELE_NUM_PER_C0)));
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::zZ>) {
return MakeLayout(
MakeShape(MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<NpuArch::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(MakeStride(Int<ELE_NUM_PER_C0>{},
RoundUp((int64_t)cols, Int<ELE_NUM_PER_C0>{}) * NpuArch::C0_NUM_PER_FRACTAL),
MakeStride(Int<1>{}, Int<ELE_NUM_PER_FRACTAL>{})));
} else {
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})),
MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<NpuArch::C0_NUM_PER_FRACTAL>{}))),
MakeStride(
MakeStride(Int<1>{}, RoundUp((int64_t)cols, Int<NpuArch::C0_NUM_PER_FRACTAL>{}) * ELE_NUM_PER_C0),
MakeStride(Int<ELE_NUM_PER_C0>{}, Int<ELE_NUM_PER_FRACTAL>{})));
}
}
template <class Element, class LayoutTag, bool isMxScaleB, class T, class U>
HOST_DEVICE constexpr
auto MakeMxScaleLayout(T const& rows, U const& cols)
{
static_assert(
std::is_same_v<Element, AscendC::fp8_e8m0_t> &&
(std::is_same_v<LayoutTag, NpuArch::layout::RowMajor> ||
std::is_same_v<LayoutTag, NpuArch::layout::ColumnMajor> ||
std::is_same_v<LayoutTag, NpuArch::layout::zZ> || std::is_same_v<LayoutTag, NpuArch::layout::nN>),
"only support RowMajor, ColumnMajor, zZ, nN in fp8_e8m0_t dtype"
);
constexpr uint32_t ELE_NUM_PER_C0 = 2;
constexpr uint32_t ELE_NUM_PER_FRACTAL = 32;
if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::RowMajor>) {
if constexpr (!isMxScaleB) {
return MakeLayout(
MakeShape(rows, MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(RoundUp(cols, Int<ELE_NUM_PER_C0>{}), MakeStride(Int<1>{}, Int<ELE_NUM_PER_C0>{}))
);
} else {
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})), cols),
MakeStride(MakeStride(Int<1>{}, cols * ELE_NUM_PER_C0), Int<ELE_NUM_PER_C0>{})
);
}
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::ColumnMajor>) {
if constexpr (!isMxScaleB) {
return MakeLayout(
MakeShape(rows, MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
MakeStride(Int<ELE_NUM_PER_C0>{}, MakeStride(Int<1>{}, rows * ELE_NUM_PER_C0))
);
} else {
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})), cols),
MakeStride(MakeStride(Int<1>{}, Int<ELE_NUM_PER_C0>{}), RoundUp(rows, Int<ELE_NUM_PER_C0>{}))
);
}
} else if constexpr (std::is_same_v<LayoutTag, NpuArch::layout::zZ>) {
return MakeLayout(
MakeShape(
MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<NpuArch::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))
),
MakeStride(
MakeStride(
Int<ELE_NUM_PER_C0>{}, RoundUp((int64_t)cols, Int<ELE_NUM_PER_C0>{}) * NpuArch::C0_NUM_PER_FRACTAL
),
MakeStride(Int<1>{}, Int<ELE_NUM_PER_FRACTAL>{})
)
);
} else {
return MakeLayout(
MakeShape(
MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})),
MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<NpuArch::C0_NUM_PER_FRACTAL>{}))
),
MakeStride(
MakeStride(Int<1>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(
Int<ELE_NUM_PER_C0>{}, RoundUp((int64_t)rows, Int<ELE_NUM_PER_C0>{}) * NpuArch::C0_NUM_PER_FRACTAL
)
)
);
}
}
template <class Layout, class ShapeNew>
HOST_DEVICE constexpr
auto MakeLayoutTile(Layout const& layout, ShapeNew const& shapeNew)
{
static_assert(
is_tuple<ShapeNew>::value && depth_v<ShapeNew> == 1 && (rank_v<ShapeNew> == 1 || rank_v<ShapeNew> == 2)
);
if constexpr (Layout::depth == 1 && (Layout::rank == 1 || Layout::rank == 2)) {
return MakeLayout(shapeNew, layout.stride());
} else if constexpr (Layout::depth == 2 && Layout::rank == 2 && rank_v<decltype(shape<0>(Layout{}))> == 1 &&
rank_v<decltype(shape<1>(Layout{}))> == 2) {
const uint32_t rows = get<0>(shapeNew);
const uint32_t cols = get<1>(shapeNew);
constexpr uint32_t ELE_NUM_PER_C0 = decltype(shape<1, 0>(layout))::value;
return MakeLayout(
MakeShape(rows, MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(cols, Int<ELE_NUM_PER_C0>{}))),
layout.stride()
);
} else if constexpr (Layout::depth == 2 && Layout::rank == 2 && rank_v<decltype(shape<0>(Layout{}))> == 2 &&
rank_v<decltype(shape<1>(Layout{}))> == 1) {
const uint32_t rows = get<0>(shapeNew);
const uint32_t cols = get<1>(shapeNew);
constexpr uint32_t ELE_NUM_PER_C0 = decltype(shape<0, 0>(layout))::value;
return MakeLayout(
MakeShape(MakeShape(Int<ELE_NUM_PER_C0>{}, CeilDiv(rows, Int<ELE_NUM_PER_C0>{})), cols),
layout.stride()
);
} else if constexpr (is_static<decltype(shape<0, 0>(layout))>::value &&
is_static<decltype(shape<1, 0>(layout))>::value) {
const uint32_t rows = get<0>(shapeNew);
const uint32_t cols = get<1>(shapeNew);
constexpr uint32_t dstInnerShapeRow = decltype(shape<0, 0>(layout))::value;
constexpr uint32_t dstInnerShapeCol = decltype(shape<1, 0>(layout))::value;
return MakeLayout(
MakeShape(MakeShape(Int<dstInnerShapeRow>{}, CeilDiv<dstInnerShapeRow>(rows)),
MakeShape(Int<dstInnerShapeCol>{}, CeilDiv<dstInnerShapeCol>(cols))),
layout.stride());
} else {
const uint32_t rows = get<0>(shapeNew);
const uint32_t cols = get<1>(shapeNew);
const uint32_t dstInnerShapeRow = shape<0, 0>(layout);
const uint32_t dstInnerShapeCol = shape<1, 0>(layout);
return MakeLayout(
MakeShape(MakeShape(dstInnerShapeRow, CeilDiv(rows, dstInnerShapeRow)),
MakeShape(dstInnerShapeCol, CeilDiv(cols, dstInnerShapeCol))),
layout.stride());
}
}
template <class T, class U>
HOST_DEVICE constexpr
auto MakeLayoutL0C(T const& rows, U const& cols)
{
constexpr uint32_t ELE_NUM_PER_FRACTAL = 256;
return MakeLayout(
MakeShape(MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(rows, Int<NpuArch::C0_NUM_PER_FRACTAL>{})),
MakeShape(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, CeilDiv(cols, Int<NpuArch::C0_NUM_PER_FRACTAL>{}))),
MakeStride(MakeStride(Int<NpuArch::C0_NUM_PER_FRACTAL>{}, Int<ELE_NUM_PER_FRACTAL>{}),
MakeStride(
Int<1>{}, RoundUp((int64_t)rows, Int<NpuArch::C0_NUM_PER_FRACTAL>{}) * NpuArch::C0_NUM_PER_FRACTAL)));
}
}
# endif