import os
import re
from typing import (
Optional,
Callable,
List
)
from dataclasses import dataclass
from copy import deepcopy
from sys import maxsize
from collections import defaultdict
from megatron.training import get_args
import torch.distributed as dist
from megatron.core import parallel_state as ps
domains = ['tp', 'dp', 'pp', 'ep', 'cp']
def is_a3():
try:
cmd = 'npu-smi info -t board -i 0 -c 0 | grep Chip | grep Name'
chip_name = os.popen(cmd).read().strip()
is_910B = bool(re.search(r'910B', chip_name, re.IGNORECASE))
is_ascend910 = bool(re.search(r'Ascend910|Ascend 910', chip_name, re.IGNORECASE))
except Exception as e:
raise RuntimeError(f"Fail to get chip name : {str(e)}") from e
if is_ascend910:
return True
return False
is_a3_version = is_a3()
@dataclass
class ParallelCommDomain:
ip_list: List[List[str]]
rank_list: List[List[int]]
world_size: int
parallel_type: str
comm_amount: int
comm_amount_no_overlap: int
def generate_masked_orthogonal_rank_groups(
world_size: int, parallel_size: List[int],
mask: List[bool]) -> List[List[int]]:
def prefix_product(a: List[int], init=1) -> List[int]:
r = [init]
for v in a:
init = init * v
r.append(init)
return r
def inner_product(a: List[int], b: List[int]) -> int:
return sum([x * y for x, y in zip(a, b)])
def decompose(index, shape, stride=None):
'''
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
'''
if stride is None:
stride = prefix_product(shape)
idx = [(index // d) % s for s, d in zip(shape, stride)]
idx_stride_sum = sum([x * y for x, y in zip(idx, stride[:-1])])
if idx_stride_sum != index:
raise ValueError(
"idx {} with shape {} mismatch the return idx {}".format(
index, shape, idx
)
)
return idx
masked_shape = [s for s, m in zip(parallel_size, mask) if m]
unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]
global_stride = prefix_product(parallel_size)
masked_stride = [d for d, m in zip(global_stride, mask) if m]
unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]
group_size = prefix_product(masked_shape)[-1]
num_of_group = world_size // group_size
ranks = []
for group_index in range(num_of_group):
decomposed_group_idx = decompose(group_index, unmasked_shape)
rank = []
for rank_in_group in range(group_size):
decomposed_rank_idx = decompose(rank_in_group, masked_shape)
rank.append(
inner_product(decomposed_rank_idx, masked_stride) +
inner_product(decomposed_group_idx, unmasked_stride))
ranks.append(rank)
return ranks
class RankGenerator(object):
"""A class for generating rank groups for different modes of parallelism."""
def __init__(self,
tp: int,
ep: int,
dp: int,
pp: int,
cp: int,
order: str,
rank_offset: int = 0) -> None:
self.tp = tp
self.ep = ep
self.dp = dp
self.pp = pp
self.cp = cp
self.rank_offset = rank_offset
self.world_size = tp * dp * pp * cp
self.name_to_size = {
"tp": self.tp,
"pp": self.pp,
"dp": self.dp,
"ep": self.ep,
"cp": self.cp,
}
self.order = order
order = order.lower()
if 'ep' in order:
if 'ep-dp' not in order and 'dp-ep' not in order:
raise RuntimeError(
f"The ep and dp must be adjacent in order ({self.order}).")
for name in self.name_to_size.keys():
if name not in order and self.name_to_size[name] != 1:
raise RuntimeError(
f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't"
f"specified the order ({self.order}).")
elif name not in order:
order = order + '-' + name
self.order_w_ep = order
self.order_wo_ep = '-'.join(
[token for token in order.split('-') if token != 'ep'])
self.ordered_size_wo_ep = []
self.ordered_size_w_ep = []
for token in order.split('-'):
if token == 'dp':
self.ordered_size_w_ep.append(self.dp // self.ep)
self.ordered_size_wo_ep.append(self.dp)
elif token == 'ep':
self.ordered_size_w_ep.append(self.ep)
else:
self.ordered_size_w_ep.append(self.name_to_size[token])
self.ordered_size_wo_ep.append(self.name_to_size[token])
def generate_target_parallelism_match_mask(self, parallelism_order: str, target_parallelism_tokens: str):
ordered_parallelism_tokens = parallelism_order.split('-')
target_parallelism_token_list = target_parallelism_tokens.split('-')
match_mask = [False] * len(ordered_parallelism_tokens)
for parallelism_identifier in target_parallelism_token_list:
match_mask[ordered_parallelism_tokens.index(parallelism_identifier)] = True
return match_mask
def get_ranks(self, token, independent_ep=False):
"""Get rank group by input token.
Args:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
independent_ep (bool: True):
This flag controls whether we treat EP and DP independently.
EP shares ranks with DP, if we want to get ranks related to
EP, we should set the flag. For example, get_ranks('dp', True)
will get DP modulo EP group, and get_ranks('dp', False) will
get full DP group.
"""
if independent_ep:
parallel_size = self.ordered_size_w_ep
order = self.order_w_ep
else:
parallel_size = self.ordered_size_wo_ep
order = self.order_wo_ep
mask = self.generate_target_parallelism_match_mask(order, token)
ranks = generate_masked_orthogonal_rank_groups(self.world_size,
parallel_size, mask)
if self.rank_offset > 0:
for rank_group in ranks:
rank_group[:] = [rank + self.rank_offset for rank in rank_group]
return ranks
def RankGenerate():
args = get_args()
tp = args.tensor_model_parallel_size
ep = args.expert_model_parallel_size
dp = args.data_parallel_size
pp = args.pipeline_model_parallel_size
cp = args.context_parallel_size
g = RankGenerator(
tp=tp,
ep=ep,
dp=dp,
pp=pp,
cp=cp,
order='tp-cp-ep-dp-pp',
rank_offset=0,
)
return g
def get_tensor_parallel_comm_domain():
world_size = dist.get_world_size()
args = get_args()
rank_num = int(args.tensor_model_parallel_size)
seq_length = int(args.seq_length)
hidden_size = int(args.hidden_size)
num_layers = int(args.num_layers)
global_batch_size = int(args.global_batch_size)
micro_batch_size = int(args.micro_batch_size)
sequence_parallel = getattr(args, 'sequence_parallel', False)
use_ascend_mc2 = getattr(args, 'use_ascend_mc2', False)
use_ascend_coc = getattr(args, 'use_ascend_coc', False)
micro_batches = global_batch_size // micro_batch_size
comm_tp_num = 4
comm_parallel_ratio = (rank_num - 1) / rank_num
comm_broadcast = 4 * seq_length
comm_embedding = 4 * seq_length * hidden_size
comm_transformer = 2 * num_layers * comm_tp_num * seq_length * hidden_size
comm_vocab_parallel_ce = 2 * 3 * seq_length
if sequence_parallel:
comm_embedding += seq_length * hidden_size
comm_transformer += seq_length * hidden_size
comm_transformer += 2 * num_layers * seq_length * hidden_size
comm_tp_amount = (comm_broadcast + comm_embedding + comm_transformer +
comm_vocab_parallel_ce) * comm_parallel_ratio
comm_tp_groups_amount = micro_batches * comm_tp_amount * rank_num
comm_non_overlap = comm_broadcast + comm_vocab_parallel_ce
if use_ascend_mc2 or use_ascend_coc:
comm_non_overlap += seq_length * hidden_size
else:
comm_non_overlap += 2 * num_layers * comm_tp_num // 2 * seq_length * hidden_size
comm_non_overlap += 2 * num_layers * seq_length * hidden_size
comm_non_overlap_groups_amount = micro_batches * comm_non_overlap * comm_parallel_ratio * rank_num
g = RankGenerate()
tp_group_ranks = g.get_ranks('tp')
tp_groups_ips = None
num_tp_groups = world_size // rank_num
return ParallelCommDomain(
tp_groups_ips, tp_group_ranks, world_size, 'tp',
int(comm_tp_groups_amount) * num_tp_groups,
int(comm_non_overlap_groups_amount) * num_tp_groups)
def get_pipeline_parallel_comm_domain():
world_size = dist.get_world_size()
args = get_args()
rank_num = int(args.pipeline_model_parallel_size)
micro_batch_size = int(args.micro_batch_size)
tensor_model_parallel_size = int(args.tensor_model_parallel_size)
pipeline_model_parallel_size = int(args.pipeline_model_parallel_size)
seq_length = int(args.seq_length)
hidden_size = int(args.hidden_size)
num_layers = int(args.num_layers)
global_batch_size = int(args.global_batch_size)
sequence_parallel = getattr(args, 'sequence_parallel', False)
num_layers_per_virtual_stage = getattr(
args, 'num_layers_per_virtual_pipeline_stage', None)
num_model_chunks = (num_layers // pipeline_model_parallel_size)
pipeline_stage_num = pipeline_model_parallel_size
micro_batches = global_batch_size // micro_batch_size
if tensor_model_parallel_size > 1 and sequence_parallel:
seq_length = seq_length // tensor_model_parallel_size
if (num_layers_per_virtual_stage is not None
and int(num_layers_per_virtual_stage) < num_model_chunks):
pipeline_stage_num = num_layers // int(num_layers_per_virtual_stage)
else:
num_model_chunks = 1
comm_pp_num = 2
comm_parallel_ratio = (rank_num - 1) / rank_num
comm_recv_send = ((pipeline_stage_num - 1) * seq_length *
micro_batch_size * hidden_size) * comm_pp_num
if tensor_model_parallel_size == 1:
comm_broadcast = 4 * seq_length
comm_vocab_parallel_ce = 2 * 3 * seq_length
comm_pp_amount = (
comm_broadcast + comm_vocab_parallel_ce
) * comm_parallel_ratio * num_model_chunks + comm_recv_send
else:
comm_pp_amount = comm_recv_send
comm_pp_groups_amount = micro_batches * comm_pp_amount
comm_non_overlap_groups_amount = comm_pp_groups_amount
if (num_layers_per_virtual_stage is not None
and int(num_layers_per_virtual_stage) < num_model_chunks
and micro_batches > pipeline_model_parallel_size):
comm_non_overlap = 0
comm_non_overlap_stage = (seq_length * micro_batch_size *
hidden_size) * comm_pp_num
total_num_microbatches = num_model_chunks * micro_batches
for rankid in range(rank_num):
num_warmup_microbatches = (pipeline_model_parallel_size - rankid -
1) * 2
num_warmup_microbatches += (num_model_chunks -
1) * pipeline_model_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
total_num_microbatches)
comm_non_overlap += num_warmup_microbatches * comm_non_overlap_stage
if tensor_model_parallel_size == 1:
comm_non_overlap_groups_amount = (
comm_broadcast + comm_vocab_parallel_ce
) * comm_parallel_ratio * num_model_chunks * micro_batches + comm_non_overlap
else:
comm_non_overlap_groups_amount = comm_non_overlap
g = RankGenerate()
pp_group_ranks = g.get_ranks('pp')
pp_groups_ips = None
num_pp_groups = world_size // rank_num
return ParallelCommDomain(pp_groups_ips, pp_group_ranks, world_size, 'pp',
comm_pp_groups_amount * num_pp_groups,
comm_non_overlap_groups_amount * num_pp_groups)
def get_data_parallel_comm_domain():
world_size = dist.get_world_size()
args = get_args()
rank_num = int(args.data_parallel_size)
hidden_size = int(args.hidden_size)
num_layers = int(args.num_layers)
kv_channels = int(args.kv_channels)
num_attention_heads = int(args.num_attention_heads)
ffn_hidden_size = int(args.ffn_hidden_size)
padded_vocab_size = int(args.padded_vocab_size)
num_experts = int(args.num_experts) if getattr(args, 'num_experts',
False) else 1
pipeline_model_parallel_size = int(args.pipeline_model_parallel_size)
tensor_model_parallel_size = int(args.tensor_model_parallel_size)
num_query_groups = (int(args.num_query_groups) if getattr(
args, 'num_query_groups', False) else num_attention_heads)
num_layers_per_virtual_stage = getattr(
args, 'num_layers_per_virtual_pipeline_stage', None)
use_distributed_optimizer = getattr(args, 'use_distributed_optimizer',
False)
overlap_grad_reduce = getattr(args, 'overlap_grad_reduce', False)
overlap_param_gather = getattr(args, 'overlap_param_gather', False)
group_query_attention = getattr(args, 'group_query_attention', False)
if not group_query_attention:
num_query_groups = num_attention_heads
swiglu = getattr(args, 'swiglu', False)
gated_linear_multiplier = 1.5 if swiglu else 1
num_model_chunks = (num_layers // pipeline_model_parallel_size)
untie_embeddings_and_output_weights = getattr(
args, 'untie_embeddings_and_output_weights', False)
query_projection_size = kv_channels * num_attention_heads
query_projection_ratio = query_projection_size / hidden_size
transformer_params = (
2
* num_layers * hidden_size ** 2 *
((1 + (num_query_groups / num_attention_heads)) *
query_projection_ratio
+ (ffn_hidden_size / hidden_size) * num_experts *
gated_linear_multiplier
+ 2 / hidden_size
+ 1 / (num_layers * hidden_size)
))
embedding_size = hidden_size * padded_vocab_size
pp_size = pipeline_model_parallel_size
tp_size = tensor_model_parallel_size
pipeline_stage_num = pp_size
if (num_layers_per_virtual_stage is not None
and int(num_layers_per_virtual_stage) < num_model_chunks):
pipeline_stage_num = num_layers // int(num_layers_per_virtual_stage)
if untie_embeddings_and_output_weights:
embedding_params = 2 * embedding_size
else:
embedding_params = embedding_size
num_total_parameters = (transformer_params / pp_size +
embedding_params) / tp_size
comm_dp_num = 2
comm_ratio = (rank_num - 1) / rank_num
comm_dp_amount = comm_dp_num * num_total_parameters * comm_ratio
comm_dp_groups_amount = comm_dp_amount * tp_size * pipeline_stage_num
comm_non_overlap_groups_amount = comm_dp_groups_amount
if (use_distributed_optimizer and overlap_grad_reduce):
comm_non_overlap_groups_amount -= (comm_dp_groups_amount * 0.5)
if (use_distributed_optimizer and overlap_param_gather):
comm_non_overlap_groups_amount -= (comm_dp_groups_amount * 0.5)
g = RankGenerate()
dp_group_ranks = g.get_ranks('dp')
dp_groups_ips = None
num_dp_groups = world_size // rank_num
return ParallelCommDomain(dp_groups_ips, dp_group_ranks, world_size, 'dp',
comm_dp_groups_amount * num_dp_groups,
comm_non_overlap_groups_amount * num_dp_groups)
def get_context_parallel_comm_domain():
world_size = dist.get_world_size()
args = get_args()
rank_num = int(args.context_parallel_size)
cp_groups_ips = None
g = RankGenerate()
cp_group_ranks = g.get_ranks('cp')
seq_length = args.seq_length
hidden_size = args.hidden_size
micro_batch_size = args.micro_batch_size
num_layers = args.num_layers
if args.context_parallel_algo == 'megatron_cp_algo':
comm_per_layer = 0
comm_per_layer += seq_length * micro_batch_size * hidden_size * (
rank_num - 1) / rank_num * 2
comm_per_layer += seq_length * micro_batch_size * hidden_size * (
rank_num - 1) / rank_num * 2 * 2
comm_cp_groups_amount = comm_per_layer * num_layers
comm_non_overlap_groups_amount = 0
elif args.context_parallel_algo == 'ulysses_cp_algo':
comm_per_layer = (
rank_num -
1) / rank_num * seq_length * micro_batch_size * hidden_size * 3
comm_cp_groups_amount = comm_per_layer * num_layers * 2
comm_non_overlap_groups_amount = comm_cp_groups_amount
else:
ring_degree = rank_num // args.ulysses_degree_in_cp
fix_sub_seq_length = seq_length // ring_degree
ulysses_comm_per_layer = (
rank_num - 1
) / rank_num * fix_sub_seq_length * micro_batch_size * hidden_size * 3 * 2
ring_amount_per_layer = 0
ring_comm_amount_per_layer = 0
ring_comm_amount_per_layer += seq_length * micro_batch_size * hidden_size * (
rank_num - 1) / rank_num * 2
ring_comm_amount_per_layer += seq_length * micro_batch_size * hidden_size * (
rank_num - 1) / rank_num * 2 * 2
ring_comm_amount_per_layer /= args.ulysses_degree_in_cp
comm_per_layer = ulysses_comm_per_layer + ring_comm_amount_per_layer
comm_cp_groups_amount = comm_per_layer * num_layers
comm_non_overlap_groups_amount = ulysses_comm_per_layer * num_layers
return ParallelCommDomain(cp_groups_ips, cp_group_ranks, world_size, 'cp',
comm_cp_groups_amount,
comm_non_overlap_groups_amount)
def get_expert_parallel_comm_domain():
world_size = dist.get_world_size()
args = get_args()
rank_num = int(args.expert_model_parallel_size)
num_ep_groups = args.data_parallel_size // rank_num
ep_groups_ips = None
g = RankGenerate()
ep_group_ranks = g.get_ranks('ep', independent_ep=True)
num_ep_groups = world_size // rank_num
topk = args.moe_router_topk
seq_length = args.seq_length
hidden_size = args.hidden_size
micro_batch_size = args.micro_batch_size
num_layers = args.num_layers
if args.moe_token_dispatcher_type == "alltoall":
num_tokens = micro_batch_size * seq_length * topk
ep_comm_per_layer = num_tokens * (
rank_num - 1) / rank_num * hidden_size * rank_num * 2
comm_ep_groups_amount = ep_comm_per_layer * num_layers
comm_non_overlap_groups_amount = comm_ep_groups_amount
elif args.moe_token_dispatcher_type == "allgather":
num_tokens = micro_batch_size * seq_length * topk
ep_comm_per_layer = num_tokens * (rank_num -
1) * hidden_size * rank_num * 2
comm_ep_groups_amount = ep_comm_per_layer * num_layers
comm_non_overlap_groups_amount = comm_ep_groups_amount
return ParallelCommDomain(ep_groups_ips, ep_group_ranks, world_size, 'ep',
comm_ep_groups_amount * num_ep_groups,
comm_non_overlap_groups_amount * num_ep_groups)
def get_overlap_time_dict():
time_overlap = {}
keys = [
(x, y)
for x in domains
for y in domains
]
for key in keys:
time_overlap[key] = 0
args = get_args()
time_overlap[('pp', 'tp')] = 1
time_overlap[('pp', 'dp')] = 1
time_overlap[('pp', 'cp')] = 1
time_overlap[('pp', 'ep')] = 1
time_overlap[('tp', 'pp')] = 1
time_overlap[('dp', 'pp')] = 1
time_overlap[('cp', 'pp')] = 1
time_overlap[('ep', 'pp')] = 1
if args.overlap_grad_reduce or args.overlap_param_gather:
time_overlap[('dp', 'tp')] = 1
time_overlap[('dp', 'pp')] = 1
time_overlap[('dp', 'cp')] = 1
time_overlap[('dp', 'ep')] = 1
time_overlap[('tp', 'dp')] = 1
time_overlap[('pp', 'dp')] = 1
time_overlap[('cp', 'dp')] = 1
time_overlap[('ep', 'dp')] = 1
return time_overlap
def get_overlap_space_dict(domain_partition_information, link_type="SDMA"):
boundary_roce_910b = 8
boundary_roce_910_93 = os.environ.get('SuperNodeDieNum', 384)
if is_a3_version:
if link_type == "SDMA":
cross_boundary = []
for domain in domains:
cross_flag = is_adjacent_two_node_group(domain_partition_information[domain])
if not cross_flag:
cross_boundary.append(domain)
return overlap_space_padding(cross_boundary)
elif link_type == "ROCE":
cross_boundary = []
for domain in domains:
cross_flag = is_cross_boundary(domain_partition_information[domain], boundary_roce_910_93)
if cross_flag:
cross_boundary.append(domain)
return overlap_space_padding(cross_boundary)
else:
raise ValueError(f"Unsupported link type: {link_type}, only 'SDMA' and 'ROCE' are supported")
else:
cross_boundary = []
for domain in domains:
cross_flag = is_cross_boundary(domain_partition_information[domain], boundary_roce_910b)
if cross_flag:
cross_boundary.append(domain)
return overlap_space_padding(cross_boundary)
def overlap_space_padding(cross_boundary):
space_overlap = {}
keys = [
(x, y)
for x in domains
for y in domains
]
for key in keys:
space_overlap[key] = 0
if not cross_boundary or len(cross_boundary) == 1:
return space_overlap
for domain_a in cross_boundary:
for domain_b in cross_boundary:
if domain_a != domain_b:
space_overlap[(domain_a, domain_b)] = 1
return space_overlap
def is_cross_boundary(comm_domains: list[list[int]], boundary: int) -> bool:
if boundary <= 0:
raise ValueError("Boundary value must be a positive integer")
for domain in comm_domains:
if not domain:
continue
machine_ids = {rank // boundary for rank in domain}
if len(machine_ids) > 1:
return True
return False
def is_adjacent_two_node_group(rank_groups):
for group in rank_groups:
if not isinstance(group, list) or len(group) != 2:
return False
if not (isinstance(group[0], int) and isinstance(group[1], int)):
return False
if abs(group[1] - group[0]) != 1:
return False
if group[0] % 2 != 0:
return False
return True