import unittest
import torch
import torch.nn as nn
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests


class TestRecurrentLayers(TestCase):
    @unittest.skip("Temporarily skipping")
    def test_RNN(self):
        input1 = torch.randn(5, 3, 10).npu()
        h0 = torch.randn(2, 3, 20).npu()
        rnn = nn.RNN(10, 20, 2).npu()
        output, hn = rnn(input1, h0)
        self.assertEqual(output is not None, True)

    @unittest.skip("Temporarily skipping")
    def test_LSTM(self):
        input1 = torch.randn(5, 3, 10).npu()
        h0 = torch.randn(2, 3, 20).npu()
        c0 = torch.randn(2, 3, 20).npu()
        rnn = nn.LSTM(10, 20, 2).npu()
        output, (hn, cn) = rnn(input1, (h0, c0))
        self.assertEqual(output is not None, True)

    def test_GRU(self):
        input1 = torch.randn(5, 3, 10).npu()
        h0 = torch.randn(2, 3, 20).npu()
        rnn = nn.GRU(10, 20, 2).npu()
        output, hn = rnn(input1, h0)
        self.assertEqual(output is not None, True)

    @unittest.skip("Temporarily skipping")
    def test_RNNCell(self):
        input1 = torch.randn(6, 3, 10).npu()
        hx = torch.randn(3, 20).npu()
        output = []
        rnn = nn.RNNCell(10, 20).npu()
        for i in range(6):
            hx = rnn(input1[i], hx)
            output.append(hx)

    @unittest.skip("Temporarily skipping")
    def test_LSTMCell(self):
        input1 = torch.randn(2, 3, 10).npu()
        hx = torch.randn(3, 20).npu()
        cx = torch.randn(3, 20).npu()
        output = []
        rnn = nn.LSTMCell(10, 20).npu()
        for i in range(input1.size()[0]):
            hx, cx = rnn(input1[i], (hx, cx))
            output.append(hx)
        output = torch.stack(output, dim=0)
        self.assertEqual(output is not None, True)

    @unittest.skip("Temporarily skipping")
    def test_GRUCell(self):
        input1 = torch.randn(6, 3, 10).npu()
        hx = torch.randn(3, 20).npu()
        cx = torch.randn(3, 20).npu()
        output = []
        rnn = nn.GRUCell(10, 20).npu()
        for i in range(6):
            hx = rnn(input1[i], hx)
            output.append(hx)


if __name__ == "__main__":
    run_tests()