import math
_HCCL_GROUP_BUFFER = {}
_HCCL_OP_MODE = {}
def _parse_key_value_string(input_string, target_dict, dict_name="config"):
if input_string is None or not input_string:
return
allowed_keys = [
"dp",
"dp_cp",
"cp",
"mp",
"mp_exp",
"tp",
"pp",
"embd",
"tp_dp_cp",
"dp_cp",
"tp_dp",
"tp_cp",
"tp_exp",
"tp_ep_mp",
"exp",
"ep",
"dp_modulo_exp",
"pp_new_stream",
"cp2",
"cp_ulysses",
"cp_ring",
"cp_ring_intra",
"cp_ring_intra_overlap",
"nd1_dim1",
"ag_x_sd_rcv_overlap",
"nd1_dim2",
"ag_y_sd_rcv_overlap",
"nd2_dim1",
"nd2_dim2",
"default_group",
]
parts = input_string.split(';') if "," not in input_string else input_string.split(',')
for part in parts:
key_value = part.split(':')
if len(key_value) == 2:
key = key_value[0].strip().replace(' ', '')
value_str = key_value[1].strip().replace(' ', '')
if key in allowed_keys:
try:
value = int(value_str)
if value <= 0:
raise RuntimeError(f"Value {value} for key '{key}' must be greater than 0")
target_dict[key] = value
except ValueError as e:
raise RuntimeError(f"'{value_str}' is not a valid positive integer for key '{key}'") from e
else:
raise RuntimeError(f"Key '{key}' is not allowed in {dict_name}")
else:
raise RuntimeError(
f"The value format of {dict_name} is not valid. Expected 'key:value' pairs separated by ';'"
)
def parse_hccl_buffer_string(hccl_group_buffer):
_parse_key_value_string(hccl_group_buffer, _HCCL_GROUP_BUFFER, "--hccl-group-buffer")
def parse_hccl_op_mode_string(hccl_op_mode_string):
_parse_key_value_string(hccl_op_mode_string, _HCCL_OP_MODE, "--hccl-op-mode")
def hccl_buffer_auto_adaptive(args):
seq_length = args.seq_length
micro_batch_size = args.micro_batch_size
hidden_size = args.hidden_size
context_parallel_size = args.context_parallel_size
tensor_model_parallel_size = args.tensor_model_parallel_size
expert_model_parallel_size = args.expert_model_parallel_size
moe_router_topk = args.moe_router_topk
moe_token_dispatcher_type = args.moe_token_dispatcher_type
context_parallel_algo = args.context_parallel_algo
num_attention_heads = args.num_attention_heads
group_query_attention = args.group_query_attention
num_query_groups = args.num_query_groups
if moe_token_dispatcher_type is not None and moe_token_dispatcher_type == 'alltoall_seq':
hccl_tp_buffer_size_mlp = 2 * math.ceil(
seq_length / context_parallel_size * micro_batch_size * hidden_size / 1024 / 1024
)
if args.sequence_parallel:
_HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp
else:
_HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp * 2
if args.hccl_ep_group_buffer_adaptive_factor > 0:
hccl_tp_buffer_size_moe = 2 * math.ceil(
args.hccl_ep_group_buffer_adaptive_factor
* seq_length
/ context_parallel_size
/ tensor_model_parallel_size
* micro_batch_size
* hidden_size
/ 1024
/ 1024
* moe_router_topk
)
else:
hccl_tp_buffer_size_moe = 200
_HCCL_GROUP_BUFFER['tp'] = max(hccl_tp_buffer_size_moe, _HCCL_GROUP_BUFFER['tp'])
else:
hccl_tp_buffer_size_mlp = 2 * math.ceil(
seq_length / context_parallel_size * micro_batch_size * hidden_size / 1024 / 1024
)
if args.sequence_parallel:
_HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp
else:
_HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp * 2
if args.sequence_parallel:
hccl_pp_buffer_size = 2 * math.ceil(
seq_length
/ context_parallel_size
/ tensor_model_parallel_size
* micro_batch_size
* hidden_size
/ 1024
/ 1024
)
else:
hccl_pp_buffer_size = 2 * math.ceil(
seq_length / context_parallel_size * micro_batch_size * hidden_size / 1024 / 1024
)
_HCCL_GROUP_BUFFER['pp'] = hccl_pp_buffer_size
_HCCL_GROUP_BUFFER['pp_new_stream'] = hccl_pp_buffer_size
_HCCL_GROUP_BUFFER['mp'] = 10
_HCCL_GROUP_BUFFER['mp_exp'] = 10
if args.hccl_ep_group_buffer_adaptive_factor > 0:
hccl_ep_buffer_size = 2 * math.ceil(
seq_length
/ context_parallel_size
/ tensor_model_parallel_size
* micro_batch_size
* hidden_size
/ 1024
/ 1024
* moe_router_topk
)
else:
hccl_ep_buffer_size = 200
_HCCL_GROUP_BUFFER['exp'] = hccl_ep_buffer_size
if moe_token_dispatcher_type is not None and moe_token_dispatcher_type == 'allgather':
if args.hccl_ep_group_buffer_adaptive_factor > 0:
hccl_tp_ep_buffer_size = 2 * math.ceil(
args.hccl_ep_group_buffer_adaptive_factor
* seq_length
/ context_parallel_size
* micro_batch_size
* hidden_size
* expert_model_parallel_size
/ 1024
/ 1024
)
else:
hccl_tp_ep_buffer_size = 200
_HCCL_GROUP_BUFFER['tp_exp'] = hccl_tp_ep_buffer_size
elif (
moe_token_dispatcher_type is not None and moe_token_dispatcher_type == 'alltoall_seq' and args.moe_tp_extend_ep
):
if args.hccl_ep_group_buffer_adaptive_factor > 0:
hccl_tp_ep_buffer_size = 2 * math.ceil(
args.hccl_ep_group_buffer_adaptive_factor
* seq_length
/ context_parallel_size
/ tensor_model_parallel_size
* micro_batch_size
* hidden_size
* moe_router_topk
/ 1024
/ 1024
)
else:
hccl_tp_ep_buffer_size = 200
_HCCL_GROUP_BUFFER['tp_exp'] = hccl_tp_ep_buffer_size
_HCCL_GROUP_BUFFER['tp_cp'] = 10
if context_parallel_algo == 'ulysses_cp_algo' or context_parallel_algo is None:
hccl_cp_buffer_size = 2 * math.ceil(
seq_length
/ context_parallel_size
* micro_batch_size
* hidden_size
/ tensor_model_parallel_size
/ 1024
/ 1024
)
_HCCL_GROUP_BUFFER['cp'] = hccl_cp_buffer_size
elif context_parallel_algo == 'megatron_cp_algo':
if group_query_attention:
hccl_cp2_buffer_size = 2 * math.ceil(
seq_length
/ context_parallel_size
* micro_batch_size
* hidden_size
/ num_attention_heads
* num_query_groups
/ tensor_model_parallel_size
/ 1024
/ 1024
)
hccl_cp_buffer_size = (
2
* 2
* math.ceil(
seq_length
/ context_parallel_size
* micro_batch_size
* hidden_size
/ num_attention_heads
* num_query_groups
/ tensor_model_parallel_size
/ 1024
/ 1024
)
)
else:
hccl_cp2_buffer_size = 2 * math.ceil(
seq_length
/ context_parallel_size
* micro_batch_size
* hidden_size
/ num_attention_heads
/ tensor_model_parallel_size
/ 1024
/ 1024
)
hccl_cp_buffer_size = (
2
* 2
* math.ceil(
seq_length
/ context_parallel_size
* micro_batch_size
* hidden_size
/ num_attention_heads
/ tensor_model_parallel_size
/ 1024
/ 1024
)
)
if args.cp_window_size > 1:
if args.use_cp_send_recv_overlap:
_HCCL_GROUP_BUFFER['cp2'] = hccl_cp2_buffer_size
_HCCL_GROUP_BUFFER['cp'] = hccl_cp2_buffer_size
_HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp2_buffer_size
_HCCL_GROUP_BUFFER['cp_ring_intra_overlap'] = hccl_cp2_buffer_size
else:
_HCCL_GROUP_BUFFER['cp'] = hccl_cp_buffer_size
_HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp_buffer_size
else:
if args.use_cp_send_recv_overlap:
_HCCL_GROUP_BUFFER['cp2'] = hccl_cp2_buffer_size
_HCCL_GROUP_BUFFER['cp'] = hccl_cp2_buffer_size
else:
_HCCL_GROUP_BUFFER['cp'] = hccl_cp_buffer_size
elif context_parallel_algo == 'hybrid_cp_algo':
ulysses_context_parallel_size = args.ulysses_degree_in_cp
ring_context_parallel_size = context_parallel_size / ulysses_context_parallel_size
hccl_cp_ulysess_buffer_size = 2 * math.ceil(
seq_length
/ ulysses_context_parallel_size
* micro_batch_size
* hidden_size
/ tensor_model_parallel_size
/ 1024
/ 1024
)
if group_query_attention:
hccl_cp_ring_buffer_size = 2 * math.ceil(
seq_length
/ ring_context_parallel_size
* micro_batch_size
* hidden_size
/ num_attention_heads
* num_query_groups
/ tensor_model_parallel_size
/ 1024
/ 1024
)
else:
hccl_cp_ring_buffer_size = 2 * math.ceil(
seq_length
/ ring_context_parallel_size
* micro_batch_size
* hidden_size
/ num_attention_heads
/ tensor_model_parallel_size
/ 1024
/ 1024
)
if args.cp_window_size > 1:
if args.use_cp_send_recv_overlap:
_HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
_HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size
_HCCL_GROUP_BUFFER['cp2'] = hccl_cp_ring_buffer_size
_HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp_ring_buffer_size
_HCCL_GROUP_BUFFER['cp_ring_intra_overlap'] = hccl_cp_ring_buffer_size
_HCCL_GROUP_BUFFER['cp'] = 10
else:
_HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
_HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size * 2
_HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp_ring_buffer_size * 2
_HCCL_GROUP_BUFFER['cp'] = 10
else:
if args.use_cp_send_recv_overlap:
_HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
_HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size
_HCCL_GROUP_BUFFER['cp2'] = hccl_cp_ring_buffer_size
_HCCL_GROUP_BUFFER['cp'] = 10
else:
_HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
_HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size * 2
_HCCL_GROUP_BUFFER['cp'] = 10