import os
from typing import cast
import torch
import torch_npu
from torch.distributed._tensor.experimental import register_sharding
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
_is_out_variant_op,
OpInfo,
OpSchema,
OpStrategy,
OutputSharding,
PlacementStrategy,
RuntimeSchemaInfo,
TupleStrategy,
)
from torch.distributed.tensor._ops._matrix_ops import _mm_like_strategy
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
register_op_strategy,
)
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree
from ._common import (
get_empty_local_results,
get_redistributed_local_args,
get_redistributed_local_kwargs,
)
aten = torch.ops.aten
npu = torch.ops.npu
def _get_max_shardable_dim(tensor):
shape = tensor.shape
world_size = torch.distributed.get_world_size()
divisible_dims = [
(idx, dim) for idx, dim in enumerate(shape) if dim % world_size == 0
]
if divisible_dims:
idx, _ = max(divisible_dims, key=lambda x: x[1])
return idx
else:
return -1
def _handle_tensor_list_in_kwargs(kwargs: dict[str, object], op_info: OpInfo) -> None:
for key, value in kwargs.items():
if isinstance(value, list) and all(isinstance(e, DTensor) for e in value):
new_schema = []
new_local_tensors = []
for dtensor in value:
new_schema.append(dtensor._spec)
new_local_tensors.append(dtensor._local_tensor)
op_info.schema.kwargs_schema[key] = tuple(
new_schema
)
op_info.local_kwargs[key] = new_local_tensors
if os.getenv("TORCH_NPU_USE_COMPATIBLE_IMPL") != "1":
@register_sharding(aten.matmul.default)
def custom_matmul_strategy(
tensor1: DTensorSpec,
tensor2: DTensorSpec,
):
shape1 = tensor1.shape
shape2 = tensor2.shape
acceptable_shardings = []
replicate_strategy = ([Replicate()], [Replicate(), Replicate()])
acceptable_shardings.append(replicate_strategy)
if len(shape1) == 1 and len(shape2) == 1:
return acceptable_shardings
elif len(shape1) == 1:
for i in range(len(shape2) - 2):
if shape2[i] % tensor2.mesh.size(0) == 0:
strategy_1 = ([Shard(i)], [Replicate(), Shard(i)])
acceptable_shardings.append(strategy_1)
if shape1[0] % tensor1.mesh.size(0) == 0:
strategy_2 = ([Partial()], [Shard(0), Shard(len(shape2) - 2)])
acceptable_shardings.append(strategy_2)
if shape2[len(shape2) - 1] % tensor2.mesh.size(0) == 0:
output_shape = shape2[:-2] + (shape2[-1],)
strategy_3 = (
[Shard(len(output_shape) - 1)],
[Replicate(), Shard(len(shape2) - 1)],
)
acceptable_shardings.append(strategy_3)
elif len(shape2) == 1:
for i in range(len(shape1) - 1):
if shape1[i] % tensor1.mesh.size(0) == 0:
strategy_1 = ([Shard(i)], [Shard(i), Replicate()])
acceptable_shardings.append(strategy_1)
if shape2[0] % tensor2.mesh.size(0) == 0:
strategy_2 = ([Partial()], [Shard(len(shape1) - 1), Shard(0)])
acceptable_shardings.append(strategy_2)
else:
output_shape = torch.broadcast_shapes(shape1[:-2], shape2[:-2]) + (
shape1[-2],
shape2[-1],
)
len1, len2 = len(shape1), len(shape2)
diff = abs(len1 - len2)
is_shape1_longer = len1 > len2
for i in range(min(len1, len2) - 3, -1, -1):
shape1_shardable = (
shape1[i + diff] % tensor1.mesh.size(0) == 0
if is_shape1_longer
else shape1[i] % tensor1.mesh.size(0) == 0
)
shape2_shardable = (
shape2[i] % tensor2.mesh.size(0) == 0
if is_shape1_longer
else shape2[i + diff] % tensor2.mesh.size(0) == 0
)
if shape1_shardable and shape2_shardable:
strategy_batch = (
([Shard(i + diff)], [Shard(i + diff), Shard(i)])
if is_shape1_longer
else ([Shard(i + diff)], [Shard(i), Shard(i + diff)])
)
acceptable_shardings.append(strategy_batch)
for i in range(diff - 1, -1, -1):
if is_shape1_longer and shape1[i] % tensor1.mesh.size(0) == 0:
strategy_batch = ([Shard(i)], [Shard(i), Replicate()])
acceptable_shardings.append(strategy_batch)
elif not is_shape1_longer and shape2[i] % tensor2.mesh.size(0) == 0:
strategy_batch = ([Shard(i)], [Replicate(), Shard(i)])
acceptable_shardings.append(strategy_batch)
if shape1[len(shape1) - 2] % tensor1.mesh.size(0) == 0:
strategy_tensor1 = (
[Shard(len(output_shape) - 2)],
[Shard(len(shape1) - 2), Replicate()],
)
acceptable_shardings.append(strategy_tensor1)
if shape2[len(shape2) - 1] % tensor2.mesh.size(0) == 0:
strategy_tensor2 = (
[Shard(len(output_shape) - 1)],
[Replicate(), Shard(len(shape2) - 1)],
)
acceptable_shardings.append(strategy_tensor2)
if shape1[len(shape1) - 1] % tensor1.mesh.size(0) == 0:
strategy_3 = (
[Partial()],
[Shard(len(shape1) - 1), Shard(len(shape2) - 2)],
)
acceptable_shardings.append(strategy_3)
return acceptable_shardings
@register_sharding(aten.matmul_backward.default)
def custom_matmul_backward_strategy(
grad: DTensorSpec,
tensor1: DTensorSpec,
tensor2: DTensorSpec,
mask: list[bool],
):
grad_dim = len(grad.shape)
tensor1_dim = len(tensor1.shape)
tensor2_dim = len(tensor2.shape)
acceptable_shardings = []
replicate_strategy = (
[Replicate(), Replicate()],
[Replicate(), Replicate(), Replicate(), None],
)
acceptable_shardings.append(replicate_strategy)
if tensor1_dim == 1 and tensor2_dim == 1:
return acceptable_shardings
elif tensor1_dim >= 2 and (tensor2_dim == 1 or tensor2_dim == 2):
if tensor2.shape[0] % tensor2.mesh.size(0) == 0:
strategy_1 = (
[Shard(tensor1_dim - 1), Shard(0)],
[Replicate(), Shard(tensor1_dim - 1), Shard(0), None],
)
acceptable_shardings.append(strategy_1)
for i in range(tensor1_dim - 1):
if tensor1.shape[i] % tensor1.mesh.size(0) == 0:
strategy_2 = (
[Shard(i), Partial()],
[Shard(i), Shard(i), Replicate(), None],
)
acceptable_shardings.append(strategy_2)
if tensor2_dim == 2 and tensor2.shape[1] % tensor2.mesh.size(0) == 0:
strategy_3 = (
[Partial(), Shard(1)],
[Shard(grad_dim - 1), Replicate(), Shard(1), None],
)
acceptable_shardings.append(strategy_3)
return acceptable_shardings
elif tensor2_dim >= 2 and (tensor1_dim == 1 or tensor1_dim == 2):
is_special = tensor2_dim == 2 and tensor1_dim == 1
if tensor1.shape[-1] % tensor1.mesh.size(0) == 0:
strategy_1 = (
[
Shard(tensor1_dim if is_special else tensor1_dim - 1),
Shard(tensor2_dim - 2),
],
[Replicate(), Shard(tensor1_dim - 1), Shard(tensor2_dim - 2), None],
)
acceptable_shardings.append(strategy_1)
if tensor2.shape[-1] % tensor2.mesh.size(0) == 0:
strategy_2 = (
[Partial(), Shard(tensor2_dim - 1)],
[Shard(grad_dim - 1), Replicate(), Shard(tensor2_dim - 1), None],
)
acceptable_shardings.append(strategy_2)
for i in range(tensor2_dim - 2):
if tensor2.shape[i] % tensor2.mesh.size(0) == 0:
strategy_3 = (
[Partial(), Shard(i)],
[Shard(i), Replicate(), Shard(i)],
)
acceptable_shardings.append(strategy_3)
if tensor1_dim == 2 and tensor1.shape[0] % tensor1.mesh.size(0) == 0:
strategy_4 = (
[Shard(0), Partial()],
[Shard(grad_dim - 2), Shard(0), Replicate(), None],
)
acceptable_shardings.append(strategy_4)
return acceptable_shardings
else:
if grad.shape[-1] % grad.mesh.size(0) == 0:
strategy_1 = (
[Partial(), Shard(grad_dim - 1)],
[Shard(grad_dim - 1), Replicate(), Shard(tensor2_dim - 1), None],
)
acceptable_shardings.append(strategy_1)
if grad.shape[-2] % grad.mesh.size(0) == 0:
strategy_2 = (
[Shard(grad_dim - 2), Partial()],
[Shard(grad_dim - 2), Shard(tensor1_dim - 2), Replicate(), None],
)
acceptable_shardings.append(strategy_2)
if tensor1.shape[-1] % tensor1.mesh.size(0) == 0:
strategy_3 = (
[Shard(grad_dim - 1), Shard(grad_dim - 2)],
[Replicate(), Shard(tensor1_dim - 1), Shard(tensor2_dim - 2), None],
)
acceptable_shardings.append(strategy_3)
diff = abs(tensor1_dim - tensor2_dim)
is_shape1_longer = tensor1_dim > tensor2_dim
for i in range(min(tensor1_dim, tensor2_dim) - 3, -1, -1):
shape1_shardable = (
tensor1.shape[i + diff] % tensor1.mesh.size(0) == 0
if is_shape1_longer
else tensor1.shape[i] % tensor1.mesh.size(0) == 0
)
shape2_shardable = (
tensor2.shape[i] % tensor2.mesh.size(0) == 0
if is_shape1_longer
else tensor2.shape[i + diff] % tensor2.mesh.size(0) == 0
)
if shape1_shardable and shape2_shardable:
strategy_batch = (
(
[Shard(i + diff), Shard(i + diff)],
[Shard(i + diff), Shard(i + diff), Shard(i), None],
)
if is_shape1_longer
else (
[Shard(i + diff), Shard(i + diff)],
[Shard(i + diff), Shard(i), Shard(i + diff), None],
)
)
acceptable_shardings.append(strategy_batch)
for i in range(diff - 1, -1, -1):
if is_shape1_longer and tensor1.shape[i] % tensor1.mesh.size(0) == 0:
strategy_batch = (
[Shard(i), Partial()],
[Shard(i), Shard(i), Replicate(), None],
)
acceptable_shardings.append(strategy_batch)
elif (
not is_shape1_longer
and tensor2.shape[i] % tensor2.mesh.size(0) == 0
):
strategy_batch = (
[Partial(), Shard(i)],
[Shard(i), Replicate(), Shard(i), None],
)
acceptable_shardings.append(strategy_batch)
return acceptable_shardings
@register_op_strategy(
npu.npu_grouped_matmul.default,
schema_info=RuntimeSchemaInfo(
static_kwargkey=[
"bias",
"scale",
"offset",
"antiquant_scale",
"antiquant_offset",
"per_token_scale",
"group_list",
"activation_input",
"activation_quant_scale",
"activation_quant_offset",
],
needs_pytree=True,
),
)
@register_op_strategy(
npu.npu_grouped_matmul.List,
schema_info=RuntimeSchemaInfo(
static_kwargkey=[
"bias",
"scale",
"offset",
"antiquant_scale",
"antiquant_offset",
"per_token_scale",
"activation_input",
"activation_quant_scale",
"activation_quant_offset",
],
needs_pytree=True,
),
)
def npu_grouped_matmul_strategy(op_schema: OpSchema) -> OpStrategy:
if op_schema.schema_info is None:
op_schema.schema_info = RuntimeSchemaInfo(
needs_pytree=True
)
x_src_strategy: TupleStrategy = op_schema.args_schema[0]
x_num = len(x_src_strategy.childs)
weight_src_strategy: TupleStrategy = op_schema.args_schema[1]
weight_num = len(weight_src_strategy.childs)
bias_src_strategy: TupleStrategy | list | None = op_schema.kwargs_schema.get(
"bias", []
)
bias_num = (
len(bias_src_strategy.childs)
if isinstance(bias_src_strategy, TupleStrategy)
else len(bias_src_strategy)
)
group_list_num = (
1
if (
op_schema.op == npu.npu_grouped_matmul.default
and op_schema.kwargs_schema.get("group_list", None) is not None
)
else 0
)
split_item = op_schema.kwargs_schema.get("split_item", 0)
y_num = (
weight_num if split_item in (0, 1) else 1
)
strategies = []
all_replicate_strategy = [Replicate()] * y_num
all_replicate_strategy.extend(
[Replicate()] * (len(op_schema.args_strategy) + len(op_schema.kwargs_strategy))
)
strategies.append(all_replicate_strategy)
unsupported_arguments = [
"scale",
"offset",
"antiquant_scale",
"antiquant_offset",
"per_token_scale",
"activation_input",
"activation_quant_scale",
"activation_quant_offset",
]
for key in unsupported_arguments:
schema = op_schema.kwargs_schema.get(key, None)
if (
schema is not None
and isinstance(schema, TupleStrategy)
and len(schema.childs) > 0
):
full_mesh_strategies = expand_to_full_mesh_op_strategy(
op_schema.get_mesh_from_args(), op_schema, strategies, input_index=y_num
)
if y_num == 1:
for strategy in full_mesh_strategies.strategies:
strategy.output_specs = [strategy.output_specs]
return full_mesh_strategies
if (
bias_num == 0
):
replicate_partial_strategy = [Partial()] * y_num
replicate_partial_strategy.extend([Replicate()] * x_num)
replicate_partial_strategy.extend([Partial()] * weight_num)
replicate_partial_strategy.extend([Replicate()] * group_list_num)
strategies.append(replicate_partial_strategy)
partial_replicate_strategy = [Partial()] * y_num
partial_replicate_strategy.extend([Partial()] * x_num)
partial_replicate_strategy.extend([Replicate()] * weight_num)
partial_replicate_strategy.extend([Replicate()] * group_list_num)
strategies.append(partial_replicate_strategy)
group_type = op_schema.kwargs_schema.get("group_type", None)
if group_type is not None and group_type > 0:
raise NotImplementedError(
f"npu_grouped_matmul does not support group_type={group_type} now."
)
if x_num > 1 and weight_num > 1 and y_num > 1:
pair_strategies = []
x_ndim = x_src_strategy.childs[0].ndim
for i in range(x_ndim - 1):
pair_strategies.append(
[Shard(i), Shard(i), Replicate(), Replicate()]
)
pair_strategies.append([Shard(x_ndim - 1), Replicate(), Shard(1), Shard(0)])
if bias_num == 0:
pair_strategies.append([Partial(), Shard(x_ndim - 1), Shard(0), None])
for y_spec, x_spec, weight_spec, bias_spec in pair_strategies:
strategy = [y_spec] * y_num
strategy.extend([x_spec] * x_num)
strategy.extend([weight_spec] * weight_num)
strategy.extend([bias_spec] * bias_num)
strategy.extend([Replicate()] * group_list_num)
strategies.append(strategy)
elif (
x_num == 1 and weight_num == 1 and y_num == 1
):
if bias_num == 0:
k_shard_strategy = [Partial(), Shard(1), Shard(1)]
k_shard_strategy.extend([Replicate()] * group_list_num)
strategies.append(k_shard_strategy)
n_shard_strategy = [Shard(1), Replicate(), Shard(2)]
n_shard_strategy.extend([Shard(1)] * bias_num)
n_shard_strategy.extend([Replicate()] * group_list_num)
strategies.append(n_shard_strategy)
elif weight_num > 1:
if bias_num == 0:
k_shard_strategy = [Partial()] * y_num
k_shard_strategy.extend([Shard(1)] * x_num)
k_shard_strategy.extend([Shard(0)] * weight_num)
k_shard_strategy.extend([Replicate()] * group_list_num)
strategies.append(k_shard_strategy)
n_shard_strategy = [Shard(1)] * y_num
n_shard_strategy.extend([Replicate()] * x_num)
n_shard_strategy.extend([Shard(1)] * weight_num)
n_shard_strategy.extend([Shard(0)] * bias_num)
n_shard_strategy.extend([Replicate()] * group_list_num)
strategies.append(n_shard_strategy)
full_mesh_strategies = expand_to_full_mesh_op_strategy(
op_schema.get_mesh_from_args(), op_schema, strategies, input_index=y_num
)
if y_num == 1:
for strategy in full_mesh_strategies.strategies:
strategy.output_specs = [strategy.output_specs]
return full_mesh_strategies
def _infer_npu_grouped_matmul_kwargs(
op_schema: OpSchema, output_sharding: OutputSharding
) -> dict[str, DTensorSpec]:
output_spec = output_sharding.output_spec[0]
kwargs_spec = {}
for key, spec in op_schema.kwargs_schema.items():
is_tensor_or_tenor_list_like = isinstance(spec, DTensorSpec) or (
isinstance(spec, (list, tuple))
and len(spec) > 0
and isinstance(spec[0], DTensorSpec)
)
if not is_tensor_or_tenor_list_like:
kwargs_spec[key] = spec
continue
if key == "group_list":
target_placement = [Replicate() for _ in output_spec.placements]
kwargs_spec[key] = DTensorSpec(
mesh=spec.mesh,
placements=target_placement,
tensor_meta=spec.tensor_meta,
)
continue
if key == "bias":
target_placement = [
Shard(0) if placement == Shard(output_spec.ndim - 1) else Replicate()
for placement in output_spec.placements
]
else:
target_placement = [Replicate() for _ in output_spec.placements]
kwargs_spec[key] = [
DTensorSpec(
mesh=e.mesh, placements=target_placement, tensor_meta=e.tensor_meta
)
for e in spec
]
return kwargs_spec
def _npu_grouped_matmul_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
_handle_tensor_list_in_kwargs(kwargs, op_info)
def _return_type_tensor():
return True
op_info.schema.return_type_tensor = _return_type_tensor
DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
mesh = op_info.compute_mesh
participating = mesh.get_coordinate() is not None
if participating:
local_args = get_redistributed_local_args(op_info, output_sharding)
local_kwargs = get_redistributed_local_kwargs(
_infer_npu_grouped_matmul_kwargs, op_info, output_sharding
)
local_results = op_call(*local_args, **local_kwargs)
else:
local_results = get_empty_local_results(op_info, output_sharding)
return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
@register_sharding(npu.npu_all_gather_base_mm.default)
def npu_all_gather_base_mm_strategy(
x1,
x2,
hcom,
world_size,
bias=None,
x1_scale=None,
x2_scale=None,
gather_index=0,
gather_output=True,
comm_turn=0,
output_dtype=None,
comm_mode=None,
):
if gather_index != 0:
raise NotImplementedError(
f"npu_all_gather_base_mm only support gather_index=0 now, but got {gather_index}."
)
strategies = []
sharding_strategy_S0R = (
[
Replicate(),
Replicate(),
],
[
Shard(0),
Replicate(),
None,
None,
None
if bias is None
else Replicate(),
None
if x1_scale is None
else Shard(0),
None
if x2_scale is None
else Replicate(),
None,
None,
None,
None,
None,
],
)
strategies.append(sharding_strategy_S0R)
sharding_strategy_S0S1 = (
[
Shard(1),
Replicate(),
],
[
Shard(0),
Shard(1),
None,
None,
None if bias is None else Shard(0),
None
if x1_scale is None
else Shard(0),
None
if x2_scale is None
else Shard(1),
None,
None,
None,
None,
None,
],
)
strategies.append(sharding_strategy_S0S1)
return strategies
def _infer_npu_all_gather_base_mm_kwargs(
op_schema: OpSchema, output_sharding: OutputSharding
) -> dict[str, DTensorSpec]:
output_spec = output_sharding.output_spec[0]
kwargs_spec = {}
for key, spec in op_schema.kwargs_schema.items():
if not isinstance(spec, DTensorSpec):
kwargs_spec[key] = spec
continue
target_placement = []
for placement in output_spec.placements:
if placement == Replicate():
if key == "x1_scale":
target_placement.append(Shard(0))
else:
target_placement.append(Replicate())
elif placement == Shard(1):
if key == "x2_scale":
target_placement.append(Shard(1))
else:
target_placement.append(Shard(0))
else:
raise ValueError(
f"Unexpected output placement {placement} for npu_all_gather_base_mm."
)
kwargs_spec[key] = DTensorSpec(
mesh=spec.mesh, placements=target_placement, tensor_meta=spec.tensor_meta
)
return kwargs_spec
@register_sharding(npu.npu_mm_reduce_scatter_base.default)
def npu_mm_reduce_scatter_base_strategy(
x1,
x2,
hcom,
world_size,
reduce_op="sum",
bias=None,
x1_scale=None,
x2_scale=None,
comm_turn=0,
output_dtype=None,
comm_mode=None,
):
if reduce_op != "sum":
raise NotImplementedError(
f"npu_mm_reduce_scatter_base only support reduce_op='sum' now, but got {reduce_op}."
)
strategies = []
sharding_strategy_S1S0 = (
[
Shard(0)
],
[
Shard(1),
Shard(0),
None,
None,
None,
None if bias is None else Shard(0),
None
if x1_scale is None
else Shard(1),
None
if x2_scale is None
else Shard(0),
None,
None,
None,
],
)
strategies.append(sharding_strategy_S1S0)
return strategies
def _infer_npu_mm_reduce_scatter_base_kwargs(
op_schema: OpSchema, output_sharding: OutputSharding
) -> dict[str, DTensorSpec]:
output_spec = output_sharding.output_spec
kwargs_spec = {}
for key, spec in op_schema.kwargs_schema.items():
if not isinstance(spec, DTensorSpec):
kwargs_spec[key] = spec
continue
target_placement = []
for placement in output_spec.placements:
if placement == Shard(0):
if key == "x1_scale":
target_placement.append(Shard(1))
else:
target_placement.append(Shard(0))
else:
raise ValueError(
f"Unexpected output placement {placement} for npu_mm_reduce_scatter_base."
)
kwargs_spec[key] = DTensorSpec(
mesh=spec.mesh, placements=target_placement, tensor_meta=spec.tensor_meta
)
return kwargs_spec
def npu_comm_mm_fusion_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
def get_output_meta(tensor_meta, dim, world_size):
if world_size == 0:
return tensor_meta
new_shape = list(tensor_meta.shape)
if op_call == npu.npu_all_gather_base_mm.default:
new_shape[dim] = new_shape[dim] // world_size
elif op_call == npu.npu_mm_reduce_scatter_base.default:
new_shape[dim] = new_shape[dim] * world_size
return TensorMeta(
shape=torch.Size(new_shape),
stride=tensor_meta.stride,
dtype=tensor_meta.dtype,
)
if op_call == npu.npu_all_gather_base_mm.default:
world_size = args[3]
for spec in output_sharding.output_spec:
spec.tensor_meta = get_output_meta(spec.tensor_meta, 0, world_size)
elif op_call == npu.npu_mm_reduce_scatter_base.default:
world_size = args[3]
spec = output_sharding.output_spec
spec.tensor_meta = get_output_meta(spec.tensor_meta, 0, world_size)
mesh = op_info.compute_mesh
participating = mesh.get_coordinate() is not None
if participating:
local_args = get_redistributed_local_args(op_info, output_sharding)
local_kwargs = op_info.local_kwargs
if op_call == npu.npu_all_gather_base_mm.default:
local_kwargs = get_redistributed_local_kwargs(
_infer_npu_all_gather_base_mm_kwargs, op_info, output_sharding
)
elif op_call == npu.npu_mm_reduce_scatter_base.default:
local_kwargs = get_redistributed_local_kwargs(
_infer_npu_mm_reduce_scatter_base_kwargs, op_info, output_sharding
)
local_results = op_call(*local_args, **local_kwargs)
else:
local_results = get_empty_local_results(op_info, output_sharding)
return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
@register_op_strategy([npu.npu_apply_adam_w.default, npu.npu_apply_adam_w.out])
def npu_apply_adam_w_strategy(op_schema: OpSchema) -> OpStrategy:
grad_arg_index = 7
max_gard_norm_arg_index = 8
grad_strategy: OpStrategy = op_schema.args_schema[grad_arg_index]
if "out" in op_schema.kwargs_schema:
grad_spec: DTensorSpec = (
op_schema.kwargs_schema["out"].childs[0].strategies[0].output_spec
)
else:
grad_spec: DTensorSpec = grad_strategy.strategies[0].output_spec
input_target_specs = []
for i, spec in enumerate(op_schema.args_schema):
if i == grad_arg_index:
input_target_specs.append(grad_spec)
elif i == max_gard_norm_arg_index and spec is not None:
input_target_specs.append(
DTensorSpec(
mesh=grad_spec.mesh,
placements=grad_spec.placements,
tensor_meta=spec.tensor_meta,
)
)
elif isinstance(spec, OpStrategy):
input_target_specs.append(spec.strategies[0].output_spec)
output_spec = []
for k, values in op_schema.kwargs_schema.items():
if k == "out":
for v in values.childs:
output_spec.append(v.strategies[0].output_spec)
output_strategy = OpStrategy(
[
PlacementStrategy(
output_specs=tuple(output_spec), input_specs=input_target_specs
)
]
)
return output_strategy
def _npu_apply_adam_w_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
_handle_tensor_list_in_kwargs(kwargs, op_info)
DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
mesh = op_info.compute_mesh
participating = mesh.get_coordinate() is not None
if participating:
if output_sharding.needs_redistribute:
DTensor._op_dispatcher.redistribute_local_args(
op_info, output_sharding.redistribute_schema
)
local_args = (
pytree.tree_unflatten(
cast(list[object], op_info.local_args), op_info.args_tree_spec
)
if op_info.args_tree_spec
else op_info.local_args
)
local_results = torch_npu.npu_apply_adam_w(*local_args, **op_info.local_kwargs)
if _is_out_variant_op(op_call):
output_specs = (
(output_sharding.output_spec,)
if not isinstance(output_sharding.output_spec, tuple)
else output_sharding.output_spec
)
out_dts = []
spec_idx = 0
for argument in op_call._schema.arguments:
if argument.name == "out":
for value in kwargs[argument.name]:
out_dt = cast(DTensor, value)
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
out_dts.append(out_dt)
spec_idx += 1
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else:
return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
@register_op_strategy(aten.native_dropout.default)
def custom_dropout_forward_sharding(op_schema: OpSchema):
args_schema = op_schema.args_schema
input_strategy = args_schema[0] if len(args_schema) > 0 else None
input_spec = input_strategy.strategies[0].output_spec
single_mesh_dim_strategies = []
if input_spec.placements[0].is_shard():
shard_dim = cast(Shard, input_spec.placements[0]).dim
if input_spec.shape[shard_dim] < input_spec.mesh.size(0):
output_target_specs = []
output_target_specs.append(input_spec)
output_target_specs.append(
DTensorSpec(mesh=input_spec.mesh, placements=[Shard(0)])
)
input_target_specs = []
input_target_specs.append(input_spec)
output_strategy = OpStrategy(
[
PlacementStrategy(
output_specs=output_target_specs, input_specs=input_target_specs
)
]
)
return output_strategy
replicate_strategy = [Replicate(), Replicate(), Replicate()]
single_mesh_dim_strategies.append(replicate_strategy)
for dim in range(input_spec.ndim):
shard_strategy = [Shard(dim), Shard(0), Shard(dim)]
single_mesh_dim_strategies.append(shard_strategy)
return expand_to_full_mesh_op_strategy(
input_spec.mesh, op_schema, single_mesh_dim_strategies, input_index=2
)
@register_op_strategy(aten.native_dropout_backward.default)
def custom_dropout_backward_sharding(op_schema: OpSchema) -> OpStrategy:
input_target_specs = []
for spec in op_schema.args_schema:
if isinstance(spec, OpStrategy):
input_target_specs.append(spec.strategies[0].output_spec)
output_strategy = OpStrategy(
[
PlacementStrategy(
output_specs=op_schema.args_schema[0].strategies[0].output_spec,
input_specs=input_target_specs,
)
]
)
return output_strategy
@register_op_strategy(npu.npu_bmmV2.default)
def custom_bmm_strategy(op_schema: OpSchema):
mesh = op_schema.get_mesh_from_args()
return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
customized_ops = {
npu.npu_grouped_matmul.default: _npu_grouped_matmul_handler,
npu.npu_grouped_matmul.List: _npu_grouped_matmul_handler,
npu.npu_apply_adam_w.out: _npu_apply_adam_w_handler,
npu.npu_all_gather_base_mm.default: npu_comm_mm_fusion_handler,
npu.npu_mm_reduce_scatter_base.default: npu_comm_mm_fusion_handler,
}
old_handlers = DTensor._op_dispatcher._custom_op_handlers
DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops}