from abc import ABC, abstractmethod
from mindspeed.te.pytorch.fp8 import MatmulKey
from mindspeed.te.pytorch.module_typing import FP8Metadata
class CommOverlapConfig:
@classmethod
def get_tp_size(cls):
if cls.tp_size is None:
from megatron.core.parallel_state import get_tensor_model_parallel_world_size
cls.tp_size = get_tensor_model_parallel_world_size()
return cls.tp_size
@classmethod
def init_tp_size(cls, tp_size):
cls.tp_size = tp_size
@classmethod
def get_tp_group(cls):
if cls.tp_group is None:
from megatron.core.parallel_state import get_tensor_model_parallel_group
cls.tp_group = get_tensor_model_parallel_group()
return cls.tp_group
@classmethod
def init_tp_group(cls, tp_group):
cls.tp_group = tp_group
tp_size = None
tp_group = None
save_allgather_input = True
parallel_num = 2
class CommOverlapOps(ABC):
@staticmethod
@abstractmethod
def allgather_matmul(input_, weight, bias, fp8_meta: FP8Metadata, key: MatmulKey):
...
@staticmethod
@abstractmethod
def matmul_reduce_scatter(input_, weight, bias, fp8_meta: FP8Metadata, key: MatmulKey):
...
@staticmethod
@abstractmethod
def matmul_all_reduce(input_, weight, bias, fp8_meta: FP8Metadata, key: MatmulKey, fp8_enable: bool):
...
COMM_OVERLAP_CONFIG = CommOverlapConfig()