import os
import re
import sys
import json
import time
import shutil
import logging
import subprocess
from argparse import ArgumentParser
from typing import Dict, Any
from enum import Enum
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
SERVER_LIST = 'server_list'
MAX_RETRIES = 10
RETRY_INTERVAL = 3
class HardwareType(Enum):
A2 = 'd802'
A3 = 'd803'
UNKNOWN = 'unknown'
def parse_args():
parser = ArgumentParser(description="Generate hccl config file")
parser.add_argument(
"--hccl_path", type=str, default="hccl.json", help="Manually specify the path of hccl config file"
)
args = parser.parse_args()
return args
def get_hardware_type():
try:
lspci_path = shutil.which("lspci")
if not lspci_path:
raise ValueError("lspci not found!")
output = subprocess.check_output(f"{lspci_path}", text=True, timeout=5)
if HardwareType.A2.value in output:
return HardwareType.A2
elif HardwareType.A3.value in output:
return HardwareType.A3
except EOFError as e:
logging.error("get hardware type failed: %s", e)
return HardwareType.UNKNOWN
def get_visible_devices():
try:
import glob
davinci_devices = glob.glob("/dev/davinci*")
if davinci_devices:
device_ids = []
for device_path in davinci_devices:
match = re.search(r'davinci(\d+)', device_path)
if match:
device_ids.append(match.group(1))
if device_ids:
return sorted(device_ids)
except Exception as e:
logging.error("Failed to detect visible devices: %s", e)
return []
def _run_command(cmd_args):
result = subprocess.run(
cmd_args,
capture_output=True,
text=True,
check=False,
timeout=30,
)
if result.stdout:
return result.stdout.splitlines(keepends=True)
return []
def retry_command(cmd_args):
for attempt in range(MAX_RETRIES):
try:
result = _run_command(cmd_args)
if result:
return result
logging.warning(
"Command returned empty result, attempt %d/%d",
attempt + 1,
MAX_RETRIES,
)
except Exception as e:
logging.warning(
"Command failed: %s, attempt %d/%d",
e,
attempt + 1,
MAX_RETRIES,
)
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_INTERVAL)
raise ValueError(f"Command failed after {MAX_RETRIES} attempts: {cmd_args}")
def main():
logging.info("start %s", __file__)
args = parse_args()
visible_devices = get_visible_devices()
logging.info('Detected visible_devices: %s', visible_devices)
hardware_type = get_hardware_type()
if hardware_type == HardwareType.UNKNOWN:
raise ValueError("unknown hardware type!")
logging.info('Detected hardware_type: %s', hardware_type)
host_ip = os.getenv('HOST_IP', '127.0.0.1')
pod_ip = os.getenv('POD_IP', '127.0.0.1')
logging.info('host_ip: %s', host_ip)
logging.info('pod_ip: %s', pod_ip)
device_ips: Dict[Any, Any] = {}
device_sdids: Dict[Any, Any] = {}
for device_id in visible_devices:
ret_ip = retry_command(["hccn_tool", "-i", str(device_id), "-ip", "-g"])
logging.info("device_id: %s, device_ip_info: %s", device_id, str(ret_ip))
device_ips[device_id] = ret_ip[0].split(":")[1].replace('\n', '').replace(' ', '')
if hardware_type == HardwareType.A3:
card_id = int(device_id) // 2
chip_id = int(device_id) % 2
ret_sdid = retry_command(["npu-smi", "info", "-t", "spod-info", "-i", str(card_id), "-c", str(chip_id)])
logging.info("device_id: %s, super_device_id: %s", device_id, str(ret_sdid))
device_sdids[device_id] = ret_sdid[0].split(":")[1].replace('\n', '').replace(' ', '')
hccn_table = {'version': '1.0', 'server_count': '1', SERVER_LIST: []}
device_list = []
for rank_id, device_id in enumerate(visible_devices):
device_ip = device_ips[device_id]
device_info = {'device_id': device_id, 'device_ip': device_ip, 'rank_id': str(rank_id)}
if hardware_type == HardwareType.A3:
device_info['super_device_id'] = device_sdids[device_id]
device_list.append(device_info)
logging.info('rank_id: %s, device_id: %s, device_ip: %s', rank_id, device_id, device_ip)
hccn_table[SERVER_LIST].append({'server_id': host_ip, 'container_ip': pod_ip, 'device': device_list})
if hardware_type == HardwareType.A3:
hccn_table['super_pod_list'] = [{"super_pod_id": "0", SERVER_LIST: [{"server_id": host_ip}]}]
hccn_table['status'] = 'completed'
with open(args.hccl_path, 'w', encoding='utf-8') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
sys.stdout.flush()
logging.info("Completed: hccl file was save in : %s", args.hccl_path)
if __name__ == "__main__":
main()