import torch
from torch.distributed.tensor._op_schema import OpSchema, RuntimeSchemaInfo
from torch.distributed.tensor._ops._pointwise_ops import pointwise_strategy
from torch.distributed.tensor._ops.utils import register_op_strategy
aten = torch.ops.aten
npu = torch.ops.npu
custom_linear_pointwise_ops = {
npu.npu_dtype_cast.default: 0,
npu._npu_dtype_cast.default: 0,
npu.npu_dtype_cast_backward.default: 0,
npu._npu_dtype_cast_backward.default: 0,
}
def custom_linear_pointwise_strategy(op_schema: OpSchema):
op_type = custom_linear_pointwise_ops.get(op_schema.op, -1)
return pointwise_strategy(op_schema, linearity=op_type)
for op in custom_linear_pointwise_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
custom_linear_pointwise_strategy
)
custom_pointwise_ops = [
aten.isclose.default,
aten.isfinite.default,
npu.fast_gelu.default,
npu.npu_fast_gelu.default,
npu.npu_layer_norm_eval.default,
npu.npu_fast_gelu_backward.default,
]
for op in custom_pointwise_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
pointwise_strategy
)