* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "micro_api_call_factory.h"
#include "ascir_ops.h"
#include "micro_load_api_call.h"
#include "micro_dtype_utils.h"
namespace {
using codegen::DTYPE_SIZE_1BYTE;
using codegen::DTYPE_SIZE_2BYTE;
using codegen::DTYPE_SIZE_4BYTE;
using codegen::DTYPE_SIZE_8BYTE;
using codegen::GetDtypeSizeByName;
using codegen::GetUintDtypeNameBySize;
void DetermineLoadDistAndUnPackSequence(uint32_t input_dtype_size, uint32_t max_dtype_size, std::string &dist,
int &remaining_unpack_times, std::vector<std::string> &remaining_dtype_sequence,
std::string &load_result_dtype) {
if (max_dtype_size <= input_dtype_size) {
load_result_dtype = GetUintDtypeNameBySize(input_dtype_size);
return;
}
uint32_t load_result_size = input_dtype_size;
if (dist.empty()) {
if (input_dtype_size == DTYPE_SIZE_1BYTE) {
dist = (max_dtype_size >= DTYPE_SIZE_4BYTE) ? "DIST_UNPACK4_B8" : "DIST_UNPACK_B8";
load_result_size = (max_dtype_size >= DTYPE_SIZE_4BYTE) ? DTYPE_SIZE_4BYTE : DTYPE_SIZE_2BYTE;
} else if (input_dtype_size == DTYPE_SIZE_2BYTE) {
dist = "DIST_UNPACK_B16";
load_result_size = DTYPE_SIZE_4BYTE;
} else if (input_dtype_size == DTYPE_SIZE_4BYTE) {
dist = "DIST_UNPACK_B32";
load_result_size = DTYPE_SIZE_8BYTE;
}
}
load_result_dtype = GetUintDtypeNameBySize(load_result_size);
if (load_result_size < max_dtype_size) {
remaining_unpack_times = 0;
uint32_t current_size = load_result_size;
while (current_size < max_dtype_size) {
current_size *= 2;
remaining_unpack_times++;
if (current_size == DTYPE_SIZE_2BYTE) {
remaining_dtype_sequence.push_back("uint16_t");
} else if (current_size == DTYPE_SIZE_4BYTE) {
remaining_dtype_sequence.push_back("uint32_t");
} else if (current_size == DTYPE_SIZE_8BYTE) {
remaining_dtype_sequence.push_back("uint64_t");
}
}
}
}
}
namespace codegen {
Status MicroLoadApiCall::Generate(const TensorManager &tensor_mng, const TPipe &tpipe, CallParam ¶m,
string &result) {
std::stringstream ss;
auto tensor_id = GetOutputTensorIdByIndex(0);
GE_ASSERT_NOTNULL(tensor_mng.GetTensor(tensor_id));
auto tensor_ptr = tensor_mng.GetTensor(tensor_id);
auto input_dtype = tensor_ptr->dtype_;
std::string input_dtype_name;
Tensor::DtypeName(input_dtype, input_dtype_name);
uint32_t input_dtype_size = ge::GetSizeByDataType(input_dtype);
int remaining_unpack_times = 0;
std::vector<std::string> remaining_dtype_sequence;
std::string load_result_dtype;
if (!param.max_dtype_size.empty()) {
uint32_t max_dtype_size = GetDtypeSizeByName(param.max_dtype_size);
DetermineLoadDistAndUnPackSequence(input_dtype_size, max_dtype_size, this->dist_, remaining_unpack_times,
remaining_dtype_sequence, load_result_dtype);
} else {
load_result_dtype = GetUintDtypeNameBySize(input_dtype_size);
}
std::string load_template_params = "";
if (!this->dist_.empty()) {
load_template_params = "<" + input_dtype_name + ", AscendC::MicroAPI::LoadDist::" + this->dist_ + ">";
}
ss << "AscendC::MicroAPI::LoadAlign" << load_template_params << "(" << tensor_ptr->name << ", "
<< *(tpipe.GetTensor(this->GetInputTensorIdByIndex(0))) << " + " << param.offset << ");" << std::endl;
for (int i = 0; i < remaining_unpack_times && i < static_cast<int>(remaining_dtype_sequence.size()); ++i) {
if (remaining_dtype_sequence[i] == load_result_dtype) {
continue;
}
ss << "AscendC::Reg::UnPack<" << remaining_dtype_sequence[i] << ", " << load_result_dtype
<< ">((AscendC::Reg::RegTensor<" << remaining_dtype_sequence[i] << ">&)" << tensor_ptr->name
<< ", (AscendC::Reg::RegTensor<" << load_result_dtype << ">&)" << tensor_ptr->name << ");" << std::endl;
load_result_dtype = remaining_dtype_sequence[i];
}
result = ss.str();
return ge::SUCCESS;
}
Status MicroLoadApiCall::Init(const ascir::NodeView &node) {
(void)node;
return ge::SUCCESS;
}
Status MicroLoadApiCall::UpdateDistModeByStrideInfo(const TPipe &tpipe) {
auto tensor_id = GetInputTensorIdByIndex(0);
const Tensor *tensor_ptr = tpipe.GetTensor(tensor_id);
GE_ASSERT_NOTNULL(tensor_ptr);
ascir::SizeExpr last_dim_stride = tensor_ptr->vectorized_strides.back();
if (af::SymbolicUtils::StaticCheckEq(last_dim_stride.Simplify(), af::sym::kSymbolZero) != af::TriBool::kTrue) {
return ge::SUCCESS;
}
bool is_all_zero = std::all_of(
tensor_ptr->vectorized_strides.begin(), tensor_ptr->vectorized_strides.end(), [](const ascir::SizeExpr &stride) {
return af::SymbolicUtils::StaticCheckEq(stride.Simplify(), af::sym::kSymbolZero) == af::TriBool::kTrue;
});
if (is_all_zero) {
return ge::SUCCESS;
}
std::map<int, string> LOAD_BRC_DIST_MODE = {
{DTYPE_SIZE_1BYTE, "DIST_BRC_B8"}, {DTYPE_SIZE_2BYTE, "DIST_BRC_B16"}, {DTYPE_SIZE_4BYTE, "DIST_BRC_B32"}};
auto dtype_size = ge::GetSizeByDataType(tensor_ptr->dtype);
this->dist_ = LOAD_BRC_DIST_MODE[dtype_size];
return ge::SUCCESS;
}
static MicroApiCallRegister<MicroLoadApiCall> register_micro_load_api_call("MicroLoadApiCall");
}