from typing import Iterable, Union
import torch
import torch_npu
from . import _lazy_init, _lazy_call, device_count, current_device
__all__ = ['get_rng_state', 'set_rng_state',
'get_rng_state_all', 'set_rng_state_all',
'manual_seed', 'manual_seed_all',
'seed', 'seed_all', 'initial_seed']
def get_rng_state(device: Union[int, str, torch.device] = 'npu') -> torch.Tensor:
r"""Returns the random number generator state of the specified NPU as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'npu'`` (i.e., ``torch.device('npu')``, the current NPU device).
.. warning::
This function eagerly initializes NPU.
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('npu', device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch_npu.npu.default_generators[idx]
return default_generator.get_state()
def get_rng_state_all():
r"""Returns a list of ByteTensor representing the random number states of all devices."""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
return results
def set_rng_state(new_state: torch.Tensor, device: Union[int, str, torch.device] = 'npu') -> None:
r"""Sets the random number generator state of the specified NPU.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'npu'`` (i.e., ``torch.device('npu')``, the current NPU device).
"""
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('npu', device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch_npu.npu.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
def set_rng_state_all(new_states):
r"""Sets the random number generator state of all devices.
Args:
new_states (Iterable of torch.ByteTensor): The desired state for each device
"""
for i, state in enumerate(new_states):
set_rng_state(state, i)
def manual_seed(seed):
r"""Sets the seed for generating random numbers for the current NPU.
It's safe to call this function if NPU is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
.. warning::
If you are working with a multi-NPU model, this function is insufficient
to get determinism. To seed all NPUs, use :func:`manual_seed_all`.
"""
seed = int(seed)
def cb():
idx = current_device()
default_generator = torch_npu.npu.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb)
def manual_seed_all(seed):
r"""Sets the seed for generating random numbers on all NPUs.
It's safe to call this function if NPU is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
"""
seed = int(seed)
def cb():
for i in range(device_count()):
default_generator = torch_npu.npu.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb)
def seed():
r"""Sets the seed for generating random numbers to a random number for the current NPU.
It's safe to call this function if NPU is not available; in that
case, it is silently ignored.
.. warning::
If you are working with a multi-NPU model, this function will only initialize
the seed on one NPU. To initialize all NPUs, use :func:`seed_all`.
"""
def cb():
idx = current_device()
default_generator = torch_npu.npu.default_generators[idx]
default_generator.seed()
_lazy_call(cb)
def seed_all():
r"""Sets the seed for generating random numbers to a random number on all NPUs.
It's safe to call this function if NPU is not available; in that
case, it is silently ignored.
"""
def cb():
random_seed = 0
seeded = False
for i in range(device_count()):
default_generator = torch_npu.npu.default_generators[i]
if not seeded:
default_generator.seed()
random_seed = default_generator.initial_seed()
seeded = True
else:
default_generator.manual_seed(random_seed)
_lazy_call(cb)
def initial_seed():
r"""Returns the current random seed of the current NPU.
.. warning::
This function eagerly initializes NPU.
"""
_lazy_init()
idx = current_device()
default_generator = torch_npu.npu.default_generators[idx]
return default_generator.initial_seed()