_FSDP2_EP_RANK = None
_FSDP2_EP_SIZE = None
_FSDP2_CHECK_MOE_FUNC = None
_FSDP2_EP_GROUP = None
_FSDP2_EP_FSDP_GROUP = None
def get_ep_rank():
global _FSDP2_EP_RANK
if _FSDP2_EP_RANK:
return _FSDP2_EP_RANK
else:
return 0
def get_ep_size():
global _FSDP2_EP_SIZE
if _FSDP2_EP_SIZE:
return _FSDP2_EP_SIZE
else:
return 1
def get_ep_group():
global _FSDP2_EP_GROUP
return _FSDP2_EP_GROUP
def get_ep_fsdp_group():
global _FSDP2_EP_FSDP_GROUP
return _FSDP2_EP_FSDP_GROUP
def get_check_moe_func():
global _FSDP2_CHECK_MOE_FUNC
if _FSDP2_CHECK_MOE_FUNC:
return _FSDP2_CHECK_MOE_FUNC
else:
return lambda _: False
def set_ep_rank(ep_rank):
global _FSDP2_EP_RANK
_FSDP2_EP_RANK = ep_rank
def set_ep_size(ep_size):
global _FSDP2_EP_SIZE
_FSDP2_EP_SIZE = ep_size
def set_check_moe_func(fn):
global _FSDP2_CHECK_MOE_FUNC
_FSDP2_CHECK_MOE_FUNC = fn
def set_ep_group(group):
global _FSDP2_EP_GROUP
_FSDP2_EP_GROUP = group
def set_ep_fsdp_group(group):
global _FSDP2_EP_FSDP_GROUP
_FSDP2_EP_FSDP_GROUP = group