"""
Restful api for training, in order to provide the method of post.
"""
import os
import logging
import shutil
from flask import abort
from flask_restful import Resource
from analysis.optimizer.workload_characterization import WorkloadCharacterization
from analysis.engine.parser import TRAIN_POST_PARSER
from analysis.default_config import TRAINING_MODEL_PATH
LOGGER = logging.getLogger(__name__)
class Training(Resource):
"""provide the method of post for training"""
model_path = "modelpath"
model_name = "modelname"
data_path = "datapath"
def post(self):
"""
characterization retrain
"""
args = TRAIN_POST_PARSER.parse_args()
LOGGER.info(args)
model_path = args.get(self.model_path)
model_name = args.get(self.model_name)
data_path = args.get(self.data_path)
valid, err = valid_model_name(model_name)
if not valid:
return "Illegal model name provide: {}".format(err), 400
characterization = WorkloadCharacterization(model_path)
output_path = TRAINING_MODEL_PATH + model_name
if os.path.exists(output_path):
return "File already exists!", 400
try:
characterization.retrain(data_path, output_path)
except Exception as err:
LOGGER.error(err)
abort(500)
return {}, 200
def valid_model_name(name):
file_name, file_ext = os.path.splitext(name)
if file_ext != ".m":
return False, "the ext of model name should be .m"
if file_name in ['scaler', 'aencoder', 'tencoder', 'default_clf', 'total_clf', 'throughput_performance_clf']:
return False, "model name cannot be set as default_clf/scaler/aencoder/tencoder/throughput_performance_clf/total_clf"
for ind, char in enumerate(file_name):
if 'a' <= char <= 'z':
continue
if 'A' <= char <= 'Z':
continue
if '0' <= char <= '9':
continue
if ind != 0 and ind != len(file_name) - 1 and char == '_':
continue
return False, "model name cannot contains special character"
return True, None