import torch
from megatron.training import get_args
from megatron.core import mpu
from megatron.core.transformer.module import fp32_to_float16, float16_to_fp32, Float16Module
def _transform_model_dtype(module, transform, ignore_child_modules):
for child_module in module.children():
if child_module in ignore_child_modules:
continue
child_module = transform(child_module)
return module
def float16Module_init(self, config, module):
super(Float16Module, self).__init__(config)
ae_config = getattr(config.mm.model, "ae", None)
ae_float32 = False
if ae_config is not None and getattr(ae_config, "dtype", None) == torch.float32:
module.ae = module.ae.float()
ae_float32 = True
if config.fp16:
if ae_float32:
self.add_module("module", _transform_model_dtype(module, lambda module: module.half(), [module.ae]))
else:
self.add_module("module", module.half())
def float16_convertor(val):
return val.half()
elif config.bf16:
if ae_float32:
self.add_module("module", _transform_model_dtype(module, lambda module: module.bfloat16(), [module.ae]))
else:
self.add_module("module", module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception("should not be here")
self.float16_convertor = float16_convertor
def float16Module_forward(self, *inputs, **kwargs):
args = get_args()
if mpu.is_pipeline_first_stage():
ae_config = getattr(args.mm.model, "ae", None)
if ae_config is not None and getattr(ae_config, "dtype", None) != torch.float32:
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs