import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
import torch_npu
from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU
aten = torch.ops.aten
class CommonRulesTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 4
def _gen_tensor_meta(self, shape):
empty_tensor = torch.empty(shape)
return TensorMeta(
empty_tensor.shape,
empty_tensor.stride(),
empty_tensor.dtype,
)
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_einop_basic_propagation(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
mm_call = aten.mm.default
mat1, mat2 = [-1, -1], [-1, 0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0])
mat1, mat2 = [0, -1], [-1, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, -1])
mat1, mat2 = [-1, 0], [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertTrue(output_spec.placements[0].is_partial())
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_einop_pointwise_propagation(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
add_call = aten.add.Tensor
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
mat1 = [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
output_sharding = einop_rule(
"ij,ij->ij", OpSchema(add_call, (mat1_spec, mat1_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, -1])
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8]))
mat1 = [-1, 0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([2]))
mat2_spec = DTensorSpec.from_dim_map(
mesh, [-1], [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"ijk,k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0, -1])
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8, 8]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([1, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, [0, -1, -1], [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, [-1, -1], [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"ijk,1k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, -1, -1])
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_einop_merge_sharding(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
mm_call = aten.mm.default
mat1, mat2 = [0, -1], [-1, 1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [0, 1])
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_einop_linearity(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
mm_call = aten.mm.default
mat1, mat2 = [0, -1], [-1, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
self.assertIsNone(output_sharding.output_spec)
suggestions = output_sharding.redistribute_schema
self.assertIsNotNone(suggestions)
suggested_spec = suggestions.args_schema[0]
self.assertFalse(suggested_spec.placements[1].is_partial())
output_sharding = einop_rule(
"mk,kn->mn",
OpSchema(mm_call, (mat1_spec, mat2_spec), {}),
linearity=True,
)
self.assertIsNone(output_sharding.output_spec)
suggestions = output_sharding.redistribute_schema
self.assertIsNotNone(suggestions)
mat2_spec = suggestions.args_schema[1]
self.assertTrue(mat2_spec.placements[1].is_partial())
add_call = aten.add.Tensor
mat1, mat2 = [0, -1], [0, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [1], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"ij,ij->ij",
OpSchema(add_call, (mat1_spec, mat2_spec), {}),
linearity=True,
)
self.assertIsNone(output_sharding.output_spec)
suggestions = output_sharding.redistribute_schema
self.assertIsNotNone(suggestions)
mat2_spec = suggestions.args_schema[1]
self.assertTrue(mat2_spec.placements[1].is_partial())
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_einop_multi_sharding_on_mesh_dim(self):
mesh_shape = torch.arange(self.world_size)
mesh = DeviceMesh(self.device_type, mesh_shape)
mm_call = aten.mm.default
mat1, mat2 = [0, -1], [0, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 12]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {})
)
output_spec = output_sharding.output_spec
self.assertIsNone(output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1])
self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1])
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_einop_errors(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
add_call = aten.add.Tensor
mat1, mat2 = [0, -1], [1, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"):
einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {}))
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_pointwise_rules_broadcasting(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
where_call = aten.where.self
inp1, inp2, inp3 = [0], [], [-1, -1]
inp1_tensor_meta = self._gen_tensor_meta(torch.Size([8]))
inp2_tensor_meta = self._gen_tensor_meta(torch.Size([]))
inp3_tensor_meta = self._gen_tensor_meta(torch.Size([1, 1]))
condition = DTensorSpec.from_dim_map(
mesh, inp1, [], tensor_meta=inp1_tensor_meta
)
self_tensor = DTensorSpec.from_dim_map(
mesh, inp2, [], tensor_meta=inp2_tensor_meta
)
other_tensor = DTensorSpec.from_dim_map(
mesh, inp3, [], tensor_meta=inp3_tensor_meta
)
output_sharding = pointwise_rule(
OpSchema(where_call, (condition, self_tensor, other_tensor), {})
)
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0])
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_pointwise_rules_suggestion(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
lerp_call = aten.lerp.Scalar
inp1, inp2 = [-1, -1], [-1, 0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, inp1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, inp2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(
OpSchema(lerp_call, (mat1_spec, mat2_spec, -1), {})
)
self.assertIsNone(output_sharding.output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(len(schema_suggestion.args_schema), 3)
self.assertEqual(schema_suggestion.args_schema[2], -1)
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_pointwise_multi_sharding_on_mesh_dim(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
add_call = aten.add.Tensor
mat1, mat2 = [-1, 0], [0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([20, 6]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([6]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
output_spec = output_sharding.output_spec
self.assertIsNotNone(output_spec)
self.assertEqual(output_spec.dim_map, [-1, 0])
mat1, mat2 = [0, -1, -1, 1], [0, -1, 1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 1, 8]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
output_spec = output_sharding.output_spec
self.assertIsNone(output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1])
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2)
@skipIfUnsupportMultiNPU(4)
@with_comms
def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self):
mesh_shape = torch.arange(self.world_size).reshape(
self.world_size // 2, self.world_size // 2
)
mesh = DeviceMesh(self.device_type, mesh_shape)
add_call = aten.add_.Tensor
mat1, mat2 = [0, -1, 1], [-1, -1, 0]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8]))
mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 8]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], tensor_meta=mat2_tensor_meta
)
output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {}))
output_spec = output_sharding.output_spec
self.assertIsNone(output_spec)
self.assertIsNotNone(output_sharding.redistribute_schema)
schema_suggestion = output_sharding.redistribute_schema
self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1)
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1)
if __name__ == "__main__":
run_tests()