import os
from dataclasses import fields
from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.communication.hccl.hccl_info import (
get_device_id,
get_rank_id,
get_rank_size,
)
from mxrec.python.constants.constants import CommNodeInfo
from mxrec.python.config.parser import TomlParser
from mxrec.python.config.config import (
get_comm_node_info,
get_log_level,
get_use_ranktable,
)
def init(toml_path: str):
_parser_init(toml_path)
_logger_init()
_ascend_env_init()
logger.info("MxRec initialization is complete.")
def _parser_init(path: str):
TomlParser.set_instance(path)
def _logger_init():
level = get_log_level()
logger.set_instance(level)
def _ascend_env_init():
if get_use_ranktable():
rank_id = get_rank_id()
os.environ["RANK_ID"] = str(rank_id)
logger.info("The environment variable RANK_ID is set to %s.", rank_id)
rank_size = get_rank_size()
os.environ["RANK_SIZE"] = str(rank_size)
logger.info("The environment variable RANK_SIZE is set to %s.", rank_size)
else:
env_info = []
comm_node_info = get_comm_node_info()
env_info.append(comm_node_info)
for info in env_info:
if not isinstance(info, (CommNodeInfo,)):
raise ValueError(
f"the environment info must be dataclass, but got {info}"
)
for field in fields(info):
env_name = field.name.upper()
ori_env_var = os.getenv(env_name)
if ori_env_var is not None:
continue
env_var = getattr(info, field.name)
os.environ[env_name] = str(env_var)
logger.info(
"The environment variable %s is set to %s.", env_name, env_var
)
device_id = str(get_device_id())
os.environ["ASCEND_DEVICE_ID"] = device_id
logger.info("The environment variable ASCEND_DEVICE_ID is set to %s.", device_id)