from functools import wraps
from itertools import product
from contextlib import contextmanager
from typing import List
import os
import sys
import unittest
import json
import stat
import atexit
import threading
import tempfile
import functools
import numpy as np
import torch
import torch_npu
from torch_npu.utils._path_manager import PathManager
@contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
yield
torch.set_rng_state(rng_state)
def iter_indices(tensor):
if tensor.dim() == 0:
return range(0)
if tensor.dim() == 1:
return range(tensor.size(0))
return product(*(range(s) for s in tensor.size()))
def is_iterable(obj):
try:
iter(obj)
return True
except TypeError:
return False
def set_npu_device():
npu_device = get_npu_device()
torch.npu.set_device(npu_device)
return npu_device
def get_npu_device():
npu_device = os.environ.get('SET_NPU_DEVICE')
if npu_device is None:
npu_device = "npu:0"
else:
npu_device = f"npu:{npu_device}"
return npu_device
def create_common_tensor(item, minValue, maxValue, device=None):
if device is None:
device = get_npu_device()
dtype = item[0]
npu_format = item[1]
shape = item[2]
input1 = np.random.uniform(minValue, maxValue, shape).astype(dtype)
cpu_input = torch.from_numpy(input1)
npu_input = torch.from_numpy(input1).to(device)
if npu_format != -1:
npu_input = torch_npu.npu_format_cast(npu_input, npu_format)
return cpu_input, npu_input
def __generate_2args_broadcast_cases(device=None):
if device is None:
device = get_npu_device()
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
Y = np.random.rand(1, 1, 1).astype(np.float32)
cpu_x = torch.from_numpy(X)
npu_x = torch.from_numpy(X).to(device)
cpu_y = torch.from_numpy(Y)
npu_y = torch.from_numpy(Y).to(device)
yield cpu_x, cpu_y, npu_x, npu_y
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
Y = np.random.rand(4, 5).astype(np.float32)
cpu_x = torch.from_numpy(X)
npu_x = torch.from_numpy(X).to(device)
cpu_y = torch.from_numpy(Y)
npu_y = torch.from_numpy(Y).to(device)
yield cpu_x, cpu_y, npu_x, npu_y
def test_2args_broadcast(fn):
output_list = []
for cpu_x, cpu_y, npu_x, npu_y in __generate_2args_broadcast_cases():
cpu_out = fn(cpu_x, cpu_y).numpy()
npu_out = fn(npu_x, npu_y).to("cpu").numpy()
output_list.append([cpu_out, npu_out])
return output_list
def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, no_zero=False, device=None):
if device is None:
device = get_npu_device()
if dtype == torch.bool:
x = np.random.randint(0, 2, size=shape).astype(bool)
elif dtype == torch.half:
x = np.random.uniform(min_value, max_value, shape).astype(np.float16)
elif dtype == torch.float:
x = np.random.uniform(min_value, max_value, shape).astype(np.float32)
else:
x = np.random.randint(min_value, max_value + 1, size=shape).astype(np.int32)
if no_zero:
ones = np.ones_like(x)
x = np.where(x != 0, x, ones)
cpu_input = torch.from_numpy(x)
npu_input = torch.from_numpy(x).to(device)
if npu_format != -1 and (dtype in [torch.float, torch.half]):
npu_input = torch_npu.npu_format_cast(npu_input, npu_format)
return cpu_input, npu_input
def check_operators_in_prof(expected_operators, prof, unexpected_operators=None):
unexpected_operators = unexpected_operators or []
prof_key_averages = prof.key_averages()
if not prof_key_averages:
return print("torch profiling is empty, please check it")
for prof_item in prof_key_averages:
if prof_item.key.replace("cached_", "") in unexpected_operators:
return False
elif prof_item.key.replace("cached_", "") in expected_operators:
expected_operators.remove(prof_item.key.replace("cached_", ""))
if not expected_operators:
return True
return False
class SkipIfNoLapack(object):
def __call__(self, fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if not torch._C.has_lapack:
raise unittest.SkipTest('PyTorch compiled without Lapack')
else:
fn(*args, **kwargs)
return wrapper
class SkipIfNotRegistered(object):
"""Wraps the decorator to hide the import of the `core`.
Args:
op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`.
message: message to fail with.
Usage:
@SkipIfNotRegistered('MyOp', 'MyOp is not linked!')
This will check if 'MyOp' is in the caffe2.python.core
"""
def __call__(self, message):
try:
from caffe2.python import core
skipper = unittest.skipIf(self not in core._REGISTERED_OPERATORS, message)
except ImportError:
skipper = unittest.skip("Cannot import `caffe2.python.core`")
return skipper
class SupportedDevices:
def __init__(self, supported_devices: List[str]) -> None:
self.supported_devices = supported_devices
def __call__(self, fn):
@wraps(fn)
def dep_fn(slf, *args, **kwargs):
device_name = torch_npu.npu.get_device_name(0)[:10]
if device_name == "Ascend910_":
device_name = torch_npu.npu.get_device_name(0)[:12]
elif device_name.startswith("Ascend950"):
device_name = torch_npu.npu.get_device_name(0)[:9]
if device_name not in self.supported_devices:
reason = f"Only run on {repr(self.supported_devices)}, current device is {device_name}."
raise unittest.SkipTest(reason)
return fn(slf, *args, **kwargs)
return dep_fn
class SkipIfNotGteCANNVersion:
def __init__(self, base_version, module="CANN"):
self.base_version = base_version
self.module = module
def __call__(self, fn):
@wraps(fn)
def func(slf, *args, **kwargs):
current_version = torch_npu.utils.get_cann_version(module=self.module)
is_gte = torch_npu.npu.utils._is_gte_cann_version(self.base_version, module=self.module)
if not is_gte:
reason = f"Only run on {self.base_version} or greater, current cann version is {current_version}."
raise unittest.SkipTest(reason)
return fn(slf, *args, **kwargs)
return func
PERF_TEST_ENABLE = (os.getenv('PERF_TEST_ENABLE', default='').upper() in ['ON', '1', 'YES', 'TRUE', 'Y'])
PERF_BASELINE_FILE = os.path.realpath(
os.getenv("PERF_BASELINE_FILE", default=os.path.join(os.getcwd(), "performance_baseline.json"))
)
class Baseline(object):
def __init__(self, baselineFile):
self._baseline = {}
self._baselineFile = baselineFile
self._mutex = threading.Lock()
if os.path.exists(self._baselineFile):
PathManager.check_directory_path_readable(self._baselineFile)
with open(self._baselineFile, "r") as f:
self._baseline = json.load(f)
def get_baseline(self, resourceId):
return self._baseline.get(resourceId)
def set_baseline(self, resourceId, baseline):
with self._mutex:
self._baseline[resourceId] = baseline
def save_baseline(self):
with self._mutex:
with os.fdopen(os.open(self._baselineFile, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), "w") as f:
json.dump(self._baseline, f)
PerfBaseline = Baseline(PERF_BASELINE_FILE)
@atexit.register
def dump_baseline():
if PERF_TEST_ENABLE:
PerfBaseline.save_baseline()
def _create_scaling_models_optimizers(device="npu", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
with torch.no_grad():
for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
s.copy_(c)
kwargs = {"lr": 1.0}
if optimizer_kwargs is not None:
kwargs.update(optimizer_kwargs)
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
return mod_control, mod_scaling, opt_control, opt_scaling
def _create_scaling_case(device="npu", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
loss_fn = torch.nn.MSELoss().npu()
skip_iter = 2
return _create_scaling_models_optimizers(
device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
) + (data, loss_fn, skip_iter)
@functools.lru_cache
def get_cycles_per_ms() -> float:
"""Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
"""
def measure() -> float:
start = torch_npu.npu.Event(enable_timing=True)
end = torch_npu.npu.Event(enable_timing=True)
start.record()
torch_npu.npu._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms
num = 10
vals = []
for _ in range(num):
vals.append(measure())
vals = sorted(vals)
return mean(vals[2 : num - 2])