import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests


class TestSparseFunctions(TestCase):
    def test_embedding(self):
        input1 = torch.tensor([[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long)
        embd = nn.Embedding(20, 20)
        weight = embd.weight
        npu_input = input1.npu().int()
        npu_weight = weight.npu()
        cpu_output = F.embedding(input1, weight)
        npu_output = F.embedding(npu_input, npu_weight)

        self.assertRtolEqual(cpu_output.detach().numpy(), npu_output.detach().cpu().numpy())

    def test_one_hot(self):
        input1 = torch.arange(0, 5) % 3
        npu_input = input1.npu().int()

        cpu_output = F.one_hot(input1)
        npu_output = F.one_hot(npu_input)

        self.assertRtolEqual(cpu_output.detach().int().numpy(), npu_output.detach().cpu().numpy())


if __name__ == "__main__":
    run_tests()