import contextlib
import torch
from fp16.fp16 import conversion_helper
import fp16
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if val_typecheck.dtype == torch.float32:
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if val_typecheck.dtype in [torch.float16, torch.bfloat16]:
val = val.float()
return val
return conversion_helper(val, float_conversion)
fp16.fp32_to_fp16 = fp32_to_float16
fp16.fp16_to_fp32 = float16_to_fp32