import builtins
import os
import sys
import torch
import torch.distributed as dist
from torch.autograd import Function
from torch.autograd.profiler import record_function
from torch.nn.parallel import DistributedDataParallel as DDP
try:
import torch_ccl
except ImportError as e:
torch_ccl = False
try:
import torch_ucc
except ImportError as e:
torch_ucc = False
my_rank = -1
my_size = -1
my_local_rank = -1
my_local_size = -1
alltoall_supported = False
a2a_impl = os.environ.get("DLRM_ALLTOALL_IMPL", "")
myreq = None
def env2int(env_list, default=-1):
for e in env_list:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default
def get_my_slice(n):
k, m = divmod(n, my_size)
return slice(
my_rank * k + min(my_rank, m), (my_rank + 1) * k + min(my_rank + 1, m), 1
)
def get_split_lengths(n):
k, m = divmod(n, my_size)
if m == 0:
splits = None
my_len = k
else:
splits = [(k + 1) if i < m else k for i in range(my_size)]
my_len = splits[my_rank]
return (my_len, splits)
def init_distributed(rank=-1, local_rank=-1, size=-1, use_npu=False, backend=""):
global myreq
global my_rank
global my_size
global my_local_rank
global my_local_size
global a2a_impl
global alltoall_supported
if size > 1:
my_local_rank = local_rank
my_local_size = size
if use_npu:
if my_local_size > torch.npu.device_count():
print(
"Not sufficient NPUs available... local_size = %d, nnpus = %d"
% (my_local_size, torch.npu.device_count())
)
sys.exit(1)
torch.npu.set_device(my_local_rank)
dist.init_process_group(backend, rank=local_rank, world_size=size)
my_rank = local_rank
my_size = size
if my_rank == 0:
print("Running on %d ranks using %s backend" % (my_size, backend))
if hasattr(dist, "all_to_all_single"):
try:
t = torch.zeros([4])
if use_npu:
t = t.npu()
dist.all_to_all_single(t, t)
alltoall_supported = True
except RuntimeError as err:
print("fail to enable all_to_all_single primitive: %s" % err)
if a2a_impl == "alltoall" and alltoall_supported == False:
print(
"Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall"
% (a2a_impl, backend)
)
a2a_impl = "scatter"
if a2a_impl != "":
print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl)
else:
my_rank = 0
my_size = 1
my_local_rank = 0
my_local_size = 1
print_all(
"world size: %d, current rank: %d, local rank: %d"
% (my_size, my_rank, my_local_rank)
)
myreq = Request()
class Request(object):
def __init__(self):
self.req = None
self.tensor = None
self.WaitFunction = All2All_Scatter_Wait
def wait(self):
ret = self.WaitFunction.apply(*self.tensor)
self.req = None
self.tensor = None
return ret
class All2All_ScatterList_Req(Function):
@staticmethod
def forward(ctx, a2a_info, *inputs):
global myreq
batch_split_lengths = (
a2a_info.global_batch_partition_slices
if a2a_info.global_batch_partition_slices
else a2a_info.local_batch_num
)
table_split_lengths = (
a2a_info.global_table_wise_parition_slices
if a2a_info.global_table_wise_parition_slices
else [a2a_info.local_table_num] * my_size
)
gather_list = []
req_list = []
for i in range(my_size):
for j in range(table_split_lengths[i]):
out_tensor = inputs[0].new_empty(
[a2a_info.local_batch_num, a2a_info.emb_dim]
)
scatter_list = (
list(inputs[j].split(batch_split_lengths, dim=0))
if i == my_rank
else []
)
req = dist.scatter(out_tensor, scatter_list, src=i, async_op=True)
gather_list.append(out_tensor)
req_list.append(req)
myreq.req = req_list
myreq.tensor = tuple(gather_list)
myreq.a2a_info = a2a_info
return myreq.tensor
@staticmethod
def backward(ctx, *grad_output):
global myreq
for r in myreq.req:
r.wait()
myreq.req = None
grad_inputs = myreq.tensor
myreq.tensor = None
return (None, *grad_inputs)
class All2All_ScatterList_Wait(Function):
@staticmethod
def forward(ctx, *output):
global myreq
ctx.a2a_info = myreq.a2a_info
for r in myreq.req:
r.wait()
myreq.req = None
myreq.tensor = None
return output
@staticmethod
def backward(ctx, *grad_output):
global myreq
a2a_info = ctx.a2a_info
grad_output = [t.contiguous() for t in grad_output]
batch_split_lengths = (
a2a_info.global_batch_partition_slices
if a2a_info.global_batch_partition_slices
else [a2a_info.local_batch_num] * my_size
)
per_rank_table_splits = (
a2a_info.global_table_wise_parition_slices
if a2a_info.global_table_wise_parition_slices
else [a2a_info.local_table_num] * my_size
)
grad_inputs = [
grad_output[0].new_empty([ctx.a2a_info.batch_size, ctx.a2a_info.emb_dim])
for _ in range(a2a_info.local_table_num)
]
req_list = []
ind = 0
for i in range(my_size):
for j in range(per_rank_table_splits[i]):
gather_list = (
list(grad_inputs[j].split(batch_split_lengths, dim=0))
if i == my_rank
else None
)
req = dist.gather(grad_output[ind], gather_list, dst=i, async_op=True)
req_list.append(req)
ind += 1
myreq.req = req_list
myreq.tensor = grad_inputs
return tuple(grad_output)
class All2All_Scatter_Req(Function):
@staticmethod
def forward(ctx, a2a_info, *inputs):
global myreq
batch_split_lengths = (
a2a_info.global_batch_partition_slices
if a2a_info.global_batch_partition_slices
else a2a_info.local_batch_num
)
table_split_lengths = (
a2a_info.global_table_wise_parition_slices
if a2a_info.global_table_wise_parition_slices
else [a2a_info.local_table_num] * my_size
)
input = torch.cat(inputs, dim=1)
scatter_list = list(input.split(batch_split_lengths, dim=0))
gather_list = []
req_list = []
for i in range(my_size):
out_tensor = input.new_empty(
[a2a_info.local_batch_num, table_split_lengths[i] * a2a_info.emb_dim]
)
req = dist.scatter(
out_tensor, scatter_list if i == my_rank else [], src=i, async_op=True
)
gather_list.append(out_tensor)
req_list.append(req)
myreq.req = req_list
myreq.tensor = tuple(gather_list)
myreq.a2a_info = a2a_info
ctx.a2a_info = a2a_info
return myreq.tensor
@staticmethod
def backward(ctx, *grad_output):
global myreq
for r in myreq.req:
r.wait()
myreq.req = None
grad_input = myreq.tensor
grad_inputs = grad_input.split(ctx.a2a_info.emb_dim, dim=1)
myreq.tensor = None
return (None, *grad_inputs)
class All2All_Scatter_Wait(Function):
@staticmethod
def forward(ctx, *output):
global myreq
ctx.a2a_info = myreq.a2a_info
for r in myreq.req:
r.wait()
myreq.req = None
myreq.tensor = None
return output
@staticmethod
def backward(ctx, *grad_output):
global myreq
assert len(grad_output) == my_size
scatter_list = [t.contiguous() for t in grad_output]
a2a_info = ctx.a2a_info
batch_split_lengths = (
a2a_info.global_batch_partition_slices
if a2a_info.global_batch_partition_slices
else a2a_info.local_batch_num
)
table_split_lengths = (
a2a_info.global_table_wise_parition_slices
if a2a_info.global_table_wise_parition_slices
else [a2a_info.local_table_num] * my_size
)
grad_input = grad_output[0].new_empty(
[a2a_info.batch_size, a2a_info.emb_dim * a2a_info.local_table_num]
)
gather_list = list(grad_input.split(batch_split_lengths, dim=0))
req_list = []
for i in range(my_size):
req = dist.gather(
scatter_list[i],
gather_list if i == my_rank else [],
dst=i,
async_op=True,
)
req_list.append(req)
myreq.req = req_list
myreq.tensor = grad_input
return grad_output
class All2All_Req(Function):
@staticmethod
def forward(ctx, a2a_info, *inputs):
global myreq
with record_function("DLRM alltoall_req_fwd_single"):
batch_split_lengths = a2a_info.global_batch_partition_slices
if batch_split_lengths:
batch_split_lengths = [
m * a2a_info.emb_dim * a2a_info.local_table_num
for m in batch_split_lengths
]
table_split_lengths = a2a_info.global_table_wise_parition_slices
if table_split_lengths:
table_split_lengths = [
a2a_info.local_batch_num * e * a2a_info.emb_dim
for e in table_split_lengths
]
input = torch.cat(inputs, dim=1).view([-1])
output = input.new_empty(
[
a2a_info.global_table_num
* a2a_info.local_batch_num
* a2a_info.emb_dim
]
)
req = dist.all_to_all_single(
output, input, table_split_lengths, batch_split_lengths, async_op=True
)
myreq.req = req
myreq.tensor = []
myreq.tensor.append(output)
myreq.tensor = tuple(myreq.tensor)
a2a_info.batch_split_lengths = batch_split_lengths
a2a_info.table_split_lengths = table_split_lengths
myreq.a2a_info = a2a_info
ctx.a2a_info = a2a_info
return myreq.tensor
@staticmethod
def backward(ctx, *grad_output):
global myreq
with record_function("DLRM alltoall_req_bwd_single"):
a2a_info = ctx.a2a_info
myreq.req.wait()
myreq.req = None
grad_input = myreq.tensor
grad_inputs = grad_input.view([a2a_info.batch_size, -1]).split(
a2a_info.emb_dim, dim=1
)
grad_inputs = [gin.contiguous() for gin in grad_inputs]
myreq.tensor = None
return (None, *grad_inputs)
class All2All_Wait(Function):
@staticmethod
def forward(ctx, *output):
global myreq
with record_function("DLRM alltoall_wait_fwd_single"):
a2a_info = myreq.a2a_info
ctx.a2a_info = a2a_info
myreq.req.wait()
myreq.req = None
myreq.tensor = None
table_split_lengths = (
a2a_info.table_split_lengths
if a2a_info.table_split_lengths
else a2a_info.local_table_num
* a2a_info.local_batch_num
* a2a_info.emb_dim
)
outputs = output[0].split(table_split_lengths)
outputs = tuple(
[out.view([a2a_info.local_batch_num, -1]) for out in outputs]
)
return outputs
@staticmethod
def backward(ctx, *grad_outputs):
global myreq
with record_function("DLRM alltoall_wait_bwd_single"):
a2a_info = ctx.a2a_info
grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs]
grad_output = torch.cat(grad_outputs)
grad_input = grad_output.new_empty(
[a2a_info.batch_size * a2a_info.local_table_num * a2a_info.emb_dim]
)
req = dist.all_to_all_single(
grad_input,
grad_output,
a2a_info.batch_split_lengths,
a2a_info.table_split_lengths,
async_op=True,
)
myreq.req = req
myreq.tensor = grad_input
return (grad_output,)
class AllGather(Function):
@staticmethod
def forward(ctx, input, global_lengths, dim=0):
if not isinstance(global_lengths, (list, tuple)):
global_lengths = [global_lengths] * my_size
assert len(global_lengths) == my_size
assert global_lengths[my_rank] == input.size(dim)
local_start = sum(global_lengths[:my_rank])
output_size = list(input.size())
ctx.dim = dim
ctx.local_start = local_start
ctx.local_length = global_lengths[my_rank]
input = input.contiguous()
if dim == 0:
out_len = sum(global_lengths)
output_size[dim] = out_len
output = input.new_empty(output_size)
gather_list = list(output.split(global_lengths, dim=0))
else:
gather_list = [torch.empty_like(input) for _ in range(my_size)]
gather_list = []
for length in global_lengths:
output_size[dim] = length
gather_list.append(input.new_empty(output_size))
dist.all_gather(gather_list, input)
if dim != 0:
output = torch.cat(gather_list, dim=dim)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
start = ctx.local_start
length = ctx.local_length
grad_input = grad_output.narrow(dim, start, length)
return (grad_input, None, None)
class All2AllInfo(object):
pass
def alltoall(inputs, per_rank_table_splits):
global myreq
batch_size, emb_dim = inputs[0].size()
a2a_info = All2AllInfo()
a2a_info.local_table_num = len(inputs)
a2a_info.global_table_wise_parition_slices = per_rank_table_splits
(
a2a_info.local_batch_num,
a2a_info.global_batch_partition_slices,
) = get_split_lengths(batch_size)
a2a_info.emb_dim = emb_dim
a2a_info.batch_size = batch_size
a2a_info.global_table_num = (
sum(per_rank_table_splits)
if per_rank_table_splits
else a2a_info.local_table_num * my_size
)
if a2a_impl == "" and alltoall_supported or a2a_impl == "alltoall":
output = All2All_Req.apply(a2a_info, *inputs)
myreq.WaitFunction = All2All_Wait
elif a2a_impl == "" or a2a_impl == "scatter":
output = All2All_Scatter_Req.apply(a2a_info, *inputs)
myreq.WaitFunction = All2All_Scatter_Wait
elif a2a_impl == "scatter_list":
output = All2All_ScatterList_Req.apply(a2a_info, *inputs)
myreq.WaitFunction = All2All_ScatterList_Wait
else:
print(
"Unknown value set for DLRM_ALLTOALL_IMPL (%s), "
"please use one of [alltoall, scatter, scatter_list]" % a2a_impl
)
return myreq
def all_gather(input, lengths, dim=0):
if not lengths:
lengths = [input.size(0)] * my_size
return AllGather.apply(input, lengths, dim)
def barrier():
if my_size > 1:
dist.barrier()
orig_print = builtins.print
def rank0_print(*args, **kwargs):
if my_rank <= 0 or kwargs.get("print_all", False):
orig_print(*args, **kwargs)
builtins.print = rank0_print
def print_all(*args, **kwargs):
orig_print(*args, **kwargs)