* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file layout_utils.h
* \brief
*/
#ifndef UTILS_LAYOUT_UTILS_H
#define UTILS_LAYOUT_UTILS_H
#include "./integral_constant.h"
namespace Cgmct {
namespace Gemm {
namespace layout {
struct RowMajor {};
struct RowMajorAlign {};
struct ColumnMajor {};
struct ColumnMajorAlign {};
struct Nz {};
struct Zn {};
}
constexpr int32_t C0_BYTE_SIZE = 32;
constexpr int32_t C0_SIZE_FP16 = C0_BYTE_SIZE / sizeof(half);
constexpr int32_t C0_SIZE_BF16 = C0_BYTE_SIZE / sizeof(bfloat16_t);
constexpr int32_t C0_SIZE_FP32 = C0_BYTE_SIZE / sizeof(float);
constexpr int32_t C0_NUM_PER_FRACTAL = 16;
constexpr int32_t C0_SIZE_L0C = 16;
template <typename T>
struct TagToFormat {
static_assert(AscendC::Std::always_false_v<T>, "TagToFormat is not implemented for this layout");
};
template <>
struct TagToFormat<layout::RowMajor> {
using tag = layout::RowMajor;
static constexpr CubeFormat format = CubeFormat::ND;
};
template <>
struct TagToFormat<layout::RowMajorAlign> {
using tag = layout::RowMajorAlign;
static constexpr CubeFormat format = CubeFormat::ND_ALIGN;
};
template <>
struct TagToFormat<layout::ColumnMajor> {
using tag = layout::ColumnMajor;
static constexpr CubeFormat format = CubeFormat::ND;
};
template <>
struct TagToFormat<layout::ColumnMajorAlign> {
using tag = layout::ColumnMajorAlign;
static constexpr CubeFormat format = CubeFormat::ND_ALIGN;
};
template <>
struct TagToFormat<layout::Zn> {
using tag = layout::Zn;
static constexpr CubeFormat format = CubeFormat::NZ;
};
template <>
struct TagToFormat<layout::Nz> {
using tag = layout::Nz;
static constexpr CubeFormat format = CubeFormat::NZ;
};
template <typename T>
struct TagToTrans {
static_assert(AscendC::Std::always_false_v<T>, "TagToTrans is not implemented for this layout");
};
template <>
struct TagToTrans<layout::RowMajor> {
static constexpr bool value = false;
};
template <>
struct TagToTrans<layout::RowMajorAlign> {
static constexpr bool value = false;
};
template <>
struct TagToTrans<layout::ColumnMajor> {
static constexpr bool value = true;
};
template <>
struct TagToTrans<layout::ColumnMajorAlign> {
static constexpr bool value = false;
};
template <>
struct TagToTrans<layout::Zn> {
static constexpr bool value = true;
};
template <>
struct TagToTrans<layout::Nz> {
static constexpr bool value = false;
};
template <typename T, typename Layout>
struct LayoutToFormat {
static_assert(AscendC::Std::always_false_v<Layout>, "LayoutToFormat is not implemented for this layout");
};
template <typename T, typename IntType>
struct LayoutToFormat<T, AscendC::Layout<AscendC::Shape<IntType, IntType>, AscendC::Stride<IntType, _1>>> {
static_assert(AscendC::Std::is_integral_v<IntType>, "Layout element type must be integral.");
static constexpr CubeFormat format = CubeFormat::ND;
};
template <typename IntType>
struct LayoutToFormat<
half, AscendC::Layout<AscendC::Shape<AscendC::Shape<Int<C0_NUM_PER_FRACTAL>, IntType>,
AscendC::Shape<Int<C0_SIZE_FP16>, IntType>>,
AscendC::Stride<AscendC::Stride<Int<C0_SIZE_FP16>, Int<C0_NUM_PER_FRACTAL * C0_SIZE_FP16>>,
AscendC::Stride<_1, IntType>>>> {
static_assert(AscendC::Std::is_integral_v<IntType>, "Layout element type must be integral.");
static constexpr CubeFormat format = CubeFormat::NZ;
};
template <typename IntType>
struct LayoutToFormat<
bfloat16_t,
AscendC::Layout<
AscendC::Shape<AscendC::Shape<Int<C0_NUM_PER_FRACTAL>, IntType>, AscendC::Shape<Int<C0_SIZE_BF16>, IntType>>,
AscendC::Stride<AscendC::Stride<Int<C0_SIZE_BF16>, Int<C0_NUM_PER_FRACTAL * C0_SIZE_BF16>>,
AscendC::Stride<_1, IntType>>>> {
static_assert(AscendC::Std::is_integral_v<IntType>, "Layout element type must be integral.");
static constexpr CubeFormat format = CubeFormat::NZ;
};
template <typename IntType>
struct LayoutToFormat<
float, AscendC::Layout<AscendC::Shape<AscendC::Shape<Int<C0_NUM_PER_FRACTAL>, IntType>,
AscendC::Shape<Int<C0_SIZE_FP32>, IntType>>,
AscendC::Stride<AscendC::Stride<Int<C0_SIZE_FP32>, Int<C0_NUM_PER_FRACTAL * C0_SIZE_FP32>>,
AscendC::Stride<_1, IntType>>>> {
static_assert(AscendC::Std::is_integral_v<IntType>, "Layout element type must be integral.");
static constexpr CubeFormat format = CubeFormat::NZ;
};
template <typename T, typename Layout>
inline constexpr CubeFormat LayoutToFormatV = LayoutToFormat<T, Layout>::format;
template <typename T, CubeFormat format>
struct FormatToLayout {
using type = CubeFormat;
};
template <typename T>
struct FormatToLayout<T, CubeFormat::ND> {
using type = AscendC::Layout<AscendC::Shape<int, int>, AscendC::Stride<int, _1>>;
};
template <typename T>
struct FormatToLayout<T, CubeFormat::ND_ALIGN> {
using type = AscendC::Layout<AscendC::Shape<int, int>, AscendC::Stride<int, _1>>;
};
template <typename T>
struct FormatToLayout<T, CubeFormat::NZ> {
static constexpr int C0_SIZE = C0_BYTE_SIZE / sizeof(T);
using type = AscendC::Layout<
AscendC::Shape<AscendC::Shape<Int<C0_NUM_PER_FRACTAL>, int>,
AscendC::Shape<Int<C0_SIZE>, int>
>,
AscendC::Stride<AscendC::Stride<Int<C0_SIZE>, Int<C0_NUM_PER_FRACTAL * C0_SIZE>>, AscendC::Stride<_1, int>>>;
};
template <typename T>
struct FormatToLayout<T, CubeFormat::ZN> {
static constexpr int C0_SIZE = C0_BYTE_SIZE / sizeof(T);
using type = AscendC::Layout<
AscendC::Shape<AscendC::Shape<Int<C0_SIZE>, int>, AscendC::Shape<_16, int>>,
AscendC::Stride<AscendC::Stride<_1, int>, AscendC::Stride<Int<C0_SIZE>, Int<C0_NUM_PER_FRACTAL * C0_SIZE>>>>;
};
template <typename T>
struct FormatToLayout<T, CubeFormat::ZZ> {
static constexpr int C0_SIZE = C0_BYTE_SIZE / sizeof(T);
using type = AscendC::Layout<
AscendC::Shape<AscendC::Shape<_16, int>, AscendC::Shape<Int<C0_SIZE>, int>>,
AscendC::Stride<AscendC::Stride<Int<C0_SIZE>, int>, AscendC::Stride<_1, Int<C0_NUM_PER_FRACTAL * C0_SIZE>>>>;
};
template <typename T, CubeFormat format>
using FormatToLayoutT = typename FormatToLayout<T, format>::type;
template <typename T, CubeFormat format>
__aicore__ constexpr inline decltype(auto) MakeLayoutByFormat(int row, int col)
{
static_assert(format == CubeFormat::ND || format == CubeFormat::ND_ALIGN || format == CubeFormat::NZ ||
format == CubeFormat::ZN || format == CubeFormat::ZZ,
"Unsupport format");
constexpr int c0Size = C0_BYTE_SIZE / sizeof(T);
if constexpr (format == CubeFormat::ND || format == CubeFormat::ND_ALIGN) {
return AscendC::MakeLayout(AscendC::MakeShape(row, col), AscendC::MakeStride(col, _1{}));
} else if constexpr (format == CubeFormat::NZ) {
return AscendC::MakeLayout(
AscendC::MakeShape(AscendC::MakeShape(_16{}, AscendC::Ceil(row, C0_NUM_PER_FRACTAL)),
AscendC::MakeShape(Int<c0Size>{}, AscendC::Ceil(col, c0Size))),
AscendC::MakeStride(AscendC::MakeStride(Int<c0Size>{}, Int<C0_NUM_PER_FRACTAL * c0Size>{}),
AscendC::MakeStride(_1{}, AscendC::CeilAlign(row, C0_NUM_PER_FRACTAL) * c0Size)));
} else if constexpr (format == CubeFormat::ZN) {
return AscendC::MakeLayout(
AscendC::MakeShape(AscendC::MakeShape(Int<c0Size>{}, AscendC::Ceil(row, c0Size)),
AscendC::MakeShape(_16{}, AscendC::Ceil(col, C0_NUM_PER_FRACTAL))),
AscendC::MakeStride(AscendC::MakeStride(_1{}, AscendC::CeilAlign(col, C0_NUM_PER_FRACTAL) * c0Size),
AscendC::MakeStride(Int<c0Size>{}, Int<C0_NUM_PER_FRACTAL * c0Size>{})));
} else {
return AscendC::MakeLayout(AscendC::MakeShape(AscendC::MakeShape(_16{}, AscendC::Ceil(row, C0_NUM_PER_FRACTAL)),
AscendC::MakeShape(Int<c0Size>{}, AscendC::Ceil(col, c0Size))),
AscendC::MakeStride(AscendC::MakeStride(Int<c0Size>{}, AscendC::CeilAlign(col, c0Size) *
C0_NUM_PER_FRACTAL),
AscendC::MakeStride(_1{}, Int<C0_NUM_PER_FRACTAL * c0Size>{})));
}
}
}
}
#endif