* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "library_helper.h"
#include <opdev/bfloat16.h>
namespace Catlass {
namespace {
using namespace Library;
template<typename T>
void ConstructLayoutT(DataType dataType, uint32_t a, uint32_t b, uint8_t *data)
{
switch (dataType) {
case DataType::U8:
*reinterpret_cast<T*>(data) = T::template MakeLayout<uint8_t>(a, b);
break;
case DataType::Int8:
*reinterpret_cast<T*>(data) = T::template MakeLayout<int8_t>(a, b);
break;
case DataType::Int32:
*reinterpret_cast<T*>(data) = T::template MakeLayout<int32_t>(a, b);
break;
case DataType::Fp16:
*reinterpret_cast<T*>(data) = T::template MakeLayout<op::fp16_t>(a, b);
break;
case DataType::Bf16:
*reinterpret_cast<T*>(data) = T::template MakeLayout<op::bfloat16>(a, b);
break;
case DataType::Fp32:
*reinterpret_cast<T*>(data) = T::template MakeLayout<float>(a, b);
break;
default:
break;
}
}
}
size_t LibraryHelper::GetDataTypeSize(LibraryHelper::DataType dataType)
{
constexpr size_t BYTE = 1;
constexpr size_t WORD = 2;
constexpr size_t DWORD = 4;
switch (dataType) {
case DataType::U8:
case DataType::Int8:
return BYTE;
case DataType::Fp16:
case DataType::Bf16:
return WORD;
case DataType::Int32:
case DataType::Fp32:
return DWORD;
default:
return 0;
}
}
size_t LibraryHelper::GetLayoutSize(LibraryHelper::LayoutType layoutType)
{
switch (layoutType) {
case LayoutType::RowMajor:
return sizeof(layout::RowMajor);
case LayoutType::ColumnMajor:
return sizeof(layout::ColumnMajor);
case LayoutType::nZ:
return sizeof(layout::nZ);
case LayoutType::zN:
return sizeof(layout::zN);
case LayoutType::zZ:
return sizeof(layout::zZ);
case LayoutType::PaddingRowMajor:
return sizeof(layout::PaddingRowMajor);
case LayoutType::PaddingColumnMajor:
return sizeof(layout::PaddingColumnMajor);
case LayoutType::nN:
return sizeof(layout::nN);
default:
return 0;
}
}
std::string_view LibraryHelper::GetDataTypeStr(LibraryHelper::DataType dataType)
{
switch (dataType) {
case DataType::U8:
return "u8";
case DataType::Int8:
return "int8";
case DataType::Int32:
return "int32";
case DataType::Fp16:
return "fp16";
case DataType::Bf16:
return "bf16";
case DataType::Fp32:
return "fp32";
default:
return "";
}
}
std::string_view LibraryHelper::GetLayoutStr(LibraryHelper::LayoutType layoutType)
{
switch (layoutType) {
case LayoutType::RowMajor:
return "row";
case LayoutType::ColumnMajor:
return "column";
case LayoutType::nZ:
return "nZ";
case LayoutType::zN:
return "zN";
case LayoutType::zZ:
return "zZ";
case LayoutType::PaddingRowMajor:
return "padding_row_major";
case LayoutType::PaddingColumnMajor:
return "padding_column_major";
case LayoutType::nN:
return "nN";
default:
return "";
}
}
LibraryHelper::DataType LibraryHelper::GetDataTypeEnum(std::string_view str)
{
static std::unordered_map<std::string_view, DataType> STR_TO_DTYPE = {
{"u8", DataType::U8},
{"int8", DataType::Int8},
{"int32", DataType::Int32},
{"fp16", DataType::Fp16},
{"bf16", DataType::Bf16},
{"fp32", DataType::Fp32},
};
auto it = STR_TO_DTYPE.find(str);
if (it == STR_TO_DTYPE.end()) {
return DataType::Invalid;
}
return it->second;
}
LibraryHelper::LayoutType LibraryHelper::GetLayoutEnum(std::string_view str)
{
static std::unordered_map<std::string_view, LayoutType> STR_TO_LAYOUT = {
{"row", LayoutType::RowMajor},
{"column", LayoutType::ColumnMajor},
{"nZ", LayoutType::nZ},
{"zN", LayoutType::zN},
{"zZ", LayoutType::zZ},
{"padding_row_major", LayoutType::PaddingRowMajor},
{"padding_column_major", LayoutType::PaddingColumnMajor},
{"nN", LayoutType::nN},
};
auto it = STR_TO_LAYOUT.find(str);
if (it == STR_TO_LAYOUT.end()) {
return LayoutType::Invalid;
}
return it->second;
}
void LibraryHelper::ConstructLayout(LayoutType layoutType, DataType dataType, uint32_t a, uint32_t b, uint8_t *data)
{
switch (layoutType) {
case LayoutType::RowMajor:
*reinterpret_cast<layout::RowMajor*>(data) = layout::RowMajor(a, b);
break;
case LayoutType::ColumnMajor:
*reinterpret_cast<layout::ColumnMajor*>(data) = layout::ColumnMajor(a, b);
break;
case LayoutType::nZ:
ConstructLayoutT<layout::nZ>(dataType, a, b, data);
break;
case LayoutType::zN:
ConstructLayoutT<layout::zN>(dataType, a, b, data);
break;
case LayoutType::zZ:
ConstructLayoutT<layout::zZ>(dataType, a, b, data);
break;
case LayoutType::PaddingRowMajor:
*reinterpret_cast<layout::PaddingRowMajor*>(data) = layout::PaddingRowMajor(a, b);
break;
case LayoutType::PaddingColumnMajor:
*reinterpret_cast<layout::PaddingColumnMajor*>(data) = layout::PaddingColumnMajor(a, b);
break;
case LayoutType::nN:
ConstructLayoutT<layout::nN>(dataType, a, b, data);
break;
default:
break;
}
}
}