from typing import List, Tuple
import torch
lib = torch.library.Library("fsdp", "FRAGMENT")
@torch.library.impl(lib, "chunk_cat", "PrivateUse1")
def chunk_cat(
tensors: List[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
tensors = [tensor.contiguous() for tensor in tensors]
out = out.contiguous()
torch._chunk_cat(tensors, dim, num_chunks, out=out)
@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1")
def all_gather_copy_in_npu(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
with torch.no_grad():
if foreach_copy_dsts[0].device == all_gather_inputs[0].device:
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs, non_blocking=True)
else:
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
return all_gather_input, all_gather_output
@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
dim: int,
out: List[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)