import itertools
from typing import cast, Any, Dict, Tuple
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed._tensor.experimental import register_sharding
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._op_schema import (
OpInfo,
OpSchema,
OutputSharding
)
from torch.distributed.tensor._redistribute import redistribute_local_tensor
import torch_npu
from ._common import (
get_redistributed_local_args,
get_redistributed_local_kwargs,
get_empty_local_results
)
npu = torch.ops.npu
@register_sharding(npu.npu_fusion_attention.default)
def npu_fusion_attention_strategy(query, key, value, head_num, input_layout, pse=None, padding_mask=None,
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
next_tockens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None,
actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False,
softmax_layout="", sink=None):
strategies = []
replicate_strategy = (
[
Replicate(),
Replicate(),
Replicate(),
Replicate(),
None, None, None
],
[
Replicate(),
Replicate(),
Replicate(),
None,
None,
None if pse is None else Replicate(),
None if padding_mask is None else Replicate(),
None if atten_mask is None else Replicate(),
None, None, None, None, None, None, None, None, None, None, None, None,
None if sink is None else Replicate()
]
)
strategies.append(replicate_strategy)
unused_args_in_sdpa = [pse, padding_mask, prefix, actual_seq_qlen, actual_seq_kvlen, sink]
if not all(arg is None for arg in unused_args_in_sdpa) or keep_prob < 1.0:
return strategies
if 'B' in input_layout:
batch_dim = input_layout.index('B')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[0] != 1:
atten_mask_sharding = Shard(0)
else:
atten_mask_sharding = Replicate()
dp_sharding_strategy = (
[
Shard(batch_dim),
Shard(0),
Shard(0),
Replicate(),
None, None, None
],
[
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
None,
None,
None,
None,
atten_mask_sharding,
None, None, None, None, None, None, None, None, None, None, None, None,
None
]
)
strategies.append(dp_sharding_strategy)
if 'N' in input_layout:
head_dim = input_layout.index('N')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[1] != 1:
atten_mask_sharding = Shard(1)
else:
atten_mask_sharding = Replicate()
tp_sharding_strategy = (
[
Shard(head_dim),
Shard(1),
Shard(1),
Replicate(),
None, None, None
],
[
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
None,
None,
None,
None,
atten_mask_sharding,
None, None, None, None, None, None, None, None, None, None, None, None,
None
]
)
strategies.append(tp_sharding_strategy)
return strategies
@register_sharding(npu.npu_fusion_attention_grad.default)
def npu_fusion_attention_grad_strategy(query, key, value, dy, head_num, input_layout, pse=None, padding_mask=None,
atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
attention_in=None, scale_value=1., keep_prob=1., pre_tockens=2147483647,
next_tockens=2147483647, inner_precise=0, seed=0, offset=0, numels=0,
prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
gen_mask_parallel=True, sync=False, softmax_layout="", sink=None):
strategies = []
replicate_strategy = (
[
Replicate(),
Replicate(),
Replicate(),
Replicate() if pse is not None else None,
Replicate()
],
[
Replicate(),
Replicate(),
Replicate(),
Replicate(),
None,
None,
None if pse is None else Replicate(),
None if padding_mask is None else Replicate(),
None if atten_mask is None else Replicate(),
None if softmax_max is None else Replicate(),
None if softmax_sum is None else Replicate(),
None if softmax_in is None else Replicate(),
None if attention_in is None else Replicate(),
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
None if sink is None else Replicate()
]
)
strategies.append(replicate_strategy)
unused_args_in_sdpa = [pse, padding_mask, prefix, actual_seq_qlen, actual_seq_kvlen, sink]
if not all(arg is None for arg in unused_args_in_sdpa) or keep_prob < 1.0:
return strategies
if 'B' in input_layout:
batch_dim = input_layout.index('B')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[0] != 1:
atten_mask_sharding = Shard(0)
else:
atten_mask_sharding = Replicate()
dp_sharding_strategy = (
[
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
Replicate() if pse is not None else None,
Replicate()
],
[
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
None,
None,
None,
None,
atten_mask_sharding,
Shard(0) if softmax_max is not None else None,
Shard(0) if softmax_sum is not None else None,
None if softmax_in is None else Replicate(),
Shard(batch_dim) if attention_in is not None else None,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
None
]
)
strategies.append(dp_sharding_strategy)
if 'N' in input_layout:
head_dim = input_layout.index('N')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[1] != 1:
atten_mask_sharding = Shard(1)
else:
atten_mask_sharding = Replicate()
tp_sharding_strategy = (
[
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
Replicate() if pse is not None else None,
Replicate()
],
[
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
None,
None,
None,
None,
atten_mask_sharding,
Shard(1) if softmax_max is not None else None,
Shard(1) if softmax_sum is not None else None,
None if softmax_in is None else Replicate(),
Shard(head_dim) if attention_in is not None else None,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
None
]
)
strategies.append(tp_sharding_strategy)
return strategies
@register_sharding(npu.npu_fusion_attention_v3.default)
def npu_fusion_attention_v3_strategy(query, key, value, head_num, input_layout, pse=None, padding_mask=None,
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
next_tockens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None,
actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False,
softmax_layout="", sink=None):
strategies = []
replicate_strategy = (
[
Replicate(),
Replicate(),
Replicate(),
Replicate(),
Replicate(),
Replicate()
],
[
Replicate(),
Replicate(),
Replicate(),
None,
None,
None if pse is None else Replicate(),
None if padding_mask is None else Replicate(),
None if atten_mask is None else Replicate(),
None, None, None, None, None, None,
None if actual_seq_qlen is None else Replicate(),
None if actual_seq_kvlen is None else Replicate(),
None, None, None, None,
None if sink is None else Replicate()
]
)
strategies.append(replicate_strategy)
unused_args_in_sdpa = [pse, padding_mask, prefix, actual_seq_qlen, actual_seq_kvlen, sink]
if not all(arg is None for arg in unused_args_in_sdpa) or keep_prob < 1.0:
return strategies
if 'B' in input_layout:
batch_dim = input_layout.index('B')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[0] != 1:
atten_mask_sharding = Shard(0)
else:
atten_mask_sharding = Replicate()
dp_sharding_strategy = (
[
Shard(batch_dim),
Shard(0),
Shard(0),
Replicate(),
Replicate(),
Replicate()
],
[
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
None,
None,
None,
None,
atten_mask_sharding,
None, None, None, None, None, None,
None if actual_seq_qlen is None else Replicate(),
None if actual_seq_kvlen is None else Replicate(),
None, None, None, None,
None
]
)
strategies.append(dp_sharding_strategy)
if 'N' in input_layout:
head_dim = input_layout.index('N')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[1] != 1:
atten_mask_sharding = Shard(1)
else:
atten_mask_sharding = Replicate()
tp_sharding_strategy = (
[
Shard(head_dim),
Shard(1),
Shard(1),
Replicate(),
Replicate(),
Replicate()
],
[
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
None,
None,
None,
None,
atten_mask_sharding,
None, None, None, None, None, None,
None if actual_seq_qlen is None else Replicate(),
None if actual_seq_kvlen is None else Replicate(),
None, None, None, None,
None
]
)
strategies.append(tp_sharding_strategy)
return strategies
@register_sharding(npu.npu_fusion_attention_grad_v3.default)
def npu_fusion_attention_grad_v3_strategy(query, key, value, dy, head_num, input_layout, pse=None, padding_mask=None,
atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
attention_in=None, scale_value=1., keep_prob=1., pre_tockens=2147483647,
next_tockens=2147483647, inner_precise=0, seed=None, offset=None,
prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
gen_mask_parallel=True, sync=False, softmax_layout="", sink=None):
strategies = []
replicate_strategy = (
[
Replicate(),
Replicate(),
Replicate(),
Replicate(),
Replicate()
],
[
Replicate(),
Replicate(),
Replicate(),
Replicate(),
None,
None,
None if pse is None else Replicate(),
None if padding_mask is None else Replicate(),
None if atten_mask is None else Replicate(),
None if softmax_max is None else Replicate(),
None if softmax_sum is None else Replicate(),
None if softmax_in is None else Replicate(),
None if attention_in is None else Replicate(),
None, None, None, None, None,
None if seed is None else Replicate(),
None if offset is None else Replicate(),
None,
None if actual_seq_qlen is None else Replicate(),
None if actual_seq_kvlen is None else Replicate(),
None, None, None, None,
None if sink is None else Replicate()
]
)
strategies.append(replicate_strategy)
unused_args_in_sdpa = [pse, padding_mask, prefix, actual_seq_qlen, actual_seq_kvlen, sink]
if not all(arg is None for arg in unused_args_in_sdpa) or keep_prob < 1.0:
return strategies
if 'B' in input_layout:
batch_dim = input_layout.index('B')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[0] != 1:
atten_mask_sharding = Shard(0)
else:
atten_mask_sharding = Replicate()
dp_sharding_strategy = (
[
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
Replicate(),
Replicate()
],
[
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
Shard(batch_dim),
None,
None,
None,
None,
atten_mask_sharding,
Shard(0) if softmax_max is not None else None,
Shard(0) if softmax_sum is not None else None,
None if softmax_in is None else Replicate(),
Shard(batch_dim) if attention_in is not None else None,
None, None, None, None, None,
None if seed is None else Replicate(),
None if offset is None else Replicate(),
None,
None if actual_seq_qlen is None else Replicate(),
None if actual_seq_kvlen is None else Replicate(),
None, None, None, None,
None
]
)
strategies.append(dp_sharding_strategy)
if 'N' in input_layout:
head_dim = input_layout.index('N')
atten_mask_sharding = None
if atten_mask is not None:
if atten_mask.ndim == 4 and atten_mask.shape[1] != 1:
atten_mask_sharding = Shard(1)
else:
atten_mask_sharding = Replicate()
tp_sharding_strategy = (
[
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
Replicate(),
Replicate()
],
[
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
Shard(head_dim),
None,
None,
None,
None,
atten_mask_sharding,
Shard(1) if softmax_max is not None else None,
Shard(1) if softmax_sum is not None else None,
None if softmax_in is None else Replicate(),
Shard(head_dim) if attention_in is not None else None,
None, None, None, None, None,
None if seed is None else Replicate(),
None if offset is None else Replicate(),
None,
None if actual_seq_qlen is None else Replicate(),
None if actual_seq_kvlen is None else Replicate(),
None, None, None, None,
None
]
)
strategies.append(tp_sharding_strategy)
return strategies
def _infer_npu_fusion_attention_grad_kwargs_spec(
op_schema: OpSchema,
output_sharding: OutputSharding
) -> Dict[str, DTensorSpec]:
input_layout = op_schema.args_schema[5]
batch_dim = input_layout.index('B') if 'B' in input_layout else None
dp_shard = Shard(batch_dim) if batch_dim is not None else None
head_dim = input_layout.index('N') if 'N' in input_layout else None
tp_shard = Shard(head_dim) if head_dim is not None else None
output_spec = output_sharding.output_spec[0]
kwargs_spec = {}
for key, spec in op_schema.kwargs_schema.items():
if not isinstance(spec, DTensorSpec):
kwargs_spec[key] = spec
continue
target_placement = []
for placement in output_spec.placements:
if placement == Replicate():
target_placement.append(Replicate())
elif placement == dp_shard:
if key == 'atten_mask':
atten_mask = op_schema.kwargs_schema[key]
if atten_mask.ndim == 4 and atten_mask.shape[0] != 1:
target_placement.append(dp_shard)
else:
target_placement.append(Replicate())
elif key == 'softmax_max' or key == 'softmax_sum':
target_placement.append(Shard(0))
elif key == 'attention_in':
target_placement.append(dp_shard)
else:
target_placement.append(Replicate())
elif placement == tp_shard:
if key == 'atten_mask':
atten_mask = op_schema.kwargs_schema[key]
if atten_mask.ndim == 4 and atten_mask.shape[1] != 1:
target_placement.append(tp_shard)
else:
target_placement.append(Replicate())
elif key == 'softmax_max' or key == 'softmax_sum':
target_placement.append(Shard(1))
elif key == 'attention_in':
target_placement.append(tp_shard)
else:
target_placement.append(Replicate())
else:
raise ValueError(
f"Unexpected placement {placement} for npu_fusion_attention_grad in layout {input_layout}."
)
kwargs_spec[key] = DTensorSpec(
mesh=spec.mesh,
placements=target_placement,
tensor_meta=spec.tensor_meta
)
return kwargs_spec
def _npu_fusion_attention_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
def npu_attention_input_fn(
mesh: DeviceMesh, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
all_args = []
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
arg = DTensor.from_local(arg, mesh, [Replicate()], run_check=False)
all_args.append(arg)
new_args = tuple(all_args[0: len(args)])
new_kwargs = dict(zip(kwargs.keys(), all_args[len(args):]))
return new_args, new_kwargs
runtime_schema_info = (
DTensor._op_dispatcher.sharding_propagator.op_to_schema_info.get(op_call, None)
)
if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree
from typing import Sequence
tree_args, args_spec = pytree.tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
args_list, args_spec = args, None
args, kwargs = npu_attention_input_fn(args_list[0].device_mesh, *args, **kwargs)
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
mesh = op_info.compute_mesh
participating = mesh.get_coordinate() is not None
if participating:
local_args = get_redistributed_local_args(op_info, output_sharding)
local_kwargs = op_info.local_kwargs
if op_call == npu.npu_fusion_attention.default or op_call == npu.npu_fusion_attention_v3.default:
input_layout = op_info.local_args[4]
if 'N' in input_layout:
head_dim = input_layout.index('N')
local_args = list(local_args)
local_query = local_args[0]
local_args[3] = local_query.size(head_dim)
local_args = tuple(local_args)
if op_call == npu.npu_fusion_attention.default:
local_results = torch_npu.npu_fusion_attention(
*local_args, **local_kwargs
)
else:
local_results = torch_npu.npu_fusion_attention_v3(
*local_args, **local_kwargs
)
elif op_call == npu.npu_fusion_attention_grad.default or op_call == npu.npu_fusion_attention_grad_v3.default:
local_kwargs = get_redistributed_local_kwargs(
_infer_npu_fusion_attention_grad_kwargs_spec, op_info, output_sharding
)
input_layout = op_info.local_args[5]
if 'N' in input_layout:
head_dim = input_layout.index('N')
local_args = list(local_args)
local_query = local_args[0]
local_args[4] = local_query.size(head_dim)
local_args = tuple(local_args)
if op_call == npu.npu_fusion_attention_grad.default:
local_results = torch_npu.npu_fusion_attention_grad(
*local_args, **local_kwargs
)
else:
local_results = torch_npu.npu_fusion_attention_grad_v3(
*local_args, **local_kwargs
)
else:
raise NotImplementedError(
"_npu_fusion_attention_handler only supports npu_fusion_attention and npu_fusion_attention_grad now."
)
else:
spec = output_sharding.output_spec
def default_tensor(spec: DTensorSpec) -> torch.Tensor:
if spec.tensor_meta is not None:
shape = spec.tensor_meta.shape
dtype = spec.tensor_meta.dtype
if len(shape) == 0:
return torch.zeros((), dtype=dtype)
else:
return torch.tensor([], dtype=dtype)
else:
raise RuntimeError(f"{spec} has no tensor metadata.")
local_results = [default_tensor(s) if s is not None else 0 for s in spec]
return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
customized_ops = {
npu.npu_fusion_attention.default: _npu_fusion_attention_handler,
npu.npu_fusion_attention_grad.default: _npu_fusion_attention_handler,
npu.npu_fusion_attention_v3.default: _npu_fusion_attention_handler,
npu.npu_fusion_attention_grad_v3.default: _npu_fusion_attention_handler,
}
old_handlers = DTensor._op_dispatcher._custom_op_handlers
DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops}