import ctypes
import os
MAX_GROUP_NAME_LEN = 127
MAX_VALUE_UINT32 = 4294967295
def load_lib():
""" load libhcomm.so file."""
try:
hccl_lib = ctypes.CDLL('libhcomm.so')
except Exception as e:
raise ValueError('load hccl lib error')
return hccl_lib
HCCL_LIB_CTYPES = load_lib()
def c_str(string):
"""Convert a python string to C string."""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)
def set_split_strategy_by_idx(idxList, group="hccl_world_group"):
if isinstance(group, (str)):
if len(group) > MAX_GROUP_NAME_LEN:
raise ValueError('group name len[{}] too long,'.format(len(group))
+ ' Max len[{}].'.format(MAX_GROUP_NAME_LEN))
if len(group) == 0:
raise ValueError('group name is empty.')
else:
raise ValueError('group must be a python str')
if isinstance(idxList, (list)):
if(len(idxList) == 0):
raise ValueError('idxList length is 0')
else:
raise ValueError('idxList must be a python list')
for idx in idxList:
if not isinstance(idx, (int)):
raise ValueError('idx val[{}] in idxList is not python int type.'.format(idx))
if idx < 0 or idx > MAX_VALUE_UINT32:
raise ValueError('idx val[{}] in idxList is an out-of-range value,'
' the correct value range is 0 to {}'.format(idx, MAX_VALUE_UINT32))
if not all([idxList[idx] < idxList[idx + 1]
for idx in range(len(idxList) - 1)]):
raise ValueError('idx in idxList is not ascending')
c_array_idxList = c_array(ctypes.c_uint, idxList)
c_idx_num = ctypes.c_uint(len(idxList))
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomSetGradFusionByIndex(c_group, c_idx_num, c_array_idxList)
if ret != 0:
raise ValueError('split error. ret[{}]'.format(ret))
def set_split_strategy_by_size(dataSizeList, group="hccl_world_group"):
if isinstance(group, (str)):
if len(group) > MAX_GROUP_NAME_LEN:
raise ValueError('group name len[{}] too long,'.format(len(group))
+ ' Max len[{}].'.format({MAX_GROUP_NAME_LEN}))
if len(group) == 0:
raise ValueError('group name is empty.')
else:
raise ValueError('group must be a python str')
if isinstance(dataSizeList, (list)):
if len(dataSizeList) == 0:
raise ValueError('dataSizeList length is 0')
else:
raise ValueError('dataSizeList must be a python list')
for dataSize in dataSizeList:
if not isinstance(dataSize, (int, float)):
raise ValueError('dataSize val[{}] in dataSizeList is not python int or float type.'.format(dataSize))
if dataSize < 0:
raise ValueError('dataSize val[{}] in dataSizeList cannot be a negative number.'.format(dataSize))
if sum(dataSizeList) != 100:
raise ValueError('size percentage list sum is not 100%')
c_array_sizeList = c_array(ctypes.c_float, dataSizeList)
c_size_num = ctypes.c_uint(len(dataSizeList))
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomSetGradFusionBySize(c_group, c_size_num, c_array_sizeList)
if ret != 0:
raise ValueError('split error, ret[{}]'.format(ret))