* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_CONVERTER_H_
#define MINDSPORE_LITE_INCLUDE_CONVERTER_H_
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "include/api/format.h"
#include "include/api/status.h"
#include "include/registry/converter_context.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
struct ConverterPara;
class MS_API Converter {
public:
Converter();
inline Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "",
const std::string &weight_file = "");
~Converter() = default;
inline void SetConfigFile(const std::string &config_file);
inline std::string GetConfigFile() const;
inline void SetConfigInfo(const std::string §ion, const std::map<std::string, std::string> &config);
inline std::map<std::string, std::map<std::string, std::string>> GetConfigInfo() const;
void SetWeightFp16(bool weight_fp16);
bool GetWeightFp16() const;
inline void SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape);
inline std::map<std::string, std::vector<int64_t>> GetInputShape() const;
void SetInputFormat(Format format);
Format GetInputFormat() const;
void SetOutputFormat(Format format);
void SetInputDataType(DataType data_type);
DataType GetInputDataType();
void SetOutputDataType(DataType data_type);
DataType GetOutputDataType();
void SetSaveType(ModelType save_type);
ModelType GetSaveType() const;
inline void SetDecryptKey(const std::string &key);
inline void SetDecryptMode(const std::string &mode);
inline std::string GetDecryptMode() const;
void SetEnableEncryption(bool encryption);
bool GetEnableEncryption() const;
inline void SetEncryptKey(const std::string &key);
void SetInfer(bool infer);
bool GetInfer() const;
void SetTrainModel(bool train_model);
bool GetTrainModel() const;
void SetNoFusion(bool no_fusion);
bool GetNoFusion();
void SetOptimizeTransformer(bool optimize_transformer);
bool GetOptimizeTransformer();
inline void SetDevice(const std::string &device);
inline std::string GetDevice();
void SetDeviceId(int32_t device_id);
int32_t GetDeviceId();
void SetRankId(int32_t rank_id);
int32_t GetRankId();
inline void SetProvider(const std::string &provider);
inline std::string GetProvider();
inline void SetChipName(const std::string &device);
inline std::string GetChipName();
Status Convert();
void *Convert(size_t *data_size);
inline Status Convert(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "",
const std::string &weight_file = "");
private:
Converter(converter::FmkType fmk_type, const std::vector<char> &model_file, const std::vector<char> &output_file,
const std::vector<char> &weight_file);
void SetConfigFile(const std::vector<char> &config_file);
std::vector<char> GetConfigFileChar() const;
void SetConfigInfo(const std::vector<char> §ion, const std::map<std::vector<char>, std::vector<char>> &config);
std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> GetConfigInfoChar() const;
void SetInputShape(const std::map<std::vector<char>, std::vector<int64_t>> &input_shape);
std::map<std::vector<char>, std::vector<int64_t>> GetInputShapeChar() const;
void SetDecryptKey(const std::vector<char> &key);
void SetDecryptMode(const std::vector<char> &mode);
std::vector<char> GetDecryptModeChar() const;
void SetEncryptKey(const std::vector<char> &key);
void SetDevice(const std::vector<char> &device);
std::vector<char> GetDeviceChar();
void SetProvider(const std::vector<char> &provider);
std::vector<char> GetProviderChar();
void SetChipName(const std::vector<char> &chip_name);
std::vector<char> GetChipNameChar();
Status Convert(converter::FmkType fmk_type, const std::vector<char> &model_file, const std::vector<char> &output_file,
const std::vector<char> &weight_file);
std::shared_ptr<ConverterPara> data_;
};
Converter::Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file,
const std::string &weight_file)
: Converter(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file)) {}
void Converter::SetConfigFile(const std::string &config_file) { SetConfigFile(StringToChar(config_file)); }
std::string Converter::GetConfigFile() const { return CharToString(GetConfigFileChar()); }
void Converter::SetConfigInfo(const std::string §ion, const std::map<std::string, std::string> &config) {
SetConfigInfo(StringToChar(section), MapStringToVectorChar(config));
}
std::map<std::string, std::map<std::string, std::string>> Converter::GetConfigInfo() const {
return MapMapCharToString(GetConfigInfoChar());
}
void Converter::SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape) {
SetInputShape(MapStringToChar(input_shape));
}
std::map<std::string, std::vector<int64_t>> Converter::GetInputShape() const {
return MapCharToString(GetInputShapeChar());
}
void Converter::SetDecryptKey(const std::string &key) { SetDecryptKey(StringToChar(key)); }
void Converter::SetDecryptMode(const std::string &mode) { SetDecryptMode(StringToChar(mode)); }
std::string Converter::GetDecryptMode() const { return CharToString(GetDecryptModeChar()); }
void Converter::SetEncryptKey(const std::string &key) { SetEncryptKey(StringToChar(key)); }
void Converter::SetDevice(const std::string &device) { SetDevice(StringToChar(device)); }
std::string Converter::GetDevice() { return CharToString(GetDeviceChar()); }
void Converter::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
std::string Converter::GetProvider() { return CharToString(GetProviderChar()); }
void Converter::SetChipName(const std::string &chip_name) { SetChipName(StringToChar(chip_name)); }
std::string Converter::GetChipName() { return CharToString(GetChipNameChar()); }
Status Converter::Convert(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file,
const std::string &weight_file) {
return Convert(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file));
}
}
#endif