386d82fd创建于 2025年4月21日历史提交
import torch.distributed as dist
import xfuser
import torch

_LLM_TP_ENABLE = False


def initialize_parall_group(ring_degree, ulysses_degree, tensor_parallel_degree, llm_tensor_parallel_degree=None):
    dist.init_process_group("hccl")
    xfuser.core.distributed.init_distributed_environment(
        rank=dist.get_rank(), 
        world_size=dist.get_world_size(),
        backend="hccl"
    )
    
    xfuser.core.distributed.initialize_model_parallel(
        sequence_parallel_degree=ulysses_degree,
        ring_degree=ring_degree,
        ulysses_degree=ulysses_degree,
        tensor_parallel_degree=tensor_parallel_degree,
    )
    torch.npu.set_device(dist.get_rank())


def enable_llm_tensor_model_parallel():
    global _LLM_TP_ENABLE
    _LLM_TP_ENABLE = True


def get_llm_tensor_model_parallel_world_size():
    return dist.get_world_size() if _LLM_TP_ENABLE else 1
    

def get_llm_tensor_model_parallel_rank():
    return dist.get_rank() if _LLM_TP_ENABLE else 0


def get_llm_tensor_model_parallel_group():
    return dist.group.WORLD if _LLM_TP_ENABLE else None


def get_parallel_group():
    return xfuser.core.distributed.get_world_group()


def get_sequence_parallel_world_size():
    return xfuser.core.distributed.parallel_state.get_sequence_parallel_world_size()


def get_sequence_parallel_rank():
    return xfuser.core.distributed.parallel_state.get_sequence_parallel_rank()


def get_sp_group():
    return xfuser.core.distributed.parallel_state.get_sp_group()


def parallel_forward(fn_):
    def wrapTheFunction(_, hidden_states, *args, **kwargs):
        if kwargs['parallel']:            
            hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
            kwargs['attn_mask'] = torch.chunk(kwargs['attn_mask'], get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
        output = fn_(_, hidden_states, *args, **kwargs)
        
        if kwargs['parallel']:
            output = get_sp_group().all_gather(output.contiguous(), dim=-2)
        
        return output
     
    return wrapTheFunction