from torch._utils import _get_device_module
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase
class NPUDTensorTestBase(DTensorTestBase):
@property
def device_type(self):
return "npu"
@property
def world_size(self):
device_count = _get_device_module(self.device_type).device_count()
device_num = 4
if device_count > 1:
device_num = min(device_num, device_count)
return device_num