* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tbe_tiling_api.h
* \brief
*/
#ifndef TBE_TILING_API_H
#define TBE_TILING_API_H
#include <cstdint>
#include <exe_graph/runtime/tiling_context.h>
#include <tiling/platform/platform_ascendc.h>
#include "graph/utils/type_utils.h"
#include "platform/platform_infos_def.h"
namespace optiling {
struct Conv3dBackpropV2TBETilingData {
int32_t m_l0;
int32_t k_l0;
int32_t n_l0;
int32_t m_al1;
int32_t n_bl1;
int32_t k_al1;
int32_t k_bl1;
int32_t db_l0c;
int32_t db_al1;
int32_t db_bl1;
int32_t batch_dim;
int32_t d_dim;
int32_t group_dim;
int32_t m_dim;
int32_t n_dim;
int32_t k_dim;
};
struct Conv3dBpFilterV2RunInfo {
int32_t batch;
int32_t co;
int32_t ci;
int32_t cout1_g;
int32_t cin1_g;
int32_t dout;
int32_t wo;
int32_t ho;
int32_t wi;
int32_t hi;
int32_t di;
int32_t kw;
int32_t kh;
int32_t kd;
int32_t real_g;
int32_t stride_w;
int32_t stride_h;
int32_t stride_d;
int32_t pad_l;
int32_t pad_r;
int32_t pad_u;
int32_t pad_d;
int32_t pad_f;
int32_t pad_b;
int32_t dilation_w;
int32_t dilation_h;
int32_t dilation_d;
int32_t ci1;
uint64_t bl1_bound;
int32_t batch_dout_single_core;
uint32_t k0;
uint32_t m0;
uint32_t n0;
uint32_t hf32Flag;
ge::DataType a_dtype = ge::DT_FLOAT16;
ge::DataType b_dtype = ge::DT_FLOAT16;
ge::DataType c_dtype = ge::DT_FLOAT16;
int32_t a_dtype_bytes = 2;
int32_t b_dtype_bytes = 2;
int32_t c_dtype_bytes = 2;
uint32_t core_num;
};
struct Conv3dBpInputV2RunInfo {
int32_t batch_n;
int32_t real_g;
int32_t dedx_d;
int32_t dedx_cin;
int32_t dedx_cin1;
int32_t dedx_cin1_g;
int32_t dedx_h;
int32_t dedx_w;
int32_t dedy_d;
int32_t dedy_cout;
int32_t dedy_cout1;
int32_t dedy_cout1_g;
int32_t dedy_h;
int32_t dedy_w;
int32_t kernel_d;
int32_t kernel_h;
int32_t kernel_w;
int32_t stride_d;
int32_t stride_h;
int32_t stride_w;
int32_t pad_h;
int32_t pad_t;
int32_t pad_u;
int32_t pad_d;
int32_t pad_l;
int32_t pad_r;
int32_t dilation_d;
int32_t dilation_h;
int32_t dilation_w;
int32_t backprop_pad_h;
int32_t backprop_pad_t;
int32_t backprop_pad_u;
int32_t backprop_pad_d;
int32_t backprop_pad_l;
int32_t backprop_pad_r;
int32_t hf32_flag;
int32_t a_dtype_bytes = 2;
int32_t b_dtype_bytes = 2;
int32_t c_dtype_bytes = 2;
int32_t initOutputFlag = 0;
};
struct Conv3DBackpropV2CompileInfo {
std::string soc_version = "";
platform_ascendc::SocVersion shortSocVersion = platform_ascendc::SocVersion::ASCEND910B;
uint32_t core_num = 0;
uint64_t ub_size = 0;
uint64_t l1_size = 0;
uint64_t l2_size = 0;
uint64_t l0a_size = 0;
uint64_t l0b_size = 0;
uint64_t l0c_size = 0;
uint64_t bt_size = 0;
int32_t cube_freq = 0;
bool load3d_constraints = true;
bool intrinsic_data_move_l12ub = true;
bool intrinsic_matmul_ub_to_ub = false;
bool intrinsic_conv_ub_to_ub = false;
bool intrinsic_data_move_l0c2ub = true;
bool intrinsic_fix_pipe_l0c2out = false;
bool intrinsic_fix_pipe_l0c2ub = false;
bool intrinsic_data_move_out2l1_nd2nz = false;
bool intrinsic_data_move_l12bt_bf16 = false;
};
enum OpTypeV2 : size_t {
kConv3DBackpropFilterV2,
kConv3DBackpropInputV2,
kConv3DTransposeV2,
};
bool GetTbeTiling(const gert::TilingContext* context, Conv3dBpFilterV2RunInfo& runInfoForV2, Conv3dBackpropV2TBETilingData& tbeTilingForV2);
bool GetTbeTiling(gert::TilingContext* context, Conv3dBpInputV2RunInfo& runInfoV2,
Conv3dBackpropV2TBETilingData& tbeTilingForV2, const optiling::OpTypeV2 opType);
bool GetTbeTiling(gert::TilingContext* context, Conv3dBackpropV2TBETilingData& tbeTilingForV2, const optiling::OpTypeV2 opType);
}
#endif