* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file pow_tiling_arch35.h
* \brief
*/
#ifndef POW_CANNDEV_OPS_BUILT_IN_OP_TILING_RUNTIME_POW_POW_TILING_H_
#define POW_CANNDEV_OPS_BUILT_IN_OP_TILING_RUNTIME_POW_POW_TILING_H_
#include <map>
#include <tuple>
#include "register/op_impl_registry.h"
#include "op_host/tiling_base_class.h"
#include "register/tilingdata_base.h"
#include "atvoss/broadcast/broadcast_tiling.h"
#include "log/log.h"
namespace optiling
{
const int INPUT_MAX_DIMS = 8;
using namespace Ops::Base;
BEGIN_TILING_DATA_DEF(PowBaseTilingData)
TILING_DATA_FIELD_DEF(int64_t, blockNum);
TILING_DATA_FIELD_DEF(int64_t, ubSize);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(Pow, PowBaseTilingData)
BEGIN_TILING_DATA_DEF(PowTensorScalarTilingData)
TILING_DATA_FIELD_DEF(int64_t, blockNum);
TILING_DATA_FIELD_DEF(int64_t, blockFactor);
TILING_DATA_FIELD_DEF(int64_t, blockTail);
TILING_DATA_FIELD_DEF(int64_t, ubOuter);
TILING_DATA_FIELD_DEF(int64_t, ubFactor);
TILING_DATA_FIELD_DEF(int64_t, ubTail);
TILING_DATA_FIELD_DEF(int64_t, totalSize);
TILING_DATA_FIELD_DEF(int64_t, ubSize);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(Pow_1001, PowTensorScalarTilingData)
REGISTER_TILING_DATA_CLASS(Pow_2001, PowTensorScalarTilingData)
REGISTER_TILING_DATA_CLASS(Pow_3001, PowTensorScalarTilingData)
REGISTER_TILING_DATA_CLASS(Pow_4001, PowTensorScalarTilingData)
REGISTER_TILING_DATA_CLASS(Pow_5001, PowTensorScalarTilingData)
REGISTER_TILING_DATA_CLASS(Pow_6001, PowTensorScalarTilingData)
REGISTER_TILING_DATA_CLASS(Pow_7001, PowTensorScalarTilingData)
BEGIN_TILING_DATA_DEF(PowTensorTensorTilingData)
TILING_DATA_FIELD_DEF(int64_t, blockFormer);
TILING_DATA_FIELD_DEF(int64_t, ubFormer);
TILING_DATA_FIELD_DEF(int64_t, ubOuter);
TILING_DATA_FIELD_DEF(int64_t, ubTail);
TILING_DATA_FIELD_DEF(int64_t, blockTail);
TILING_DATA_FIELD_DEF(int64_t, shapeLen);
TILING_DATA_FIELD_DEF(int64_t, ubSplitAxis);
TILING_DATA_FIELD_DEF(int64_t, dimProductBeforeUbInner);
TILING_DATA_FIELD_DEF(int64_t, elemNum);
TILING_DATA_FIELD_DEF_ARR(int64_t, INPUT_MAX_DIMS, input0Dims);
TILING_DATA_FIELD_DEF_ARR(int64_t, INPUT_MAX_DIMS, input1Dims);
TILING_DATA_FIELD_DEF_ARR(int64_t, INPUT_MAX_DIMS, outputDims);
TILING_DATA_FIELD_DEF_ARR(int64_t, INPUT_MAX_DIMS, input0Strides);
TILING_DATA_FIELD_DEF_ARR(int64_t, INPUT_MAX_DIMS, input1Strides);
TILING_DATA_FIELD_DEF_ARR(int64_t, INPUT_MAX_DIMS, outputStrides);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(Pow_100000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_100000001001100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_200000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_200000001001100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_300000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_300000001001100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_400000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_400000001001100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_500000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_500000001001100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_600000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_600000001001100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_700000001000100, PowTensorTensorTilingData)
REGISTER_TILING_DATA_CLASS(Pow_700000001001100, PowTensorTensorTilingData)
struct PowCompileInfo {
int64_t coreNum = 0;
int64_t ubSize = 0;
bool isRegBase = false;
int64_t vectorLength = 0;
int64_t blockSize = 0;
};
struct InputParams {
int64_t coreNum;
int64_t ubSize;
bool isRegBase = false;
int64_t vectorLength = 0;
int64_t blockSize = 0;
gert::Shape baseShape;
gert::Shape expShape;
gert::Shape powShape;
ge::DataType baseDtype;
ge::DataType expDtype;
ge::DataType powDtype;
int64_t baseDtypeSize;
int64_t computeDtypeSize;
};
class PowTilingBase : public Ops::Base::TilingBaseClass
{
public:
explicit PowTilingBase(gert::TilingContext* context) : Ops::Base::TilingBaseClass(context)
{
}
~PowTilingBase() override
{
}
InputParams params_;
protected:
bool IsCapable() override;
ge::graphStatus GetPlatformInfo() override;
ge::graphStatus GetShapeAttrsInfo() override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus DoLibApiTiling() override;
uint64_t GetTilingKey() const override;
ge::graphStatus GetWorkspaceSize() override;
ge::graphStatus PostTiling() override;
};
class PowTensorScalarTiling : public PowTilingBase
{
public:
explicit PowTensorScalarTiling(gert::TilingContext* context) : PowTilingBase(context)
{
}
~PowTensorScalarTiling() override
{
}
PowTensorScalarTilingData tilingData_;
protected:
bool IsCapable() override;
uint64_t GetTilingKey() const override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus PostTiling() override;
int64_t GetUbFactor(int64_t dim0, int64_t oriUbFactor) const;
};
class PowTensorTensorTiling : public PowTilingBase
{
public:
explicit PowTensorTensorTiling(gert::TilingContext* context) : PowTilingBase(context)
{
}
~PowTensorTensorTiling() override
{
}
PowTensorTensorTilingData tilingData_;
protected:
bool IsCapable() override;
uint64_t GetTilingKey() const override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus PostTiling() override;
void SetOpKey();
uint64_t GetOpKey(ge::DataType inputXDtype, ge::DataType inputYDtype, ge::DataType outputZDtype);
uint64_t GenerateTilingKey(uint64_t innerKey);
std::map<uint64_t, BroadcastComputeParams> GetComputeMap(uint64_t opKey);
uint64_t opKey_{0};
uint64_t blockNum{0};
std::map<std::tuple<ge::DataType, ge::DataType, ge::DataType>, uint64_t> opKeys;
int64_t input0Dims[INPUT_MAX_DIMS] = {0};
int64_t input1Dims[INPUT_MAX_DIMS] = {0};
int64_t outputDims[INPUT_MAX_DIMS] = {0};
int64_t input0Strides[INPUT_MAX_DIMS] = {0};
int64_t input1Strides[INPUT_MAX_DIMS] = {0};
int64_t outputStrides[INPUT_MAX_DIMS] = {0};
};
}
#endif