from typing import Optional, cast
import logging
import torch
from torch.nn import _reduction as _Reduction
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementStrategy,
PlacementList,
RuntimeSchemaInfo,
)
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
register_op_strategy,
normalize_dim,
expand_to_full_mesh_op_strategy,
)
from torch.distributed.tensor._ops._math_ops import (
_replicate_dims_start_at,
_infer_reduce_dims_map,
map_placements_after_reduction,
)
from torch.distributed.tensor._utils import normalize_to_torch_size
from torch.distributed.tensor import Partial, Replicate, Shard
from torch.distributed.tensor.experimental import register_sharding
logger = logging.getLogger("torch.distributed.tensor")
npu = torch.ops.npu
aten = torch.ops.aten
@register_op_strategy(npu.npu_rms_norm.default)
def npu_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args(validate=False)
expected_args_len = 2
(
input_strategy,
gamma_strategy,
) = op_schema.args_schema[:expected_args_len]
normalized_shape = gamma_strategy.shape
normalized_size = normalize_to_torch_size(normalized_shape)
input_ndim = input_strategy.ndim
axis = input_ndim - len(normalized_size)
output_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
op_args_target_specs = []
redistribute_costs = []
input_src_spec = input_placement_strategy.output_spec
input_target_spec = DTensorSpec(
mesh=mesh,
placements=_replicate_dims_start_at(input_src_spec.placements, axis),
tensor_meta=input_src_spec.tensor_meta,
)
op_args_target_specs.append(input_target_spec)
redistribute_costs.append(
generate_redistribute_costs(input_strategy, input_target_spec)
)
if gamma_strategy is not None:
gamma_src_spec = gamma_strategy.strategies[idx].output_spec
gamma_target_spec = DTensorSpec(
mesh=mesh,
placements=_replicate_dims_start_at(gamma_src_spec.placements),
tensor_meta=gamma_src_spec.tensor_meta,
)
op_args_target_specs.append(gamma_target_spec)
redistribute_costs.append(
generate_redistribute_costs(gamma_strategy, gamma_target_spec)
)
output_target_spec = input_target_spec
output_strategy.strategies.append(
PlacementStrategy(
output_specs=output_target_spec,
input_specs=op_args_target_specs,
redistribute_cost=redistribute_costs,
)
)
return output_strategy
@register_op_strategy(npu.npu_rms_norm_backward.default)
def npu_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args(validate=False)
(
grad_out_strategy,
input_strategy,
gamma_strategy,
rstd_strategy,
) = op_schema.args_schema
normalized_shape = gamma_strategy.shape
normalized_size = normalize_to_torch_size(normalized_shape)
input_ndim = input_strategy.ndim
axis = input_ndim - len(normalized_size)
outer_dims = list(range(axis))
out_tuple_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
output_specs_list: list[Optional[DTensorSpec]] = []
input_specs_list: list[DTensorSpec] = []
redistribute_costs = []
input_src_spec = input_placement_strategy.output_spec
grad_out_target_spec = DTensorSpec(
mesh=mesh,
placements=_replicate_dims_start_at(input_src_spec.placements, axis),
tensor_meta=input_src_spec.tensor_meta,
)
input_specs_list.append(grad_out_target_spec)
redistribute_costs.append(
generate_redistribute_costs(grad_out_strategy, grad_out_target_spec)
)
output_specs_list.append(grad_out_target_spec)
input_target_spec = DTensorSpec(
mesh=mesh,
placements=_replicate_dims_start_at(input_src_spec.placements, axis),
tensor_meta=input_src_spec.tensor_meta,
)
input_specs_list.append(input_target_spec)
redistribute_costs.append(
generate_redistribute_costs(input_strategy, input_target_spec)
)
if gamma_strategy is not None:
gamma_src_spec = gamma_strategy.strategies[idx].output_spec
input_specs_list.append(gamma_src_spec)
redistribute_costs.append([0.0 for _ in gamma_strategy.strategies])
inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis)
reduce_dims_map = _infer_reduce_dims_map(
outer_dims, input_src_spec.ndim, False
)
out_placements = map_placements_after_reduction(
inp_placements, outer_dims, reduce_dims_map, "sum"
)
gamma_out_spec = DTensorSpec(
mesh=mesh,
placements=out_placements,
tensor_meta=gamma_src_spec.tensor_meta,
)
output_specs_list.append(gamma_out_spec)
else:
output_specs_list.append(None)
rstd_src_spec = rstd_strategy.strategies[idx].output_spec
input_specs_list.append(rstd_src_spec)
redistribute_costs.append([0.0 for _ in rstd_strategy.strategies])
out_tuple_strategy.strategies.append(
PlacementStrategy(
output_specs=tuple(output_specs_list),
input_specs=input_specs_list,
redistribute_cost=redistribute_costs,
)
)
return out_tuple_strategy
@register_op_strategy(npu.npu_add_rms_norm.default)
def npu_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args(validate=False)
expected_args_len = 3
(
x1_strategy,
x2_strategy,
gamma_strategy,
) = op_schema.args_schema[:expected_args_len]
normalized_shape = gamma_strategy.shape
normalized_size = normalize_to_torch_size(normalized_shape)
input_ndim = x1_strategy.ndim
axis = input_ndim - len(normalized_size)
output_strategy = OpStrategy([])
for idx, x1_placement_strategy in enumerate(x1_strategy.strategies):
op_args_target_specs = []
redistribute_costs = []
x1_src_spec = x1_placement_strategy.output_spec
x1_target_placements = _replicate_dims_start_at(x1_src_spec.placements, axis)
x1_target_spec = DTensorSpec(
mesh=mesh,
placements=x1_target_placements,
tensor_meta=x1_src_spec.tensor_meta,
)
op_args_target_specs.append(x1_target_spec)
redistribute_costs.append(
generate_redistribute_costs(x1_strategy, x1_target_spec)
)
if x2_strategy is not None:
x2_src_spec = x2_strategy.strategies[idx].output_spec
x2_target_spec = DTensorSpec(
mesh,
placements=x1_target_placements,
tensor_meta=x2_src_spec.tensor_meta,
)
op_args_target_specs.append(x2_target_spec)
redistribute_costs.append(
generate_redistribute_costs(x2_strategy, x2_target_spec)
)
if gamma_strategy is not None:
gamma_src_spec = gamma_strategy.strategies[idx].output_spec
gamma_target_spec = DTensorSpec(
mesh=mesh,
placements=_replicate_dims_start_at(gamma_src_spec.placements),
tensor_meta=gamma_src_spec.tensor_meta,
)
op_args_target_specs.append(gamma_target_spec)
redistribute_costs.append(
generate_redistribute_costs(gamma_strategy, gamma_target_spec)
)
y_target_spec = x1_target_spec
rstd_target_spec = DTensorSpec(
mesh=mesh,
placements=x1_target_placements[:axis],
tensor_meta=x1_src_spec.tensor_meta,
)
x_target_spec = x1_target_spec
output_target_spec = (y_target_spec, rstd_target_spec, x_target_spec)
output_strategy.strategies.append(
PlacementStrategy(
output_specs=output_target_spec,
input_specs=op_args_target_specs,
redistribute_cost=redistribute_costs,
)
)
return output_strategy
@register_sharding(npu.npu_rotary_mul.default)
def npu_rotary_mul_strategy(x, r1, r2, rotary_mode="half"):
acceptable_shardings = []
replicate_strategy = (
[Replicate()],
[Replicate(), Replicate(), Replicate(), None]
)
acceptable_shardings.append(replicate_strategy)
for i in range(x.ndim - 1):
if r1.shape[i] == 1:
sharding_strategy = (
[Shard(i)],
[Shard(i), Replicate(), Replicate(), None]
)
else:
sharding_strategy = (
[Shard(i)],
[Shard(i), Shard(i), Shard(i), None]
)
acceptable_shardings.append(sharding_strategy)
return acceptable_shardings
@register_sharding(npu.npu_rotary_mul_backward.default)
def npu_rotary_mul_backward_strategy(grad, x, r1, r2, rotary_mode='half'):
acceptable_shardings = []
replicate_strategy = (
[Replicate(), Replicate(), Replicate()],
[Replicate(), Replicate(), Replicate(), Replicate(), None]
)
acceptable_shardings.append(replicate_strategy)
for i in range(x.ndim - 1):
if r1.shape[i] == 1:
sharding_strategy = (
[Shard(i), Partial(), Partial()],
[Shard(i), Shard(i), Replicate(), Replicate(), None]
)
else:
sharding_strategy = (
[Shard(i), Shard(i), Shard(i)],
[Shard(i), Shard(i), Shard(i), Shard(i), None]
)
acceptable_shardings.append(sharding_strategy)
return acceptable_shardings
@register_sharding(npu.npu_gather_backward.default)
def custom_npu_gather_backward_strategy(grad, input_shape, dim, index, sparse_grad=False):
dim = normalize_dim(dim, len(input_shape))
index_shape = index.shape
grad_shape = grad.shape
acceptable_sharding = []
replicate_strategy = ([Replicate()], [Replicate(), None, None, Replicate(), None])
acceptable_sharding.append(replicate_strategy)
if len(input_shape) == len(index_shape) == len(grad_shape):
for d in range(index.ndim):
if d != dim and input_shape[d] == index_shape[d] == grad_shape[d]:
sharding_strategy = (
[Shard(d)],
[Shard(d),
None,
None,
Shard(d),
None]
)
acceptable_sharding.append(sharding_strategy)
return acceptable_sharding
@register_sharding(npu.npu_swiglu.default)
def custom_npu_swiglu_strategy(x, dim=-1):
dim = normalize_dim(dim, x.ndim)
acceptable_sharding = []
replicate_strategy = ([Replicate()], [Replicate(), None])
acceptable_sharding.append(replicate_strategy)
for i in range(x.ndim):
if i != dim:
sharding_strategy = (
[Shard(i)],
[Shard(i), None]
)
acceptable_sharding.append(sharding_strategy)
return acceptable_sharding
@register_sharding(npu.npu_swiglu_backward.default)
def custom_npu_swiglu_backward_strategy(grad, x, dim=-1):
dim = normalize_dim(dim, x.ndim)
acceptable_sharding = []
replicate_strategy = (
[Replicate(), None],
[Replicate(), Replicate(), None]
)
acceptable_sharding.append(replicate_strategy)
for i in range(x.ndim):
if i != dim:
sharding_strategy = (
[Shard(i), None],
[Shard(i), Shard(i), None]
)
acceptable_sharding.append(sharding_strategy)
return acceptable_sharding
@register_sharding(npu.npu_conv2d.default)
def custom_npu_conv2d_strategy(x, weight, bias, stride, padding, dilation, groups):
acceptable_shardings = []
replicate_strategy = (
[
Replicate()
],
[
Replicate(),
Replicate(),
None if bias is None else Replicate(),
None, None, None, None
]
)
acceptable_shardings.append(replicate_strategy)
N_sharding_strategy = (
[Shard(0)],
[Shard(0), Replicate(), None if bias is None else Replicate(), None, None, None, None]
)
acceptable_shardings.append(N_sharding_strategy)
if groups == 1:
Co_sharding_strategy = (
[Shard(1)],
[Replicate(), Shard(0), None if bias is None else Shard(0), None, None, None, None]
)
acceptable_shardings.append(Co_sharding_strategy)
return acceptable_shardings
@register_sharding(npu.npu_conv2d_backward.default)
def custom_npu_conv2d_backward_strategy(x, grad_output, weight, stride, padding, dilation, groups, output_mask):
grad_output_placement = grad_output.placements[0]
acceptable_shardings = []
replicate_strategy = (
[
Replicate(),
Replicate(),
Replicate() if output_mask[2] else None
],
[
Replicate(),
Replicate(),
Replicate(),
None, None, None, None, None
]
)
acceptable_shardings.append(replicate_strategy)
if (isinstance(grad_output_placement, Shard) and grad_output_placement.dim == 0):
N_sharding_strategy = (
[Shard(0), Partial(), Partial() if output_mask[2] else None],
[Shard(0), Shard(0), Replicate(), None, None, None, None, None]
)
acceptable_shardings.append(N_sharding_strategy)
if (isinstance(grad_output_placement, Shard) and grad_output_placement.dim == 1) and (groups == 1):
Co_sharding_strategy = (
[Partial(), Shard(0), Shard(0) if output_mask[2] else None],
[Replicate(), Shard(1), Shard(0), None, None, None, None, None]
)
acceptable_shardings.append(Co_sharding_strategy)
return acceptable_shardings
@register_sharding(npu.npu_grouped_matmul_add_.default)
def custom_grouped_matmul_add__strategy(y, x, weight, group_list, transpose_x=True, transpose_weight=False, group_type=2):
acceptable_shardings = []
y_placement = y.placements[0]
if (isinstance(y_placement, Shard) and y_placement.dim == 0):
raise ValueError(
"The 0th dimension of the output matrix cannot be split."
)
replicate_strategy = (
[
Replicate()
],
[
Replicate(),
Replicate(),
Replicate(),
Replicate(),
None, None, None
]
)
acceptable_shardings.append(replicate_strategy)
D_shard_strategy = (
[Shard(1)],
[Shard(1), Replicate(), Shard(1), Replicate(), None, None, None]
)
acceptable_shardings.append(D_shard_strategy)
return acceptable_shardings
def is_tensor_evenly_shardable(shape, spec):
"""Check if the shape is evenly shardable according to the spec."""
if not isinstance(spec, DTensorSpec):
raise TypeError(
f"Expected 'spec' to be DTensorSpec instance, got {type(spec).__name__} instead."
)
if len(shape) == 0:
raise ValueError("'shape' must have at least 1 dimension (empty shape is invalid).")
if len(spec.placements) == 0:
raise ValueError("'spec.placements' cannot be empty (must have at least one placement).")
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
shards_map[shard_dim] *= spec.mesh.size(i)
for i, dim_size in enumerate(shape):
if shards_map[i] > 1 and (dim_size % shards_map[i] != 0):
return False
return True
@register_op_strategy(npu.npu_cross_entropy_loss.default, schema_info=RuntimeSchemaInfo(3))
def custom_cross_entropy_loss_sharding(op_schema: OpSchema):
single_mesh_dim_strategies = []
args_schema = op_schema.args_schema
input_strategy = args_schema[0] if len(args_schema) > 0 else None
target_strategy = args_schema[1] if len(args_schema) > 1 else None
weight_strategy = args_schema[2] if len(args_schema) > 2 else None
reduction = args_schema[3] if len(args_schema) > 3 else 'mean'
mesh = input_strategy.mesh
all_replicate: PlacementList = [
Replicate(),
Replicate(),
Replicate(),
Replicate(),
Replicate(),
Replicate()
]
if weight_strategy is not None:
all_replicate.append(Replicate())
single_mesh_dim_strategies.append(all_replicate)
if reduction == 'none':
N_replicate_strategy = [Shard(0), Shard(0), Replicate(), Replicate(), Shard(0), Shard(0)]
if weight_strategy is not None:
N_replicate_strategy.append(Replicate())
single_mesh_dim_strategies.append(N_replicate_strategy)
elif reduction == 'mean':
target_shape = target_strategy.strategies[0].output_spec.shape
target_spec = target_strategy.strategies[0].output_spec
if weight_strategy is None and is_tensor_evenly_shardable(target_shape, target_spec):
N_replicate_strategy = [Partial("avg"), Shard(0), Replicate(), Replicate(), Shard(0), Shard(0)]
single_mesh_dim_strategies.append(N_replicate_strategy)
else:
N_replicate_strategy = [Partial("sum"), Shard(0), Replicate(), Replicate(), Shard(0), Shard(0)]
if weight_strategy is not None:
N_replicate_strategy.append(Replicate())
single_mesh_dim_strategies.append(N_replicate_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=4
)
@register_op_strategy(npu.npu_cross_entropy_loss_backward.default, schema_info=RuntimeSchemaInfo(6))
def custom_cross_entropy_loss_backward_strategy(op_schema: OpSchema):
single_mesh_dim_strategies = []
args_schema = op_schema.args_schema
grad_loss_strategy = args_schema[0] if len(args_schema) > 0 else None
weight_strategy = args_schema[3] if len(args_schema) > 3 else None
lse_for_zloss_strategy = args_schema[5] if len(args_schema) > 5 else None
reduction = args_schema[6] if len(args_schema) > 6 else 'mean'
mesh = grad_loss_strategy.mesh
all_replicate: PlacementList = [
Replicate(),
Replicate(),
Replicate(),
Replicate()
]
if weight_strategy is not None:
all_replicate.append(Replicate())
if lse_for_zloss_strategy is not None:
all_replicate.append(Replicate())
single_mesh_dim_strategies.append(all_replicate)
if reduction == 'none':
N_replicate_strategy = [Shard(0), Shard(0), Shard(0), Shard(0)]
if weight_strategy is not None:
N_replicate_strategy.append(Replicate())
if lse_for_zloss_strategy is not None:
N_replicate_strategy.append(Replicate())
single_mesh_dim_strategies.append(N_replicate_strategy)
elif reduction == 'sum':
N_replicate_strategy = [Shard(0), Replicate(), Shard(0), Shard(0)]
if weight_strategy is not None:
N_replicate_strategy.append(Replicate())
if lse_for_zloss_strategy is not None:
N_replicate_strategy.append(Replicate())
single_mesh_dim_strategies.append(N_replicate_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)
@register_sharding(aten.repeat_interleave.self_int)
def custom_npu_repeat_interleave_self_int_strategy(x, repeat, dim=None, output_size=None):
acceptable_shardings = []
replicate_strategy = (
[
Replicate()
],
[
Replicate(),
None, None, None
]
)
acceptable_shardings.append(replicate_strategy)
if not is_tensor_evenly_shardable(x.shape, x):
if dim is not None:
for i in range(x.ndim):
if i != dim:
sharding_strategy = (
[Shard(i)],
[Shard(i), None, None, None]
)
acceptable_shardings.append(sharding_strategy)
else:
if dim is None:
sharding_strategy = (
[Shard(0)],
[Shard(0), None, None, None]
)
acceptable_shardings.append(sharding_strategy)
else:
for i in range(x.ndim):
sharding_strategy = (
[Shard(i)],
[Shard(i), None, None, None]
)
acceptable_shardings.append(sharding_strategy)
return acceptable_shardings
@register_sharding(npu.repeat_interleave_backward_int.default)
def custom_npu_repeat_interleave_backward_int_strategy(grad, x, repeats, dim=None):
acceptable_shardings = []
replicate_strategy = (
[
Replicate()
],
[
Replicate(),
Replicate(),
None, None
]
)
acceptable_shardings.append(replicate_strategy)
if dim is not None and is_tensor_evenly_shardable(x.shape, x):
for i in range(x.ndim):
grad_placement = grad.placements[0]
if (isinstance(grad_placement, Shard)
and dim == grad_placement.dim
and not is_tensor_evenly_shardable(grad.shape, grad)
):
continue
else:
sharding_strategy = (
[Shard(i)],
[Shard(i), Shard(i), None, None]
)
acceptable_shardings.append(sharding_strategy)
return acceptable_shardings
@register_sharding(aten.kl_div.default)
def custom_KLDivLoss_forward_strategy(x, target, reduction=_Reduction.get_enum("mean"), log_target=False):
acceptable_shardings = []
replicate_strategy = (
[
Replicate()
],
[
Replicate(),
Replicate(),
None, None
]
)
acceptable_shardings.append(replicate_strategy)
for dim in range(x.ndim):
if reduction == _Reduction.get_enum("none"):
shard_strategy = (
[Shard(dim)],
[Shard(dim), Shard(dim), None, None]
)
elif reduction == _Reduction.get_enum("mean") and x.shape[dim] % x.mesh.size(0) == 0:
shard_strategy = (
[Partial("avg")],
[Shard(dim), Shard(dim), None, None]
)
elif reduction == _Reduction.get_enum("sum"):
shard_strategy = (
[Partial("sum")],
[Shard(dim), Shard(dim), None, None]
)
else:
logger.warning(
"%d is not a valid value for reduction. "
"Parameters supported by reduction: none: 0, mean: 1, sum & batchmean: 2.",
reduction
)
return acceptable_shardings
acceptable_shardings.append(shard_strategy)
return acceptable_shardings
@register_sharding(npu.kl_div_backward.default)
def custom_KLDivLoss_backward_strategy(grad_out, x, target, reduction=_Reduction.get_enum("mean"), log_target=False):
acceptable_shardings = []
replicate_strategy = (
[
Replicate()
],
[
Replicate(),
Replicate(),
Replicate(),
None, None
]
)
acceptable_shardings.append(replicate_strategy)
for dim in range(x.ndim):
if reduction == _Reduction.get_enum("none"):
shard_strategy = (
[Shard(dim)],
[Shard(dim), Shard(dim), Shard(dim), None, None]
)
elif reduction == _Reduction.get_enum("sum"):
shard_strategy = (
[Shard(dim)],
[Replicate(), Shard(dim), Shard(dim), None, None]
)
else:
if reduction != _Reduction.get_enum("mean"):
logger.warning(
"%d is not a valid value for reduction. "
"Parameters supported by reduction: none: 0, mean: 1, sum & batchmean: 2.",
reduction
)
return acceptable_shardings
acceptable_shardings.append(shard_strategy)
return acceptable_shardings