import sys
import time
from abc import ABC, abstractmethod
import torch
import megatron
from megatron.training import get_args, print_rank_0, get_timers
from megatron.training.training import (
print_datetime,
get_one_logger,
append_to_progress_log,
evaluate_and_print_results
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.transformer.spec_utils import import_module
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt import GPTModel
from megatron.training.checkpointing import save_checkpoint
from mindspeed_llm.training.training import build_train_args
from mindspeed_llm.training.training import train
from mindspeed_llm.training.initialize import set_jit_fusion_options
from mindspeed_llm.tasks.posttrain.utils import train_valid_test_datasets_provider
_TRAIN_START_TIME = time.time()
class BaseTrainer(ABC):
"""
BaseTrainer is an abstract base class that provides fundamental functions for training large language models.
It defines the following core methods:
- `__init__`: Initializes the basic attributes of the trainer.
- `initialize`: Initializes the trainer, including setting up timers, data iterators, etc.
- `model_provider`: Provides the model to be trained.
- `get_batch`: Retrieves a batch of data from the data iterator.
- `loss_func`: Computes the loss function.
- `forward_step`: Performs a forward pass step, computing the loss.
- `train`: The main training loop, controlling the entire training process.
"""
def __init__(self, process_non_loss_data_func=None):
self.args = get_args()
self.timers = get_timers()
self.process_non_loss_data_func = process_non_loss_data_func
self.train_args = None
self.model_type = None
self.test_data_iterator_list = None
self.train_valid_test_datasets_provider = train_valid_test_datasets_provider
self.initialize()
def initialize(self):
"""Sets up necessary configurations and logging."""
self.train_valid_test_datasets_provider.is_distributed = True
self.log_initialization()
set_jit_fusion_options()
self.synchronize_start_time()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME))
app_metrics = {}
app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)
self.train_args, self.test_data_iterator_list = build_train_args(
self.args,
self.timers,
self.train_valid_test_datasets_provider,
self.model_provider,
self.model_type,
self.forward_step,
self.process_non_loss_data_func,
app_metrics
)
def log_initialization(self):
"""Logs the initialization start."""
if self.args.log_progress:
append_to_progress_log("Starting job")
def synchronize_start_time(self):
"""Synchronize training start time across all distributed processes."""
global _TRAIN_START_TIME
start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.float, device='cuda')
torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
def model_provider(self, pre_process, post_process):
"""
Builds the model.
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss.
Defaults to True.
Returns:
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
print_rank_0('building GPT model ...')
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
if args.use_mcore_models:
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts,
args.moe_grouped_gemm, args.qk_layernorm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
mtp_block_spec = None
if args.mtp_num_layers is not None:
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor,
mtp_block_spec=mtp_block_spec,
)
else:
if not args.context_parallel_size == 1:
raise ValueError("Context parallelism is only supported with Megatron Core!")
model = megatron.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
@staticmethod
@abstractmethod
def get_batch(data_iterator):
"""
Retrieves a batch of data from the data iterator.
Called during each forward step.
"""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
def loss_func(self, input_tensor, output_tensor):
"""
Computes the loss function.
Called during each forward step.
"""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
def forward_step(self, data_iterator, model):
"""
Performs a forward pass and computes the loss.
Called during each training iteration.
"""
raise NotImplementedError("Subclasses must implement this method.")
def train(self):
args = get_args()
test_data_iterator = self.test_data_iterator_list[0]
forward_step_func, model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, process_non_loss_data_func, config = self.train_args
if not args.skip_train:
print_rank_0('training ...')
if args.dataloader_type == 'cyclic' and args.retro_project_dir:
if args.retro_cyclic_train_iters is None:
raise ValueError("retro_cyclic_train_iters must be provided.")
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
iteration = 0
if args.do_train and args.train_iters > 0:
iteration, num_floating_point_operations_so_far = train(*self.train_args)
print_datetime('after training is done')
if args.save and iteration != 0 and iteration % args.save_interval != 0:
save_checkpoint(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far
)
else:
print_rank_0('skipping training (--skip-train is on) ...')
iteration = args.iteration
if args.do_valid:
prefix = f'iteration {iteration} on validation set'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train)
if args.do_test:
prefix = f'iteration {iteration} on test set'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train)