* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OPS_MATH_DEV_TESTS_UT_COMMON_INFERSHAPE_CONTEXT_FAKER_H
#define OPS_MATH_DEV_TESTS_UT_COMMON_INFERSHAPE_CONTEXT_FAKER_H
#include "op_infer_shape_context_builder.h"
#include "any_value.h"
namespace gert {
class InfershapeContextPara {
public:
class TensorDescription {
public:
TensorDescription(const gert::StorageShape& shape, ge::DataType dtype, ge::Format format, bool isConst = false,
void* constValue = nullptr) :
shape_(shape), dtype_(dtype), format_(format), isConst_(isConst), constValue_(constValue) {}
public:
gert::StorageShape shape_;
ge::DataType dtype_ = ge::DT_FLOAT;
ge::Format format_ = ge::FORMAT_ND;
bool isConst_ = false;
void* constValue_ = nullptr;
};
class OpAttr {
public:
OpAttr(const std::string& attrName, const Ops::Math::AnyValue& attr) : attrName_(attrName), attr_(attr) {}
public:
std::string attrName_;
Ops::Math::AnyValue attr_;
};
public:
InfershapeContextPara(const std::string& opName,
const std::vector<TensorDescription>& inputTensorDesc,
const std::vector<TensorDescription>& outputTensorDesc,
const std::vector<OpAttr>& attrs,
const std::vector<uint32_t>& inputInstanceNum = {},
const std::vector<uint32_t>& outputInstanceNum = {}) :
opName_(opName),
inputTensorDesc_(inputTensorDesc),
outputTensorDesc_(outputTensorDesc),
attrs_(attrs),
inputInstanceNum_(inputInstanceNum),
outputInstanceNum_(outputInstanceNum) {}
InfershapeContextPara(const std::string& opName,
const std::vector<TensorDescription>& inputTensorDesc,
const std::vector<TensorDescription>& outputTensorDesc,
const std::vector<uint32_t>& inputInstanceNum = {},
const std::vector<uint32_t>& outputInstanceNum = {}) :
opName_(opName),
inputTensorDesc_(inputTensorDesc),
outputTensorDesc_(outputTensorDesc),
inputInstanceNum_(inputInstanceNum),
outputInstanceNum_(outputInstanceNum) {}
public:
std::string opName_;
std::vector<uint32_t> inputInstanceNum_;
std::vector<uint32_t> outputInstanceNum_;
std::vector<TensorDescription> inputTensorDesc_;
std::vector<TensorDescription> outputTensorDesc_;
std::vector<OpAttr> attrs_;
};
class InferShapeContextFaker : public OpInferShapeContextBuilder {
public:
InferShapeContextFaker& SetOpType(const std::string opType);
InferShapeContextFaker& NodeIoNum(size_t inputNum, size_t outputNum);
* only one can be choosed from NodeIoNum */
InferShapeContextFaker& IrInstanceNum(const std::vector<uint32_t>& inputInstanceNum,
const std::vector<uint32_t>& outputInstanceNum);
InferShapeContextFaker& NodeInputTd(int32_t index, ge::DataType dtype, ge::Format originFormat,
ge::Format storageFormat);
InferShapeContextFaker& NodeOutputTd(int32_t index, ge::DataType dtype, ge::Format originFormat,
ge::Format storageFormat);
InferShapeContextFaker& Attr(const std::string& attrName, bool attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, int64_t attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, float attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, const ge::AscendString& attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, const std::vector<bool>& attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, const std::vector<int64_t>& attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, const std::vector<float>& attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, const std::vector<ge::AscendString>& attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& Attr(const std::string& attrName, const std::vector<std::vector<int64_t>>& attr) {
OpInferShapeContextBuilder::AppendAttr(attr);
return *this;
}
InferShapeContextFaker& InputTensors(const std::vector<Tensor *>& inputTensors);
InferShapeContextFaker& OutputShapes(const std::vector<StorageShape *>& outputShapes);
ContextHolder<InferShapeContext> Build();
};
}
#endif