from typing import cast, Callable, Dict, Sequence, Tuple
import torch
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor._op_schema import (
OpInfo,
OpSchema,
OutputSharding
)
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree
def get_redistributed_local_args(
op_info: OpInfo,
output_sharding: OutputSharding
) -> Tuple[object, ...]:
if output_sharding.needs_redistribute:
DTensor._op_dispatcher.redistribute_local_args(
op_info,
output_sharding.redistribute_schema
)
local_args = (
pytree.tree_unflatten(
cast(list[object], op_info.local_args), op_info.args_tree_spec
)
if op_info.args_tree_spec
else op_info.local_args
)
local_args = cast(tuple[object, ...], local_args)
return local_args
def get_redistributed_local_kwargs(
kwargs_spec_infer_func: Callable[[OpSchema, OutputSharding], Dict[str, DTensorSpec]],
op_info: OpInfo,
output_sharding: OutputSharding
) -> None:
src_kwargs_spec = op_info.schema.kwargs_schema
target_kwargs_spec = kwargs_spec_infer_func(op_info.schema, output_sharding)
new_local_kwargs = {}
def _redistribute(local_value, src_spec, target_spec):
if isinstance(target_spec, DTensorSpec) and src_spec.placements != target_spec.placements:
return redistribute_local_tensor(local_value, src_spec, target_spec)
else:
return local_value
schema_info = op_info.schema.schema_info
for key, target_spec in target_kwargs_spec.items():
local_value = op_info.local_kwargs[key]
src_spec = src_kwargs_spec[key]
if isinstance(target_spec, (list, tuple)) and schema_info is not None and schema_info.needs_pytree:
new_local_kwargs[key] = [
_redistribute(val, src, dst)
for val, src, dst in zip(local_value, src_spec, target_spec)
]
else:
new_local_kwargs[key] = _redistribute(local_value, src_spec, target_spec)
return new_local_kwargs
def get_empty_local_results(op_info: OpInfo, output_sharding: OutputSharding) -> object:
spec = output_sharding.output_spec
ret_list = op_info.schema.op._schema.returns
local_results = None
if spec is not None:
def default_tensor(spec: DTensorSpec) -> torch.Tensor:
if spec.tensor_meta is not None:
shape = spec.tensor_meta.shape
dtype = spec.tensor_meta.dtype
if len(shape) == 0:
return torch.zeros((), dtype=dtype)
else:
return torch.tensor([], dtype=dtype)
else:
raise RuntimeError(f"{spec} has no tensor metadata.")
if isinstance(spec, DTensorSpec):
local_results = default_tensor(spec)
elif isinstance(spec, Sequence):
local_results = [default_tensor(s) if s is not None else None for s in spec]
if not isinstance(local_results, list):
raise RuntimeError("local_results is not a list")
if None in local_results:
ret_type = str(ret_list[0].type)
raise NotImplementedError(
f"return type {ret_type} in DTensor op is not supported"
)
return local_results