* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef TILING_PARAMS_H
#define TILING_PARAMS_H
#include <cstdint>
enum class LayoutTag : uint8_t { TagRowMajor = 0, TagColumnMajor = 1};
* Bit field layout description (little-endian):
* -------------------------------------------------------------------------
* | Bit Range | Size | Field Name | Description |
* |-----------|------|-----------------------|----------------------------|
* | 0-3 | 4 | layoutTagC | Layout tag for C matrix |
* | 4-7 | 4 | layoutTagB | Layout tag for B matrix |
* | 8-11 | 4 | layoutTagA | Layout tag for A matrix |
* | 12-15 | 4 | paddingTagC | Padding tag for C matrix |
* | 16-19 | 4 | paddingTagB | Padding tag for B matrix |
* | 20-23 | 4 | paddingTagA | Padding tag for A matrix |
* | 24-51 | 28 | reserveBit | Reserved for future use |
* | 52-55 | 4 | dtype | Data type specification |
* | 56-63 | 8 | templateKernelSerial | Template kernel serial ID |
* -------------------------------------------------------------------------
*/
union TilingKey {
uint64_t value;
struct {
uint64_t layoutTagC : 4;
uint64_t layoutTagB : 4;
uint64_t layoutTagA : 4;
uint64_t paddingTagC : 4;
uint64_t paddingTagB : 4;
uint64_t paddingTagA : 4;
uint64_t reserveBit : 28;
uint64_t dtype : 4;
uint64_t templateKernelSerial : 8;
} bits;
TilingKey() : value(0) {}
TilingKey(uint64_t v) : value(v) {}
uint8_t GetLayoutTagA() const {return bits.layoutTagA;}
uint8_t GetLayoutTagB() const {return bits.layoutTagB;}
uint8_t GetLayoutTagC() const {return bits.layoutTagC;}
uint8_t GetPaddingTagTagA() const {return bits.paddingTagA;}
uint8_t GetPaddingTagTagB() const {return bits.paddingTagB;}
uint8_t GetPaddingTagTagC() const {return bits.paddingTagC;}
uint8_t GetDtype() const {return bits.dtype;}
uint8_t GetKernelSerial() const {return bits.templateKernelSerial;}
void SetKernelSerial(uint8_t kernelSerial) { bits.templateKernelSerial = kernelSerial;}
void SetLayoutTagA(uint8_t layoutTagA) { bits.layoutTagA = layoutTagA & 0xF; }
void SetLayoutTagB(uint8_t layoutTagB) { bits.layoutTagB = layoutTagB & 0xF; }
void SetLayoutTagC(uint8_t layoutTagC) { bits.layoutTagC = layoutTagC & 0xF; }
void SetPaddingTagA(uint8_t paddingTagA) { bits.paddingTagA = paddingTagA & 0xF; }
void SetPaddingTagB(uint8_t paddingTagB) { bits.paddingTagB = paddingTagB & 0xF; }
void SetPaddingTagC(uint8_t paddingTagC) { bits.paddingTagC = paddingTagC & 0xF; }
void SetDtype(uint8_t dtype) { bits.dtype = dtype & 0xF; }
void SetTilingKey(uint8_t kernelSerial, uint8_t layoutTagA, uint8_t layoutTagB, uint8_t layoutTagC,
uint8_t paddingTagA, uint8_t paddingTagB, uint8_t paddingTagC, uint8_t dtype = 0)
{
SetKernelSerial(kernelSerial);
SetLayoutTagA(layoutTagA);
SetLayoutTagB(layoutTagB);
SetLayoutTagC(layoutTagC);
SetPaddingTagA(paddingTagA);
SetPaddingTagB(paddingTagB);
SetPaddingTagC(paddingTagC);
SetDtype(dtype);
}
};
struct TilingParams {
uint64_t strideA{0};
uint64_t strideB{0};
uint64_t strideC{0};
uint32_t m{0};
uint32_t n{0};
uint32_t k{0};
uint16_t m1{0};
uint16_t n1{0};
uint16_t k1{0};
uint8_t swizzleOffset{1};
uint8_t swizzleDirection{0};
uint16_t splitkFactor{1};
uint16_t m0{0};
uint16_t n0{0};
uint16_t k0{0};
uint16_t m1Factor{1};
uint16_t n1Factor{1};
uint8_t layoutTagA;
uint8_t layoutTagB;
uint8_t layoutTagC;
uint8_t paddingTagA;
uint8_t paddingTagB;
uint8_t paddingTagC;
uint8_t blockDim{0};
TilingKey tilingKey;
TilingParams() {}
TilingParams(uint32_t m_, uint32_t n_, uint32_t k_,
LayoutTag layoutTagA_, LayoutTag layoutTagB_, LayoutTag layoutTagC_)
: m(m_), n(n_), k(k_), layoutTagA(static_cast<uint8_t>(layoutTagA_)),
layoutTagB(static_cast<uint8_t>(layoutTagB_)), layoutTagC(static_cast<uint8_t>(layoutTagC_))
{
strideA = k;
strideB = n;
strideC = n;
if (static_cast<LayoutTag>(layoutTagA) == LayoutTag::TagColumnMajor) {
strideA = m;
}
if (static_cast<LayoutTag>(layoutTagB) == LayoutTag::TagColumnMajor) {
strideB = k;
}
if (static_cast<LayoutTag>(layoutTagC) == LayoutTag::TagColumnMajor) {
strideC = m;
}
}
TilingParams(uint32_t m_, uint32_t n_, uint32_t k_, size_t strideA_, size_t strideB_, size_t strideC_,
LayoutTag layoutTagA_, LayoutTag layoutTagB_, LayoutTag layoutTagC_)
: m(m_), n(n_), k(k_), strideA(strideA_), strideB(strideB_), strideC(strideC_),
layoutTagA(static_cast<uint8_t>(layoutTagA_)), layoutTagB(static_cast<uint8_t>(layoutTagB_)),
layoutTagC(static_cast<uint8_t>(layoutTagC_))
{}
void SetParams(uint32_t m_, uint32_t n_, uint32_t k_, size_t strideA_, size_t strideB_, size_t strideC_,
LayoutTag layoutTagA_, LayoutTag layoutTagB_, LayoutTag layoutTagC_)
{
m = m_;
n = n_;
k = k_;
strideA = strideA_;
strideB = strideB_;
strideC = strideC_;
layoutTagA = static_cast<uint8_t>(layoutTagA_);
layoutTagB = static_cast<uint8_t>(layoutTagB_);
layoutTagC = static_cast<uint8_t>(layoutTagC_);
}
void SetParams(uint32_t m_, uint32_t n_, uint32_t k_, LayoutTag layoutTagA_, LayoutTag layoutTagB_, LayoutTag layoutTagC_)
{
m = m_;
n = n_;
k = k_;
strideA = k;
strideB = n;
strideC = n;
if (static_cast<LayoutTag>(layoutTagA) == LayoutTag::TagColumnMajor) {
strideA = m;
}
if (static_cast<LayoutTag>(layoutTagB) == LayoutTag::TagColumnMajor) {
strideB = k;
}
if (static_cast<LayoutTag>(layoutTagC) == LayoutTag::TagColumnMajor) {
strideC = m;
}
layoutTagA = static_cast<uint8_t>(layoutTagA_);
layoutTagB = static_cast<uint8_t>(layoutTagB_);
layoutTagC = static_cast<uint8_t>(layoutTagC_);
}
};
#endif