import os
import gc
import torch
import torch.distributed as dist
import torch.nn.functional as F
import time
import contextlib
from typing import Optional, Tuple, Dict, Any, Union, List
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
from mindspeed_llm.fsdp2.distributed.parallel_state import ParallelState
from mindspeed_llm.fsdp2.distributed.clip_grad_norm import clip_grad_norm
from mindspeed_llm.fsdp2.data.data_factory import DataManager
from mindspeed_llm.fsdp2.data.processor.processor_utils import IGNORE_INDEX
from mindspeed_llm.fsdp2.features.chunkloss import chunk_loss, calculate_lm_loss
from mindspeed_llm.fsdp2.checkpoint.utils import empty_cache, cleanup_old_checkpoints
from mindspeed_llm.fsdp2.utils.dist_op import all_reduce
from mindspeed_llm.fsdp2.utils.logging import get_logger
from mindspeed_llm.fsdp2.utils.train_monitor import TrainMonitor
from mindspeed_llm.fsdp2.utils.profiler import ProfilerConfig, ProfilerManager
logger = get_logger(__name__)
class Trainer:
"""
Orchestrates the training loop, coordinating Model, Optimizer, Scheduler, Data, and IO.
Strictly follows the gradient accumulation and loop logic found in HuggingFace Transformers.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
data_manager: DataManager,
args,
parallel_args,
optimization_args,
data_args,
ckpt_manager,
monitor: TrainMonitor,
tokenizer=None,
):
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = data_manager.create_train_dataloader()
self.args = args
self.parallel_args = parallel_args
self.optimization_args = optimization_args
self.data_args = data_args
self.ckpt_manager = ckpt_manager
self.train_monitor = monitor
self.tokenizer = tokenizer
self.epoch = 0
self.global_step = 0
self._last_logged_step = 0
self._total_loss_scalar = 0.0
self._logging_loss_scalar = 0.0
self._global_step_last_logged = 0
self._last_logged_loss_scalar = 0.0
self.batch_seqlens = []
self._step_start_time = None
current_rank = dist.get_rank() if dist.is_initialized() else 0
prof_config = ProfilerConfig(
enabled=args.profile,
profile_step_start=args.profile_step_start,
profile_step_end=args.profile_step_end,
profile_ranks=args.profile_ranks,
profile_level=args.profile_level,
profile_export_type=args.profile_export_type,
profile_data_simplification=args.profile_data_simplification,
profile_with_cpu=args.profile_with_cpu,
profile_with_stack=args.profile_with_stack,
profile_with_memory=args.profile_with_memory,
profile_record_shapes=args.profile_record_shapes,
profile_save_path=args.profile_save_path,
current_rank=current_rank,
)
self.profiler_manager = ProfilerManager(prof_config)
@staticmethod
def _build_chunk_loss(labels, ignore_index=-100, chunk_size=1024):
shift_labels = labels
loss_mask = shift_labels != ignore_index
alpha = loss_mask.sum()
reduction = "sum"
chunk_labels = torch.split(shift_labels, chunk_size, dim=1)
loss_ctx_kwargs = [
{
"shift_labels": chunk_labels[i],
"ignore_index": ignore_index,
"reduction": reduction,
"alpha": alpha,
}
for i in range(len(chunk_labels))
]
def loss_ctx(hidden_states, head_weight, head_bias):
return chunk_loss(
hidden_states,
head_weight,
head_bias,
loss_forward=calculate_lm_loss,
loss_kwargs_chunks=loss_ctx_kwargs,
chunk_size=chunk_size
)
return loss_ctx, loss_mask
def train(self, resume_from_checkpoint: Optional[str] = None):
"""
Main training loop.
"""
args = self.args
train_dataloader = self.train_dataloader
ps = ParallelState()
if ps.is_fsdp_enable():
reduce_group = ps.get_fsdp_group()
else:
reduce_group = None
steps_in_epoch = len(train_dataloader)
total_updates_per_epoch = steps_in_epoch // args.gradient_accumulation_steps + int(
steps_in_epoch % args.gradient_accumulation_steps > 0
)
total_steps = args.max_steps if args.max_steps > 0 else (total_updates_per_epoch * args.num_train_epochs)
epochs_trained = 0
steps_trained_in_current_epoch = 0
save_checkpoint_path = None
global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * ps.get_group_size('dp_fsdp')
logger.info_rank0("***** Running training (FSDP2) *****")
logger.info_rank0(f" Num examples = {len(train_dataloader.dataset)}")
logger.info_rank0(f" Num Epochs = {args.num_train_epochs}")
logger.info_rank0(f" Total Batch Size = {global_batch_size}")
logger.info_rank0(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info_rank0(f" Total optimization steps = {total_steps}")
if resume_from_checkpoint:
state = {"model": self.model, "optimizer": self.optimizer, "extra_state": {}}
self.ckpt_manager.load(resume_from_checkpoint, state)
extra_state = state.get("extra_state", {})
if extra_state:
self.global_step = state["extra_state"]["global_step"]
self._global_step_last_logged = state['extra_state']['_global_step_last_logged']
self._last_logged_step = state['extra_state']['_global_step_last_logged']
epochs_trained = int(self.global_step // total_updates_per_epoch)
steps_trained_in_current_epoch = self.global_step % total_updates_per_epoch
self.lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"])
torch.set_rng_state(state["extra_state"]["torch_rng_state"])
dist.barrier()
logger.info_rank0(f"Load distributed checkpoint from {resume_from_checkpoint} successfully!")
logger.info_rank0(f"Resuming from epoch {epochs_trained}, step {steps_trained_in_current_epoch}")
logger.info_rank0(f"Global step = {self.global_step}")
else:
logger.info_rank0("Loaded model weights only, starting training from step 0")
self.model.train()
train_start_time = time.time()
self._step_start_time = time.time()
if self.profiler_manager.profiler is not None:
self.profiler_manager.start()
for epoch in range(epochs_trained, int(args.num_train_epochs)):
self.epoch = epoch
if hasattr(train_dataloader.sampler, "set_epoch"):
train_dataloader.sampler.set_epoch(epoch)
epoch_iterator = iter(train_dataloader)
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
total_updates = total_updates_per_epoch
if epoch == epochs_trained and steps_trained_in_current_epoch > 0:
steps_to_skip = steps_trained_in_current_epoch * args.gradient_accumulation_steps
logger.info_rank0(f" Skipping {steps_to_skip} batches in epoch {epoch}")
for _ in range(steps_to_skip):
try:
next(epoch_iterator)
except StopIteration:
break
start_update_step = steps_trained_in_current_epoch
else:
start_update_step = 0
for update_step in range(start_update_step, total_updates):
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, batch_seqlens, num_items_in_batch = self.get_batch_samples_func()(ps, epoch_iterator, num_batches)
current_step_loss = 0.0
for i, inputs in enumerate(batch_samples):
do_sync_step = (i == len(batch_samples) - 1)
fsdp_root = self._get_fsdp_root()
sync_context = fsdp_root.no_sync() if (not do_sync_step and hasattr(fsdp_root, "no_sync")) else contextlib.nullcontext()
with sync_context:
loss = self.training_step(inputs, num_items_in_batch)
if not torch.isnan(loss) and not torch.isinf(loss):
current_step_loss += loss.item()
grad_norm = clip_grad_norm(
self.model,
args.max_grad_norm
)
if isinstance(grad_norm, torch.Tensor):
grad_norm = grad_norm.item()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
if self.profiler_manager.profiler is not None:
self.profiler_manager.step()
reduced_loss, reduced_grad_norm = all_reduce(
(current_step_loss, grad_norm),
group=reduce_group
)
self._total_loss_scalar += reduced_loss
self.batch_seqlens.extend(batch_seqlens)
if self.global_step % args.logging_steps == 0:
_, record_info = self.train_monitor.step(
self.epoch, self.lr_scheduler, global_batch_size,
reduced_grad_norm, self.batch_seqlens,
self._step_start_time, total_steps, self.global_step,
self._last_logged_step, self._total_loss_scalar,
self._last_logged_loss_scalar)
self._step_start_time = record_info['time']
self._last_logged_loss_scalar = record_info['logged_loss']
self._last_logged_step = record_info['logged_step']
self.batch_seqlens.clear()
if args.save_steps > 0 and self.global_step % args.save_steps == 0:
if not args.output_dir:
logger.info_rank0("output_dir is not set, skipping checkpoint saving")
else:
empty_cache()
save_checkpoint_path = os.path.join(args.output_dir, "checkpoints", f"global_step_{self.global_step}")
state = {
"model": self.model,
"optimizer": self.optimizer,
"extra_state": {
"global_step": self.global_step,
"_global_step_last_logged": self._global_step_last_logged,
"_last_logged_step": self._last_logged_step,
"lr_scheduler": self.lr_scheduler.state_dict(),
"train_metric": self.train_monitor.state_dict(),
"torch_rng_state": torch.get_rng_state(),
},
}
self.ckpt_manager.save(path=save_checkpoint_path, state=state, save_only_model=args.save_only_model, global_steps=self.global_step)
dist.barrier()
logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!")
cleanup_old_checkpoints(args, self.data_args.data_shared_file_system)
if self.global_step >= total_steps:
break
steps_trained_in_current_epoch = 0
already_saved = (args.save_steps > 0 and self.global_step % args.save_steps == 0)
if args.save_epochs and (epoch + 1) % args.save_epochs == 0 and not already_saved:
if not args.output_dir:
logger.info_rank0("output_dir is not set, skipping checkpoint saving")
else:
empty_cache()
save_checkpoint_path = os.path.join(args.output_dir, "checkpoints", f"global_step_{self.global_step}")
state = {
"model": self.model,
"optimizer": self.optimizer,
"extra_state": {
"global_step": self.global_step,
"_global_step_last_logged": self._global_step_last_logged,
"_last_logged_step": self._last_logged_step,
"lr_scheduler": self.lr_scheduler.state_dict(),
"train_metric": self.train_monitor.state_dict(),
"torch_rng_state": torch.get_rng_state(),
},
}
self.ckpt_manager.save(path=save_checkpoint_path, state=state, save_only_model=args.save_only_model, global_steps=self.global_step)
dist.barrier()
logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!")
cleanup_old_checkpoints(args, self.data_args.data_shared_file_system)
if self.global_step >= total_steps:
break
if self.profiler_manager.profiler is not None:
self.profiler_manager.stop()
if args.save_hf_weights and save_checkpoint_path is not None:
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
model_state_dict = get_model_state_dict(self.model.model, options=options)
if dist.get_rank() == 0:
model_configs = [self.model.model.config, self.tokenizer]
self.ckpt_manager.save_model_weights(
args.output_dir, model_state_dict, model_configs=model_configs
)
logger.info_rank0(f"Huggingface checkpoint saved at {args.output_dir} successfully!")
del model_state_dict
gc.collect()
empty_cache()
dist.barrier()
if dist.get_rank() == 0 and args.output_dir:
self.ckpt_manager.save_args(args, args.output_dir)
logger.info_rank0(f"Training arguments saved at {args.output_dir} successfully!")
logger.info_rank0(f"Training completed in {time.time() - train_start_time:.2f}s")
def training_step(self, inputs: Dict[str, Any], num_items_in_batch: Optional[int]) -> torch.Tensor:
"""
Performs a single forward and backward pass.
"""
self.model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
loss = self._compute_loss(inputs, return_outputs=False, num_items_in_batch=num_items_in_batch)
del inputs
if torch.cuda.device_count() > 1:
loss = loss.mean()
loss.backward()
return loss.detach()
def get_batch_samples_func(self):
if self.parallel_args.cp_size > 1:
if self.parallel_args.cp_type == "ulysses":
return self._get_batch_samples_ulysses
else:
raise ValueError(f"Unsupported cp_type: '{self.parallel_args.cp_type}' when cp_size > 1.")
else:
return self._get_batch_samples
def _get_batch_samples_ulysses(self, parallel_state, epoch_iterator, num_batches):
cp_size = parallel_state.get_group_size("cp")
cp_rank = parallel_state.get_rank("cp")
batch, batch_seqlens, num_items_in_batch = self._get_batch_samples(parallel_state, epoch_iterator, num_batches)
for sample in batch:
labels = torch.nn.functional.pad(sample['labels'], (0, 1), value=IGNORE_INDEX)
shift_labels = labels[..., 1:]
sample['shift_labels'] = shift_labels
if "position_ids" not in sample:
position_ids = torch.arange(0, shift_labels.shape[1], device=shift_labels.device).unsqueeze(0)
sample['position_ids'] = position_ids
for key, val in sample.items():
if key == 'attention_mask':
continue
if val is not None:
seq_dim = 1
seq_total_len = val.size(seq_dim)
chunk_size = seq_total_len // cp_size
remainder = seq_total_len % cp_size
if cp_rank < remainder:
start_idx = cp_rank * (chunk_size + 1)
end_idx = start_idx + (chunk_size + 1)
else:
start_idx = remainder * (chunk_size + 1) + (cp_rank - remainder) * chunk_size
end_idx = start_idx + chunk_size
val_sliced = val.narrow(seq_dim, start_idx, end_idx - start_idx)
if key == 'shift_labels':
val_sliced = val_sliced.contiguous()
device = torch.accelerator.current_device()
sample[key] = val_sliced.to(device, non_blocking=True)
return batch, batch_seqlens, num_items_in_batch
def _get_batch_samples(self, parallel_state, epoch_iterator, num_batches):
"""
Fetch num_batches samples from the iterator at once.
Args:
parallel_state: Parallel state object (unused, kept for interface consistency)
epoch_iterator: Data iterator for the current epoch
num_batches: Number of batches to fetch
Returns:
tuple: Three values with distinct purposes:
batch_samples (List[Dict[str, torch.Tensor]]):
Raw input data for model forward pass. Example: [{"input_ids": tensor1, "labels": tensor1_labels, ...}, ...]
batch_seqlens (List[int]):
Sequence length of each sample across all fetched batches. Example: [128, 256, 64, 192]
num_items_in_batch (int):
Total valid tokens across all fetched batches (distributed aggregation). Example: 640
"""
batch_samples = []
batch_seqlens = []
for _ in range(num_batches):
try:
device = torch.accelerator.current_device()
data = next(epoch_iterator)
data["input_ids"] = data["input_ids"].to(device, non_blocking=True)
data["labels"] = data["labels"].to(device, non_blocking=True)
if "attention_mask" in data:
data["attention_mask"] = data["attention_mask"].to(device, non_blocking=True)
if "position_ids" in data:
data["position_ids"] = data["position_ids"].to(device, non_blocking=True)
batch_samples.append(data)
sample_seqlens = []
if "attention_mask" in data:
sample_seqlens = data["attention_mask"].sum(-1).tolist()
elif "labels" in data:
sample_seqlens = (data["labels"].ne(-100)).sum(-1).tolist()
batch_seqlens.extend(sample_seqlens)
except StopIteration:
break
num_items_in_batch = self._get_num_items_in_batch(batch_samples)
return batch_samples, batch_seqlens, num_items_in_batch
def _get_num_items_in_batch(self, batch_samples):
"""
Calculate the number of valid tokens in a batch (i.e., labels != -100).
Aggregates this count across all ranks.
"""
num_items_in_batch = None
device = torch.accelerator.current_device()
ps = ParallelState()
count_num_items_in_batch = (
len(batch_samples) > 0
and "labels" in batch_samples[0]
)
if count_num_items_in_batch:
try:
num_items_in_batch = sum((batch["labels"].ne(-100)).sum().item() for batch in batch_samples)
except (TypeError, AttributeError):
pass
if num_items_in_batch is not None:
if dist.is_initialized():
num_items_tensor = torch.tensor(num_items_in_batch, device=device, dtype=torch.int64)
dist.all_reduce(num_items_tensor, op=dist.ReduceOp.SUM, group=ps.get_group("dp_fsdp"))
num_items_in_batch = num_items_tensor.item()
pc = getattr(self.args, "non_data_parallel_size", None)
if pc:
num_items_in_batch = num_items_in_batch // pc
return num_items_in_batch
def _compute_loss(self, inputs, return_outputs=False, num_items_in_batch=None):
"""
Computes the loss for the batch.
"""
args = self.args
device = torch.accelerator.current_device()
kwargs = {}
if num_items_in_batch is not None:
kwargs["num_items_in_batch"] = num_items_in_batch
labels = inputs['labels'].to(device, non_blocking=True)
if args.stage == 'pt':
inputs['labels'] = None
if self.optimization_args.chunk_loss_size and args.stage == 'pt':
loss_ctx, loss_mask = self._build_chunk_loss(labels, chunk_size=self.optimization_args.chunk_loss_size)
kwargs['loss_ctx'] = loss_ctx
kwargs['loss_mask'] = loss_mask
model_inputs = {**inputs, **kwargs}
outputs = self.model(**model_inputs)
if args.stage == 'pt' and "loss" not in outputs:
logits = outputs.logits.contiguous().float()
loss = self._compute_language_model_pretrain_loss(logits, labels, **kwargs)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(f"Model outputs have no loss key: {list(outputs.keys())}")
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
ps = ParallelState()
if dist.is_initialized() and args.stage != 'pt':
loss *= ps.get_group_size("dp_fsdp")
return (loss, outputs) if return_outputs else loss
def _compute_language_model_pretrain_loss(self, logits, labels, ignore_index: int = -100, **kwargs) -> torch.Tensor:
args = self.args
shift_labels = labels
shift_labels = shift_labels.reshape(-1)
logits = logits.view(-1, logits.shape[-1])
ps = ParallelState()
if ps.get_group_size("cp") > 1:
loss = F.cross_entropy(logits, shift_labels, reduction='sum', ignore_index=ignore_index)
num_items_in_batch = (labels.ne(ignore_index)).sum()
dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=ps.get_group("cp"))
dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=ps.get_group("cp"))
loss = loss / num_items_in_batch.item()
else:
if args.calculate_per_token_loss:
loss = F.cross_entropy(logits, shift_labels, reduction='none', ignore_index=ignore_index)
else:
loss = F.cross_entropy(logits, shift_labels, ignore_index=ignore_index)
return loss
def _log_metrics(self, grad_norm=None, batch_size=0):
"""
Logs training metrics:
1. Calculate average loss since last log.
2. Calculate throughput (elapsed time).
3. Log current Grad Norm.
"""
step_diff = self.global_step - self._global_step_last_logged
if step_diff == 0:
return
avg_loss = (self._total_loss_scalar - self._logging_loss_scalar) / step_diff
self._logging_loss_scalar = self._total_loss_scalar
self._global_step_last_logged = self.global_step
current_time = time.time()
if self._step_start_time is None:
elapsed_time_seconds = 0.0
else:
elapsed_time_seconds = current_time - self._step_start_time
self._step_start_time = current_time
elapsed_time_per_iteration_ms = (elapsed_time_seconds / step_diff) * 1000
throughput = (batch_size / (elapsed_time_seconds / step_diff)) if elapsed_time_seconds > 0 else 0.0
metrics = {
"loss": avg_loss,
"lr": self.lr_scheduler.get_last_lr()[0],
"epoch": self.epoch,
"global_step": self.global_step
}
if grad_norm is not None:
metrics["grad_norm"] = grad_norm
consumed_samples = self.global_step * batch_size
logger.info_rank0(
f"iteration: {self.global_step} | "
f"consumed samples: {consumed_samples} | "
f"elapsed time per iteration (ms): {elapsed_time_per_iteration_ms:.1f} | "
f"throughput(samples/s): {throughput:.1f} | "
f"learning rate: {metrics['lr']:.6E} | "
f"global batch size: {batch_size:4d} | "
f"lm loss: {avg_loss:.4f} | "
f"grad norm: {grad_norm:.4f}"
)
def _get_fsdp_root(self):
"""
Helper to get the inner FSDP module to access context managers like `no_sync`.
Handles different wrapping depths.
"""
if isinstance(self.model, FSDP):
return self.model
if hasattr(self.model, "module"):
if isinstance(self.model.module, FSDP):
return self.model.module
if hasattr(self.model.module, "model") and isinstance(self.model.module.model, FSDP):
return self.model.module.model
return self.model