import ctypes
import os
MAX_GROUP_NAME_LEN = 127
def check_group(group):
"""A function that check if a collection
communication group is legal.If not raise error.
Returns:
None
"""
if isinstance(group, (str)):
if len(group) > MAX_GROUP_NAME_LEN:
raise ValueError('group name is invalid. group: ' + group[0:MAX_GROUP_NAME_LEN])
if len(group) == 0:
raise ValueError('group name is empty.')
else:
raise ValueError('group must be a python str')
def check_rank_num(rank_num):
"""A function that check if a collection
communication rank number is legal.If not raise error.
Returns:
None
"""
if isinstance(rank_num, (int)):
if rank_num <= 0:
raise ValueError('rank_num[{}] is less than 0 or equal to 0'.format(rank_num))
else:
raise ValueError('rank_num must be a python int')
def check_rank_id(rank_id):
"""A function that check if a collection
communication rank id is legal.If not raise error.
Returns:
None
"""
if isinstance(rank_id, (int)):
if rank_id < 0:
raise ValueError('rank_id[{}] is less than 0'.format(rank_id))
else:
raise ValueError('rank_id must be a python int')
def load_lib():
""" load libhcomm.so file."""
try:
hccl_lib = ctypes.CDLL('libhcomm.so')
except Exception as e:
raise ValueError('load hccl lib error')
return hccl_lib
HCCL_LIB_CTYPES = load_lib()
def c_str(string):
"""Convert a python string to C string."""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)
def create_group(group, rank_num, rank_ids):
check_group(group)
check_rank_num(rank_num)
if isinstance(rank_ids, (list)):
if rank_num != len(rank_ids):
raise ValueError('rank_num[{}]'.format(rank_num) + ' not equal to rank_ids len[{}].'.format(len(rank_ids)))
for rank_id in rank_ids:
check_rank_id(rank_id)
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
c_rank_num = ctypes.c_uint(rank_num)
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
if ret != 0:
raise ValueError('create group error:' + group)
else:
raise ValueError('rank_ids must be a python list')
def destroy_group(group):
check_group(group)
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
if ret != 0:
raise ValueError('destroy group error :' + group)
def get_rank_size(group="hccl_world_group"):
check_group(group)
c_group = c_str(group)
c_rank_size = ctypes.c_uint()
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
if ret != 0:
raise ValueError('get rank size error. ret[{}]'.format(ret))
return c_rank_size.value
def get_rank_id(group="hccl_world_group"):
check_group(group)
c_group = c_str(group)
c_rank_id = ctypes.c_uint()
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
if (ret != 0):
raise ValueError('get rank id error. ret[{}]'.format(ret))
return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"):
check_group(group)
c_group = c_str(group)
c_local_rank_size = ctypes.c_uint()
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
if (ret != 0):
raise ValueError('get local rank size error. ret[{}]'.format(ret))
return c_local_rank_size.value
def get_local_rank_id(group="hccl_world_group"):
check_group(group)
c_group = c_str(group)
c_local_rank_id = ctypes.c_uint()
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
if (ret != 0):
raise ValueError('get local rank id error. ret[{}]'.format(ret))
return c_local_rank_id.value
def get_world_rank_from_group_rank(group, group_rank_id):
check_group(group)
check_rank_id(group_rank_id)
c_group = c_str(group)
c_group_rank_id = ctypes.c_uint(group_rank_id)
c_world_rank_id = ctypes.c_uint()
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
if (ret != 0):
raise ValueError('get world rank from group rank error. ret[{}]'.format(ret))
return c_world_rank_id.value
def get_group_rank_from_world_rank(world_rank_id, group):
check_group(group)
check_rank_id(world_rank_id)
c_group = c_str(group)
c_world_rank_id = ctypes.c_uint(world_rank_id)
c_group_rank_id = ctypes.c_uint()
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
if (ret != 0):
raise ValueError('get group rank from world rank error. ret[{}]'.format(ret))
return c_group_rank_id.value