"""test jit."""
import os
import pytest
import numpy as np
import mindspore as ms
from mindspore.nn import Cell
from mindformers.tools.utils import is_pynative
from mindformers.models.utils import jit
class JitCell(Cell):
""" jit cell """
@jit
def construct(self, x, y):
return x * y
class TestJit:
"""A test class for jit."""
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_jit_graph_mode(self):
"""test jit for graph mode."""
os.environ["RUN_MODE"] = "predict"
cell = JitCell()
x = np.random.rand(10, 10).astype(np.float32)
y = np.random.rand(10, 10).astype(np.float32)
res = cell(ms.Tensor(x).astype(ms.float32), ms.Tensor(y).astype(ms.float32))
assert (res.asnumpy() == x * y).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_jit_pynative_mode(self):
"""test jit for pynative mode."""
os.environ["RUN_MODE"] = "predict"
os.environ['ENFORCE_EAGER'] = "True"
is_pynative_mode = is_pynative()
assert is_pynative_mode
cell = JitCell()
x = np.random.rand(10, 10).astype(np.float32)
y = np.random.rand(10, 10).astype(np.float32)
res = cell(ms.Tensor(x).astype(ms.float32), ms.Tensor(y).astype(ms.float32))
assert (res.asnumpy() == x * y).all()