# Owner(s): ["module: dynamo"]
import functools
import unittest
import torch
import torch._dynamo.test_case
import torch_npu

requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu")


class StreamintoDynamoTests(torch._dynamo.test_case.TestCase):

    @requires_npu()
    def test_stream(self):
        def model_1(x):
            a = x * x
            s = torch.npu.Stream()
            s.wait_stream(torch.npu.current_stream())
            with torch.npu.stream(s):
                b = x + a
            return b
        inp = torch.randn(2, 8).npu()
        m = torch.compile(model_1, backend="aot_eager", fullgraph=True)
        output = m(inp)
        output1 = model_1(inp)
        torch.allclose(output, output1)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()