import json
import logging
import os
import subprocess
from user_config_loader import UserConfig
from utils import convert_args_dict_to_list, resolve_with_retry
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
COMMON_KEYS = ['model_path', 'serve_name', 'enable_ep', 'dp_rpc_port', 'server_port']
PREFILL_ROLES = ['prefill', 'union']
DECODE_ROLE = 'decode'
PREFILL_KEYS = ['prefill_dp_size', 'prefill_tp_size']
DECODE_KEYS = ['decode_dp_size', 'decode_tp_size']
HARDWARE_TYPE_A2 = 'module-910b-8'
HARDWARE_TYPES_A3 = ['module-a3-16', 'module-a3-16-super-pod']
DATA_PARALLEL_ADDRESS_KEY = 'data_parallel_address'
HOST_KEY = 'host'
HEADLESS_KEY = 'headless'
DATA_PARALLEL_START_RANK_KEY = 'data_parallel_start_rank'
DATA_PARALLEL_SIZE_LOCAL_KEY = 'data_parallel_size_local'
KV_TRANSFER_CONFIG_KEY = 'kv_transfer_config'
DEPLOY_TYPE_PD_SEPARATE = 'pd_separate'
POD_INDEX_PRIMARY = 0
engine_type_args_map = {
'vllm': {
'model_path': 'model',
'serve_name': 'served_model_name',
'prefill_dp_size': 'data_parallel_size',
'decode_dp_size': 'data_parallel_size',
'prefill_tp_size': 'tensor_parallel_size',
'decode_tp_size': 'tensor_parallel_size',
'enable_ep': 'enable_expert_parallel',
'dp_rpc_port': 'data_parallel_rpc_port',
'server_port': 'port'
}
}
def get_args_from_user_config(role, user_config: UserConfig) -> dict:
engine_type = user_config.engine_common_config.engine_type
if engine_type not in engine_type_args_map:
raise ValueError(f"Unsupported engine type: {engine_type}")
args_map = engine_type_args_map[engine_type]
result = {}
engine_common_config_dict = user_config.engine_common_config.__dict__
for key, value in engine_common_config_dict.items():
if key in COMMON_KEYS:
if key not in args_map:
raise ValueError(f"Key '{key}' not found in args_map for engine type '{engine_type}'")
result[args_map[key]] = value
elif role in PREFILL_ROLES and key in PREFILL_KEYS:
if key not in args_map:
raise ValueError(f"Key '{key}' not found in args_map for engine type '{engine_type}'")
result[args_map[key]] = value
elif role == DECODE_ROLE and key in DECODE_KEYS:
if key not in args_map:
raise ValueError(f"Key '{key}' not found in args_map for engine type '{engine_type}'")
result[args_map[key]] = value
def _process_role_config(deploy_config, dp_size_key, role_desc, engine_config):
if deploy_config:
single_instance_pod_num = deploy_config.single_instance_pod_num
if single_instance_pod_num == 0:
raise ValueError(f"single_instance_pod_num must be int greater than 0")
data_parallel_size = engine_common_config_dict.get(dp_size_key, 1)
if data_parallel_size % single_instance_pod_num != 0:
raise ValueError(
f"For {role} role, {dp_size_key} ({data_parallel_size}) must be divisible by {role_desc}.single_instance_pod_num ({single_instance_pod_num})")
data_parallel_size_local = data_parallel_size // single_instance_pod_num
result['data_parallel_size_local'] = data_parallel_size_local
if engine_config:
result.update(engine_config)
if role in PREFILL_ROLES:
_process_role_config(
user_config.deploy_config.prefill,
'prefill_dp_size',
'prefill',
user_config.prefill_engine_config
)
elif role == DECODE_ROLE:
_process_role_config(
user_config.deploy_config.decode,
'decode_dp_size',
'decode',
user_config.decode_engine_config
)
return result
def get_ip_of_pod0(user_config: UserConfig) -> str:
required_env_vars = [
'INFER_SERVICE_NAME',
'INFER_SERVICE_INDEX',
'INSTANCE_INDEX',
'INSTANCE_ROLE',
]
for env_var in required_env_vars:
if env_var not in os.environ:
raise ValueError(f"Required environment variable {env_var} is not set")
infer_service_name = os.environ.get('INFER_SERVICE_NAME')
infer_service_index = os.environ.get('INFER_SERVICE_INDEX')
instance_index = os.environ.get('INSTANCE_INDEX')
instance_role = os.environ.get('INSTANCE_ROLE')
namespace = user_config.deploy_config.namespace
hostname = f"{infer_service_name}-{infer_service_index}-{instance_role}-{instance_index}-0.service-{infer_service_name}-{infer_service_index}-{instance_role}-{instance_index}.{namespace}.svc.cluster.local"
ip = resolve_with_retry(hostname)
if ip is None:
raise ValueError(f"Failed to resolve hostname {hostname}")
return ip
def get_args_from_env(user_config: UserConfig, dp_parallel_size_local: int) -> dict:
result = {}
pod_ip = os.environ.get('POD_IP')
if pod_ip:
result[DATA_PARALLEL_ADDRESS_KEY] = pod_ip
pod_name = os.environ.get('POD_NAME')
if pod_name:
parts = pod_name.split('-')
try:
pod_index = int(parts[-1])
if pod_index == POD_INDEX_PRIMARY:
result[HOST_KEY] = pod_ip
if pod_index != POD_INDEX_PRIMARY:
result[HEADLESS_KEY] = True
data_parallel_start_rank = pod_index * dp_parallel_size_local
result[DATA_PARALLEL_START_RANK_KEY] = data_parallel_start_rank
result[DATA_PARALLEL_ADDRESS_KEY] = get_ip_of_pod0(user_config)
except (ValueError, IndexError) as e:
logging.error(f"Could not extract pod_index from POD_NAME: {pod_name}")
logging.error(f"Error details: {e}")
return result
def get_kv_port_base(hardware_type: str) -> int:
if hardware_type == HARDWARE_TYPE_A2:
return 28000
elif hardware_type in HARDWARE_TYPES_A3:
return 36000
else:
raise ValueError(f"Unsupported hardware_type: {hardware_type}")
def generate_kv_transfer_config(role, user_config: UserConfig) -> str:
if 'INSTANCE_INDEX' not in os.environ:
raise ValueError("Required environment variable INSTANCE_INDEX is not set")
instance_id = int(os.environ.get('INSTANCE_INDEX'))
if role == 'prefill':
kv_role = 'kv_producer'
engine_id_base = 0
kv_port_base = get_kv_port_base(user_config.deploy_config.prefill.hardware_type)
elif role == 'decode':
kv_role = 'kv_consumer'
engine_id_base = user_config.deploy_config.prefill.instance_count
kv_port_base = get_kv_port_base(user_config.deploy_config.decode.hardware_type)
else:
raise ValueError(f"Unsupported role for KV transfer config: {role}")
engine_id = engine_id_base + instance_id
kv_port = kv_port_base + engine_id * 100
kv_transfer_config = {
"kv_connector": "MooncakeLayerwiseConnector",
"kv_role": kv_role,
"kv_port": str(kv_port),
"engine_id": str(engine_id),
"kv_connector_extra_config": {
"use_ascend_direct": True,
"prefill": {
"dp_size": user_config.engine_common_config.prefill_dp_size,
"tp_size": user_config.engine_common_config.prefill_tp_size
},
"decode": {
"dp_size": user_config.engine_common_config.decode_dp_size,
"tp_size": user_config.engine_common_config.decode_tp_size
}
}
}
return json.dumps(kv_transfer_config, indent=2)
def pull_engine(role, config_path):
try:
user_config = UserConfig.load_from_file(config_path)
args_dict = get_args_from_user_config(role, user_config)
env_args = get_args_from_env(user_config, args_dict[DATA_PARALLEL_SIZE_LOCAL_KEY])
args_dict.update(env_args)
if user_config.engine_common_config.deploy_type == DEPLOY_TYPE_PD_SEPARATE:
args_dict[KV_TRANSFER_CONFIG_KEY] = generate_kv_transfer_config(role, user_config)
converted_args_list = convert_args_dict_to_list(args_dict)
vllm_command = ['vllm', 'serve'] + converted_args_list
logging.info(f"Starting vllm with command: {' '.join(vllm_command)}")
process = subprocess.Popen(vllm_command, shell=False)
stdout, stderr = process.communicate()
if process.returncode != 0:
logging.error(f"vllm serve process failed with return code {process.returncode}")
logging.error(f"stdout: {stdout}")
logging.error(f"stderr: {stderr}")
raise subprocess.CalledProcessError(process.returncode, vllm_command, output=stdout, stderr=stderr)
logging.info(f"vllm serve process started with PID: {process.pid}")
except Exception as e:
logging.error(f"Error in pull_engine: {e}")
raise