import logging
import os
import unittest
import numpy as np
import torch
import torch.nn as nn
from amct_pytorch.classic.graph_based.amct_pytorch.utils.module_info import (
ModuleInfo,
)
logger = logging.getLogger(__name__)
class TestModuleInfo(unittest.TestCase):
"""
The UT for evaluator helper
"""
@classmethod
def setUpClass(cls):
logger.info("TestModuleInfo start!")
@classmethod
def tearDownClass(cls):
logger.info("TestModuleInfo end!")
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_get_wts_cout_cin(self):
conv2d = torch.nn.Conv2d(1, 1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(conv2d)
self.assertEqual(cout_axis, 0)
self.assertEqual(cin_axis, 1)
conv3d = torch.nn.Conv3d(1, 1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(conv3d)
self.assertEqual(cout_axis, 0)
self.assertEqual(cin_axis, 1)
conv1d = torch.nn.Conv1d(1, 1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(conv1d)
self.assertEqual(cout_axis, 0)
self.assertEqual(cin_axis, 1)
deconv2d = torch.nn.ConvTranspose2d(1, 1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(deconv2d)
self.assertEqual(cout_axis, 1)
self.assertEqual(cin_axis, 0)
deconv3d = torch.nn.ConvTranspose3d(1, 1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(deconv3d)
self.assertEqual(cout_axis, 1)
self.assertEqual(cin_axis, 0)
linear = torch.nn.Linear(1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(linear)
self.assertEqual(cout_axis, 0)
self.assertEqual(cin_axis, 1)
conv_transpose_1d = torch.nn.ConvTranspose1d(1, 1, 1)
cout_axis, cin_axis = ModuleInfo.get_wts_cout_cin(conv_transpose_1d)
self.assertEqual(cout_axis, 1)
self.assertEqual(cin_axis, 0)