"""
-------------------------------------------------------------------------
This file is part of the Vision SDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
Vision SDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
Description: ai_server flask app.
Author: Vision SDK
Create: 2024
History: NA
"""
import base64
import os
import time
import uuid
from http import HTTPStatus
from functools import wraps
from flask import Flask
from flask import jsonify
from flask import request
from error_code import ErrorCode, ErrorMessage
from infer_request import InferRequest
from server_options_and_logger import logger as logging
from server_options_and_logger import server_option_instance
from tensor import Tensor, TensorKind
from token_bucket import token_bucket_instance
INVALID_CHARS = [
"\n", "\f", "\r", "\b", "\t", "\v", "\u000D", "\u000A", "\u000C", "\u000B", "\u0009", "\u0008", "\u007F"
]
BYTE_RATE = 1024
MAX_BYTE_SIZE = server_option_instance.max_content_length * BYTE_RATE
TIME_OUT = 60
class StreamServerServiceApp:
EXPECT_JSON_SIZE = 6
def __init__(self, infer_server_manager, server_options):
self.infer_server_manager = infer_server_manager
self.server_options = server_options
self.max_request_rate = "%s/minute" % str(server_options.max_request_rate)
@staticmethod
def fail_reply(error_code, status_code):
response = jsonify(
{'isSuccess': False, 'errorCode': error_code, "errorMsg": ErrorMessage[error_code]})
response.status_code = status_code
return response
@staticmethod
def _success_reply(output):
response = jsonify(
{'isSuccess': True, 'errorCode': ErrorCode.SUCCESS, "errorMsg": ErrorMessage[ErrorCode.SUCCESS],
"outputs": output})
response.status_code = HTTPStatus.OK
return response
def base_app(self):
app = Flask(__name__)
app.config["MAX_CONTENT_LENGTH"] = MAX_BYTE_SIZE
app.config['SERVER_NAME'] = None
@app.before_request
def ip_check():
res = None
remote_addr = request.remote_addr
for invalid_char in INVALID_CHARS:
if invalid_char in remote_addr:
logging.error(ErrorMessage[ErrorCode.INTERNAL_ERROR])
return StreamServerServiceApp.fail_reply(ErrorCode.INTERNAL_ERROR, HTTPStatus.BAD_REQUEST)
return res
@app.route('/v2', methods=['GET'])
@app.route('/v2/', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_server_options():
return self._success_reply(self.server_options.server_info_json)
@app.route('/v2/live', methods=['GET'])
@app.route('/v2/live/', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_server_live():
return self._success_reply({"isLive": self.infer_server_manager.get_server_live()})
@app.route('/v2/ready', methods=['GET'])
@app.route('/v2/ready/', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_server_ready():
return self._success_reply({"isReady": self.infer_server_manager.get_server_ready()})
@app.route('/v2/streams/<stream_name>/ready', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_stream_ready(stream_name):
task_id = "streams%s" % stream_name
if task_id not in self.server_options.infer_configs:
logging.error("The task dose not exist")
return self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
return self._success_reply(self.infer_server_manager.get_stream_model_ready(task_id))
@app.route('/v2/streams/<stream_name>/config', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_stream_config(stream_name):
task_id = "streams%s" % stream_name
if task_id not in self.server_options.infer_configs:
logging.error("The task dose not exist")
return self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
return self._success_reply(self.server_options.infer_configs[task_id])
@app.route('/v2/models/<model_name>/ready', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_model_ready(model_name):
task_id = "models%s" % model_name
if task_id not in self.server_options.infer_configs:
logging.error("The task dose not exist")
self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
return self._success_reply(self.infer_server_manager.get_stream_model_ready(task_id))
@app.route('/v2/models/<model_name>/config', methods=['GET'])
@log_request_info
@content_length_limit
@request_rate_limit
def get_model_config(model_name):
task_id = "models%s" % model_name
if task_id not in self.server_options.infer_configs:
logging.error("The task dose not exist")
self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
return self._success_reply(self.server_options.infer_configs[task_id])
return app
def create_app(self):
app = self.base_app()
@app.route('/v2/<infer_type>/<name>/infer', methods=['POST'])
@log_request_info
@content_length_limit
@request_rate_limit
def post_stream_infer(infer_type, name):
if infer_type != "streams" and infer_type != "models":
logging.error("The task is not stream infer or model infer.")
return self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
task_id = "%s%s" % (infer_type, name)
if task_id not in self.infer_server_manager.infer_configs:
logging.error("The task dose not exist")
return self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
if request.headers.get("Content-Type") != "application/json":
logging.error("The content type of request must be json")
return self.fail_reply(ErrorCode.TARGET_NOT_EXISTS, HTTPStatus.BAD_REQUEST)
json_data = request.json
try:
input_request = self._extract_input_json(json_data, task_id)
except Exception as err_message:
logging.error("Invalid Http request body, extract inputs item failed! Error message is %s", err_message)
return self.fail_reply(ErrorCode.INVALID_BODY, HTTPStatus.BAD_REQUEST)
if self.infer_server_manager.request_queues[task_id].full():
logging.error("The stream infer queue is full, task is %s", task_id)
return self.fail_reply(ErrorCode.CACHE_FULL, HTTPStatus.BAD_REQUEST)
self.infer_server_manager.request_queues[task_id].push(input_request)
start_time = time.time()
while not input_request.is_processed:
time.sleep(1)
if time.time() - start_time > TIME_OUT:
logging.error("Input request process time out!")
input_request.is_processed = True
if input_request.is_error:
return self.fail_reply(ErrorCode.INFER_FAILED, HTTPStatus.BAD_REQUEST)
return self._success_reply(input_request.get_output_json())
return app
def _extract_input_json(self, json_data, task_id):
if "inputs" not in json_data or not isinstance(json_data["inputs"], list) or len(json_data["inputs"]) == 0:
logging.error("Invalid Http request body, the request must has input fields and "
"inputs should be an valid array.")
raise Exception("Invalid Http request body, the request must has input fields and inputs should"
" be an valid array.")
inputs = json_data["inputs"][0]
if len(inputs) != self.EXPECT_JSON_SIZE:
logging.error("Invalid input field, the input field can only contain 6 fields: name, id, shape, format,"
" dataType and data.")
raise Exception("Invalid input field, the input field can only contain 6 fields: name, id, shape, format,"
" dataType and data.")
if "name" not in inputs or not isinstance(inputs["name"], str):
logging.error("Invalid input field, no name field is included or name field is not string type.")
raise Exception("Invalid input field, no name field is included or name field is not string type.")
if "dataType" not in inputs or not isinstance(inputs["dataType"], str):
logging.error("Invalid input field, no dataType field is included or dataType field is not string type.")
raise Exception("Invalid input field, no dataType field is included or dataType field is not string type.")
if "id" not in inputs or not isinstance(inputs["id"], int):
logging.error("Invalid input field, no id field is included or format field is not string type.")
raise Exception("Invalid input field, no id field is included or format field is not string type.")
if "format" not in inputs or not isinstance(inputs["format"], str):
logging.error("Invalid input field, no format field is included or format field is not string type.")
raise Exception("Invalid input field, no format field is included or format field is not string type.")
if "shape" not in inputs or not isinstance(inputs["shape"], list):
logging.error("No shape field or shape field is not list type.")
raise Exception("No shape field or shape field is not list type.")
for shape in inputs["shape"]:
if not isinstance(shape, int):
logging.error("Shape field is not integer type.")
raise Exception("Shape field is not integer type.")
if "data" not in inputs:
logging.error("Invalid input field, no data field is included or data field is not string type.")
raise Exception("Invalid input field, no data field is included or data field is not string type.")
decoded_data = base64.b64decode(inputs["data"])
input_tensor = Tensor(name=inputs["name"], id=inputs["id"], format=inputs["format"], shape=inputs["shape"],
data_type_str=inputs["dataType"])
if not input_tensor.check_tensor_data(TensorKind.CLIENT_KIND):
logging.error("Check input tensor data type failed!")
raise Exception("Check input tensor data type failed!")
if not input_tensor.check_tensor_data_size(len(decoded_data)):
logging.error("Check input tensor data size failed!")
raise Exception("Check input tensor data size failed!")
input_tensor.set_data(len(decoded_data), decoded_data)
input_request = InferRequest(self.server_options.infer_configs[task_id])
input_request.add_input(inputs["id"], input_tensor)
return input_request
def content_length_limit(func):
@wraps(func)
def wrapper(*args, **kwargs):
length = 0
for header in request.headers:
for content in header:
length += len(content)
data_str = request.get_data()
length += len(data_str)
if length > MAX_BYTE_SIZE:
logging.error("Illegal request, content length with head is larger than %s byte.", str(MAX_BYTE_SIZE))
return StreamServerServiceApp.fail_reply(ErrorCode.LARGE_CONTENT, HTTPStatus.REQUEST_ENTITY_TOO_LARGE)
return func(*args, **kwargs)
return wrapper
def request_rate_limit(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not token_bucket_instance.consume_token(1):
logging.error("Too many request.")
return StreamServerServiceApp.fail_reply(ErrorCode.RATE_EXCEEDED, HTTPStatus.BAD_REQUEST)
return func(*args, **kwargs)
return wrapper
def log_request_info(func):
@wraps(func)
def wrapper(*args, **kwargs):
remote_addr = request.remote_addr
uid = request.headers.get('uuid')
client_id = None
if uid:
try:
client_id = uuid.UUID(uid)
except ValueError:
logging.error('The client id format is wrong.')
logging.info("Remote IP: %s Client ID: %s - - %s %s", remote_addr, client_id, request.method, request.path)
return func(*args, **kwargs)
return wrapper