* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License")
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file cast_tiling.h
* \brief
*/
#ifndef OPS_BUILD_IN_OP_TILING_RUNTIME_CAST_TILING_H
#define OPS_BUILD_IN_OP_TILING_RUNTIME_CAST_TILING_H
#include "platform/platform_info.h"
#include "register/op_impl_registry.h"
#include "register/tilingdata_base.h"
#include "op_host/tiling_base_class.h"
#include "log/log.h"
namespace optiling {
struct CastCompileInfo {
uint64_t coreNum;
uint64_t ubSize;
};
BEGIN_TILING_DATA_DEF(CastTilingData)
TILING_DATA_FIELD_DEF(int64_t, blockNum);
TILING_DATA_FIELD_DEF(int64_t, ubFormer);
TILING_DATA_FIELD_DEF(int64_t, blockFormer);
TILING_DATA_FIELD_DEF(int64_t, ubLoopOfFormerBlock);
TILING_DATA_FIELD_DEF(int64_t, ubLoopOfTailBlock);
TILING_DATA_FIELD_DEF(int64_t, ubTailOfFormerBlock);
TILING_DATA_FIELD_DEF(int64_t, ubTailOfTailBlock);
TILING_DATA_FIELD_DEF(int64_t, regCopyInStep);
TILING_DATA_FIELD_DEF(int64_t, regCopyOutStep);
TILING_DATA_FIELD_DEF(int64_t, ubFormerRegLoop);
TILING_DATA_FIELD_DEF(int64_t, ubTailOfFormerRegLoop);
TILING_DATA_FIELD_DEF(int64_t, ubTailOfTailRegLoop);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(Cast, CastTilingData);
struct CastMapSt {
ge::DataType srcType_;
ge::DataType dstType_;
uint8_t id_;
uint8_t srcMapType_;
uint8_t dstMapType_;
uint8_t midType_;
uint8_t castMode1_;
uint8_t castMode2_;
uint8_t regCopyInMode_;
uint8_t regCopyOutMode_;
CastMapSt() {}
CastMapSt(ge::DataType srcType, ge::DataType dstType, uint8_t id,
uint8_t srcMapType, uint8_t dstMapType, uint8_t midType,
uint8_t castMode1, uint8_t castMode2,
uint8_t regCopyInMode, uint8_t regCopyOutMode)
: srcType_(srcType), dstType_(dstType), id_(id), srcMapType_(srcMapType),
dstMapType_(dstMapType), midType_(midType), castMode1_(castMode1), castMode2_(castMode2),
regCopyInMode_(regCopyInMode), regCopyOutMode_(regCopyOutMode) {}
};
class CastTiling : public Ops::Base::TilingBaseClass {
public:
explicit CastTiling(gert::TilingContext *context) : Ops::Base::TilingBaseClass(context)
{}
protected:
bool IsCapable() override;
ge::graphStatus GetPlatformInfo() override;
ge::graphStatus GetShapeAttrsInfo() override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus DoLibApiTiling() override;
uint64_t GetTilingKey() const override;
ge::graphStatus GetWorkspaceSize() override;
ge::graphStatus PostTiling() override;
private:
int64_t GetUbCopyStep(uint8_t inType, uint8_t outType,
uint8_t copyType, int64_t &oneLoopCopyInBitSize);
int64_t GetDtypeBitSize(uint8_t dtype);
int64_t GetGeDtypeBitSize(ge::DataType dtype);
int64_t GetUbFormer(int64_t inputTypeBitSize, int64_t outputTypeBitSize);
bool IsSimt();
ge::DataType TransAclToGeDataType(int32_t aclType);
int64_t coreNum_{ 0 };
int64_t ubSize_{ 0 };
int64_t vlBitSize_{2048};
int64_t shapeSize_ {0};
CastTilingData tilingData_;
CastMapSt policy_;
};
}
#endif