import logging
import unittest
import torch
from amct_pytorch.common.utils.data_utils import float_to_fp4e2m1
def gloden_float_cast_to_float4_e2m1(tensor):
res = torch.zeros_like(tensor)
sign = torch.sign(tensor)
absvalues = torch.abs(tensor)
nan_mask = torch.isnan(tensor)
inf_mask = torch.isposinf(tensor)
neg_inf_mask = torch.isneginf(tensor)
res[absvalues <= 0.25] = 0
res[(absvalues > 0.25) & (absvalues < 0.75)] = 0.5
res[(absvalues >= 0.75) & (absvalues <= 1.25)] = 1.0
res[(absvalues > 1.25) & (absvalues < 1.75)] = 1.5
res[(absvalues >= 1.75) & (absvalues <= 2.5)] = 2.0
res[(absvalues > 2.5) & (absvalues < 3.5)] = 3.0
res[(absvalues >= 3.5) & (absvalues <= 5.0)] = 4.0
res[absvalues > 5.0] = 6.0
res *= sign
res[nan_mask] = float('nan')
res[inf_mask] = float('inf')
res[neg_inf_mask] = float('-inf')
return res
logger = logging.getLogger(__name__)
class Testfloat2fp4e2m1(unittest.TestCase):
'''
UT FOR DATA TRANSFORMATION FROM FLOAT 2 FP4E2M1
'''
@classmethod
def setUpClass(cls):
logger.info('Testfloat2fp4e2m1 START!')
@classmethod
def tearDownClass(cls):
logger.info('Testfloat2fp4e2m1 END!')
def test_float16_tensor_trans_2_fp4e2m1_success(self):
for i in range(0, 10):
torch.manual_seed(i)
ori_tensor = torch.randn((100, 100), dtype=torch.float16)
test_fp4 = float_to_fp4e2m1(ori_tensor)
golden_fp4 = gloden_float_cast_to_float4_e2m1(ori_tensor)
self.assertEqual(test_fp4.tolist(), golden_fp4.tolist())
self.assertEqual(test_fp4.dtype, torch.float16)
def test_bfloat16_tensor_trans_2_fp4e2m1_success(self):
for i in range(10, 20):
torch.manual_seed(i)
ori_tensor = torch.randn((100, 100), dtype=torch.bfloat16)
test_fp4 = float_to_fp4e2m1(ori_tensor)
golden_fp4 = gloden_float_cast_to_float4_e2m1(ori_tensor)
self.assertEqual(test_fp4.tolist(), golden_fp4.tolist())
self.assertEqual(test_fp4.dtype, torch.bfloat16)