"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import asyncio
import logging
import time
import warnings
from enum import Enum
from abc import ABC, abstractmethod
from utils import sys_util
from utils.logutil import server_logger
import common
class ServerState(Enum):
Initial = 0
Starting = 1
Pending = 2
Working = 3
Exited = 4
DEFAULT_CHECK_SERVER_PS_MAX_RETRIES = 5
DEFAULT_CHECK_SERVER_PS_DELAY = 1
DEFAULT_RANDOM_SERVER_PORT = -1
class ServerEvent(Enum):
server_started_successfully = 0
server_started_failed = 1
server_exited_on_idle = 2
server_exited_by_command = 3
server_exited_on_error = 4
server_picked_up_successfully = 5
server_picked_up_failed = 6
server_released_successfully = 7
server_released_failed = 8
server_state_changed = 9
server_reached_max_idle_time = 10
class BaseServer(ABC):
def __init__(self,
port: int,
host: str = '127.0.0.1',
server_event_callback=lambda event, server, msg="": None):
self.state = ServerState.Initial
self.port = port
if not sys_util.is_port_in_user_space(port):
raise ValueError(f'Port {port} is not valid')
self.host = host
self.protocol = 'ws/http'
self._async_lock = asyncio.Lock()
self._subprocess = None
self._initialized = False
self._server_event_callback = server_event_callback
self._initialize()
def start(self):
if not self._is_initial():
err = f"Server {self.net_location()} state: {self.state} cannot be start again."
server_logger.warning(err)
return False
if not self._initialized:
err = f"Server {self.net_location()} not initialized"
server_logger.warning(err)
return False
self._set_state(ServerState.Starting)
if not self._do_start():
server_logger.error(f"Server {self.net_location()} process failed to start.")
return False
server_logger.info(
f"Server {self.net_location()} is in the process of starting up."
f"Beginning to check whether it has started successfully.")
for i in range(DEFAULT_CHECK_SERVER_PS_MAX_RETRIES):
time.sleep(DEFAULT_CHECK_SERVER_PS_DELAY)
is_running = self._check_start()
server_logger.info(
f"Performing {i + 1}/{DEFAULT_CHECK_SERVER_PS_MAX_RETRIES} checks")
if is_running:
self._set_state(ServerState.Pending)
self._server_event_callback(ServerEvent.server_started_successfully, self)
server_logger.info(f'Server {self.net_location()} start successfully')
return True
server_logger.warning(
f"Completed {i + 1}/{DEFAULT_CHECK_SERVER_PS_MAX_RETRIES} checks. Check result: {is_running}")
self._terminate()
self._server_event_callback(ServerEvent.server_started_failed, self)
server_logger.error(f'Server {self.net_location()} failed to start.')
return False
def get_state(self):
return self.state
async def terminate(self, *args, **kwargs):
async with self._async_lock:
self._terminate()
async def pick(self):
async with self._async_lock:
self._pick()
async def release(self):
async with self._async_lock:
return self._release()
async def is_pending(self):
async with self._async_lock:
return self.state == ServerState.Pending
async def is_working(self):
async with self._async_lock:
return self.state == ServerState.Working
async def is_exited(self):
async with self._async_lock:
return self.state == ServerState.Exited
async def is_initial(self):
async with self._async_lock:
return self.state == ServerState.Initial
async def is_starting(self):
async with self._async_lock:
return self.state == ServerState.Starting
def net_location(self):
return f"{self.host}:{self.port}"
async def set_state(self, state: ServerState):
async with self._async_lock:
self._set_state(state)
def can_be_remove(self):
return self._is_initial() or self._is_exited() or self._is_pending()
@abstractmethod
def _initialize(self):
self._initialized = True
@abstractmethod
def _check_start(self) -> bool:
...
@abstractmethod
def _do_terminate(self):
...
@abstractmethod
def _do_start(self):
...
@abstractmethod
def _do_pick(self):
...
@abstractmethod
def _do_release(self):
...
def _terminate(self):
self._do_terminate()
self._set_state(ServerState.Exited)
self._server_event_callback(ServerEvent.server_exited_by_command, self)
def _pick(self):
if not self._is_pending():
self._server_event_callback(ServerEvent.server_picked_up_failed, self,
f"Cannot pick server: {self.net_location()} in state: {self.state}")
return None
try:
self._do_pick()
self._set_state(ServerState.Working)
except Exception as e:
self._server_event_callback(ServerEvent.server_picked_up_failed, self,
f"Pick server: {self.net_location()} failed: {e}")
return None
self._server_event_callback(ServerEvent.server_picked_up_successfully, self)
return self
def _release(self):
if not self._is_working():
self._server_event_callback(ServerEvent.server_released_failed, self,
f"Cannot release server: {self.net_location()} in state: {self.state}")
return False
try:
self._do_release()
self._set_state(ServerState.Pending)
except Exception as e:
self._server_event_callback(ServerEvent.server_picked_up_failed, self,
f"Pick server: {self.net_location()} failed: {e}")
return False
self._server_event_callback(ServerEvent.server_released_successfully, self)
return True
def _set_state(self, state: ServerState):
if self.state == state:
return
self._server_event_callback(ServerEvent.server_state_changed, self,
f"from {self.state} to {state}")
self.state = state
def _is_pending(self):
return self.state == ServerState.Pending
def _is_working(self):
return self.state == ServerState.Working
def _is_exited(self):
return self.state == ServerState.Exited
def _is_starting(self):
return self.state == ServerState.Starting
def _is_initial(self):
return self.state == ServerState.Initial
class BaseServerManager(ABC):
def __init__(self, core_server_size: int = common.CORE_SERVER_SIZE, max_server_size: int = common.MAX_SERVER_SIZE,
random_port_range: tuple = common.PROFILER_SERVER_DEFAULT_PORT_RANGE):
if not 0 < core_server_size <= max_server_size:
raise ValueError(f'core_server_size {core_server_size} is not valid')
if not len(random_port_range) == 2:
raise ValueError("Invalid port range, illegal range tuple")
if random_port_range[1] - random_port_range[0] + 1 < core_server_size:
raise ValueError("Invalid port range, the number of available values should be "
"at least the number of the core server size.")
self._server_map = {}
self._max_server_map_size = max_server_size
self._core_server_size = core_server_size
self._random_port_range = random_port_range
self._lock = asyncio.Lock()
async def start(self):
async with self._lock:
server_logger.info("Trying to start the profiler server manager...")
server_logger.info(f"Trying to start {self._core_server_size} core servers.")
while True:
if len(self._server_map) >= self._core_server_size:
break
server_logger.info(f"Starting the {len(self._server_map) + 1}/{self._core_server_size} server")
if not await self._add_server(is_core_server=True):
server_logger.warning(
f"Failed to add the {len(self._server_map) + 1}/{self._core_server_size} server.")
async def add_server(self, port: int = DEFAULT_RANDOM_SERVER_PORT, host: str = '127.0.0.1', *args, **kwargs):
async with self._lock:
return await self._add_server(port, host, False, *args, **kwargs)
async def is_available_port_for_register(self, port: int):
async with self._lock:
self._is_available_port_register(port)
async def register_server(self, server: BaseServer):
async with self._lock:
return self._register_server(server)
async def pick_server_by_location(self, server_net_location):
async with self._lock:
return await self._pick_server_by_location(server_net_location)
async def release_server_by_location(self, server_net_location):
async with self._lock:
return await self._release_server_by_location(server_net_location)
async def find_available_random_port(self) -> int:
random_port = await sys_util.find_available_port_in_range(*self._random_port_range, host="127.0.0.1",
max_try_times=self._max_server_map_size * 3)
if random_port == -1:
raise RuntimeError(
f"Unable to find an available random port after {self._max_server_map_size * 3} attempts.")
return random_port
async def get_idle_server(self, auto_create: bool = False):
async with self._lock:
for server in self._server_map.values():
if await server.is_pending():
return server
server_logger.info("There is no idle server.")
if auto_create:
server_logger.info("Trying to start a new server in background.")
create_task = asyncio.create_task(self._add_server())
async def remove_server(self, server_net_location: str):
async with self._lock:
server_count = len(self._server_map)
if server_count < self._core_server_size + 1:
raise RuntimeError(
f"Cannot remove server: {server_net_location}, there are only {self._core_server_size} servers.")
if server_net_location not in self._server_map:
raise RuntimeError(f"Cannot remove server[{server_net_location}], not found.")
if self._server_map[server_net_location].is_pending():
raise RuntimeWarning(f"Try to remove server[{server_net_location}] in pending.")
server = self._server_map.pop(server_net_location)
await server.terminate()
self._do_remove_server(server)
@abstractmethod
def _do_add_server(self, port: int, host: str = '127.0.0.1', *args, **kwargs) -> bool:
...
@abstractmethod
def _do_remove_server(self, server: BaseServer):
...
def _is_available_port_register(self, port: int, host: str = '127.0.0.1'):
if not port:
return False
is_in_range = self._random_port_range[0] <= port <= self._random_port_range[1]
is_unregistered = f"{host}:{port}" not in self._server_map
return is_in_range and is_unregistered
def _remove_server(self, server_net_location: str):
server_count = len(self._server_map)
if server_net_location not in self._server_map:
server_logger.error(f"Cannot remove server[{server_net_location}], not found.")
return False
server = self._server_map[server_net_location]
if not server.can_be_remove():
server_logger.warning(f"Try to remove server[{server_net_location}] in state {server.get_state()}.")
return False
server = self._server_map.pop(server_net_location)
self._do_remove_server(server)
return True
async def _add_server(self,
port: int = DEFAULT_RANDOM_SERVER_PORT,
host: str = '127.0.0.1',
is_core_server: bool = False,
*args, **kwargs):
if port == DEFAULT_RANDOM_SERVER_PORT or f"{host}:{port}" in self._server_map:
port = await self.find_available_random_port()
current_server_size = len(self._server_map)
if current_server_size >= self._max_server_map_size:
server_logger.warning("Max server size reached, cannot add more server.")
return False
server_logger.info(f"Trying to create and run new server: {host}:{port}")
if self._do_add_server(port, host, is_core_server, *args, **kwargs):
server_logger.info(f"Server {host}:{port} added successfully.")
return True
else:
server_logger.warning(f"Failed to added server: {host}:{port}")
return False
def _register_server(self, server: BaseServer):
if not self._is_available_port_register(server.port):
server_logger.warning(f"Cannot register server[{server.host}:{server.port}], port is already registered.")
return False
if server.state == ServerState.Exited:
server_logger.error(f"Cannot register server[{server.host}:{server.port}], server is already exited.")
return False
self._server_map[f"{server.host}:{server.port}"] = server
return True
async def _pick_server_by_location(self, server_net_location):
if server_net_location not in self._server_map:
logging.error(f"Cannot find server[{server_net_location}]")
return False
return await self._server_map[server_net_location].pick()
async def _release_server_by_location(self, server_net_location):
if server_net_location not in self._server_map:
logging.error(f"Cannot find server[{server_net_location}]")
return False
return await self._server_map[server_net_location].release()