* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "include/api/status.h"
#include "include/api/context.h"
#include "include/api/model.h"
namespace mindspore {
namespace lite {
namespace {
constexpr int kNumPrintOfOutData = 20;
Status FillInputData(const std::vector<mindspore::MSTensor> &inputs) {
for (auto tensor : inputs) {
auto input_data = tensor.MutableData();
if (input_data == nullptr) {
std::cerr << "MallocData for inTensor failed.\n";
return kLiteError;
}
std::vector<float> temp(tensor.ElementNum(), 1.0f);
memcpy(input_data, temp.data(), tensor.DataSize());
}
return kSuccess;
}
}
Status CompileAndRun(int argc, const char **argv) {
if (argc < 2) {
std::cerr << "Model file must be provided.\n";
return kLiteError;
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
std::cerr << "New context failed while running.\n";
return kLiteError;
}
auto &device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetProvider("Tutorial");
device_info->SetProviderDevice("Tutorial");
device_list.push_back(device_info);
std::string model_file = std::string(argv[1]);
mindspore::Model model;
auto ret = model.Build(model_file, kMindIR, context);
if (ret != kSuccess) {
std::cerr << "build model failed.\n";
return kLiteError;
}
auto inputs = model.GetInputs();
ret = FillInputData(inputs);
if (ret != kSuccess) {
std::cerr << "Generate Random Input Data failed.\n";
return ret;
}
std::vector<MSTensor> outputs;
ret = model.Predict(inputs, &outputs);
if (ret != kSuccess) {
std::cerr << "run model failed.\n";
return ret;
}
for (auto tensor : outputs) {
std::cout << "tensor name is:" << tensor.Name() << " tensor size is:" << tensor.DataSize()
<< " tensor elements num is:" << tensor.ElementNum() << std::endl;
auto out_data = std::static_pointer_cast<const float>(tensor.Data());
std::cout << "output data is:";
for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) {
std::cout << out_data.get()[i] << " ";
}
std::cout << std::endl;
}
return kSuccess;
}
}
}
int main(int argc, const char **argv) {
auto ret = mindspore::lite::CompileAndRun(argc, argv);
if (ret != mindspore::kSuccess) {
std::cerr << "run failed.\n";
return -1;
}
std::cout << "run success.\n";
return 0;
}