"""
Hetero client python interface test.
"""
from __future__ import absolute_import
import argparse
import logging
from yr.datasystem import DsTensorClient
is_torch_exist = True
is_mindspore_exist = True
try:
import acl
import numpy
except ImportError:
is_torch_exist = False
is_mindspore_exist = False
try:
import torch
import torch_npu
except ImportError:
is_torch_exist = False
try:
import mindspore
except ImportError:
is_mindspore_exist = False
class DsTensorClientExample():
"""This class shows the SDK usage example of the HeteroClient."""
def __init__(self):
parser = argparse.ArgumentParser(description="DsTensorClient python interface Test")
parser.add_argument("--host", required=True, help="The IP of worker service")
parser.add_argument("--port", required=True, type=int, help="The port of worker service")
parser.add_argument("--device_id", type=int, default=0, help="The device id")
args = parser.parse_args()
self._host = args.host
self._port = args.port
self._device_id = args.device_id
logging.basicConfig(level=logging.INFO)
def torch_dev_mset_and_dev_mget_example(self):
"""test pytorch tensor"""
logging.info("Start executing torch_dev_mset_and_dev_mget_example...")
acl.init()
acl.rt.set_device(self._device_id)
torch_npu.npu.set_device(f'npu:{self._device_id}')
key = "key"
in_tensors = [torch.rand((2, 3), dtype=torch.float16, device=f'npu:{self._device_id}')]
client = DsTensorClient(self._host, self._port, self._device_id)
client.init()
failed_keys = client.dev_mset([key], in_tensors)
if failed_keys:
raise RuntimeError(f"dev_mset failed, failed keys: {failed_keys}")
out_tensors = [torch.zeros((2, 3), dtype=torch.float16, device=f'npu:{self._device_id}')]
sub_timeout_ms = 30_000
failed_keys = client.dev_mget([key], out_tensors, sub_timeout_ms)
if failed_keys:
raise RuntimeError(f"dev_mget failed, failed keys: {failed_keys}")
acl.finalize()
logging.info("Execute torch_dev_mset_and_dev_mget_example success.")
def mindspore_dev_mset_and_dev_mget_example(self):
"""test mindspore tensor"""
logging.info("Start executing mindspore_dev_mset_and_dev_mget_example...")
acl.init()
acl.rt.set_device(self._device_id)
mindspore.set_device(device_target="Ascend", device_id=self._device_id)
key = "key"
data = numpy.random.rand(2, 3)
in_tensors = [mindspore.Tensor(data, dtype=mindspore.float32) + 0]
client = DsTensorClient(self._host, self._port, self._device_id)
client.init()
failed_keys = client.dev_mset([key], in_tensors)
if failed_keys:
raise RuntimeError(f"dev_mset failed, failed keys: {failed_keys}")
out_tensors = [mindspore.Tensor(numpy.ones(shape=[2, 3]), dtype=mindspore.float32) + 0]
sub_timeout_ms = 30_000
failed_keys = client.dev_mget([key], out_tensors, sub_timeout_ms)
if failed_keys:
raise RuntimeError(f"dev_mget failed, failed keys: {failed_keys}")
acl.finalize()
logging.info("Execute mindspore_dev_mset_and_dev_mget_example success.")
if __name__ == '__main__':
example = DsTensorClientExample()
excute = False
if is_torch_exist:
example.torch_dev_mset_and_dev_mget_example()
excute = True
if is_mindspore_exist:
example.mindspore_dev_mset_and_dev_mget_example()
excute = True
if not excute:
logging.warning("No examples were executed.")