* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* @brief Tile shape scaling utility — scales K tile dimension proportionally
* to element byte width relative to a baseline dtype.
*
* 使用方式:
* using L1 = TileShapeScaler<half, float, GemmShape<128, 256, 256>>::type;
* // half=baseline → KScale=1, L1 = GemmShape<128, 256, 256>
*
* using L1 = TileShapeScaler<float, half, GemmShape<128, 256, 256>>::type;
* // float vs half → KScale=2, L1 = GemmShape<128, 256, 128>
*
* using L1 = TileShapeScaler<ElementA, half, GemmShape<128, 256, 256>>::type;
* // ElementA 是根据 JIT 宏推导的实际 dtype
*/
#ifndef OPTEST_KERNELS_COMMON_TILE_SHAPE_SCALER_H
#define OPTEST_KERNELS_COMMON_TILE_SHAPE_SCALER_H
#include <type_traits>
namespace CatlassKernel {
* @brief Compute the K-dimension scale factor: sizeof(Actual) / sizeof(Baseline).
*
* KScale > 1 → 实际元素更宽,K tile 等比缩小以匹配 UB 预算。
* 要求 sizeof(Actual) 能被 sizeof(Baseline) 整除,K 能被 KScale 整除。
*/
template <typename ActualDtype, typename BaselineDtype>
struct KScaleFactor {
static constexpr uint32_t value = sizeof(ActualDtype) / sizeof(BaselineDtype);
static_assert(sizeof(ActualDtype) % sizeof(BaselineDtype) == 0,
"sizeof(Actual) must be a multiple of sizeof(Baseline)");
};
* @brief 将 GemmShape<M, N, K> 的 K 除以 KScale,M、N 不变。
*/
template <typename Shape, uint32_t KScale>
struct ScaleKTile {
static_assert(KScale > 0, "KScale must be positive");
using type = Shape;
};
template <uint32_t M, uint32_t N, uint32_t K, uint32_t KScale>
struct ScaleKTile<Catlass::GemmShape<M, N, K>, KScale> {
static_assert(K % KScale == 0, "K must be divisible by KScale");
using type = Catlass::GemmShape<M, N, K / KScale>;
};
* @brief 按元素宽度等比缩放 tile shape。
*
* @tparam ActualDtype 实际数据类型(如 JIT 宏派生的 ElementA)
* @tparam BaselineDtype 基线类型(通常为 half,即 shape 设计基准)
* @tparam BaseShape 基线 tile shape(如 GemmShape<128, 256, 256>)
*
* 输出:如果 ActualDtype 比 BaselineDtype 宽 N 倍,K tile 缩小为 1/N。
*/
template <typename ActualDtype, typename BaselineDtype, typename BaseShape>
struct TileShapeScaler {
static constexpr uint32_t kScale = KScaleFactor<ActualDtype, BaselineDtype>::value;
using type = typename ScaleKTile<BaseShape, kScale>::type;
};
}
#endif