import logging
from datetime import datetime
import torch
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.utils.dtype import get_dtype
from mindspeed_mm.fsdp.distributed.fully_shard_parallel import pregather_fsdp_params
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.utils.utils import move_to_device, get_time, configure_hsdp_gradient_sync, tensor_to_dtensor, \
report_memory
from mindspeed_mm.fsdp.data.data_utils.utils import build_iterations
from mindspeed_mm.fsdp.optimizer.clip_grad_norm import clip_grad_norm
from mindspeed_mm.fsdp.tools.profiler import Profiler
from mindspeed_mm.fsdp.tools.memory_profiler import memory_profiler
from mindspeed_mm.fsdp.loss.loss_func import build_loss_func
from mindspeed_mm.fsdp.params.argument import Arguments
from mindspeed_mm.fsdp.utils.lora_utils import load_state_dict
from mindspeed_mm.fsdp.data.dataloader.dataloader import Preloader
from mindspeed_mm.fsdp.checkpoint.hf_load_utils import load_hf_checkpoint, looks_like_hf_weight_dir
from mindspeed_mm.fsdp.utils.constants import MEMORY_REPORT_ITERATION
logger = logging.getLogger(__name__)
class TrainEngine:
"""Training engine that manages the main training loop and operations."""
def __init__(
self,
args: Arguments,
train_dataloader,
model,
optimizer,
scheduler,
checkpointer,
lora_weight_manager=None,
**kwargs,
):
self.args = args
self.model = model
self.train_dataloader = train_dataloader
self.optimizer = optimizer
self.lr_scheduler = scheduler
self.checkpointer = checkpointer
self.lora_weight_manager = lora_weight_manager
self.iteration, self.consumed_train_samples = 0, 0
if args.training.load:
self.iteration, self.consumed_train_samples = self.load()
if (
args.training.init_model_with_meta_device
and args.training.lora.enable
and args.training.lora.pretrained_lora_path
):
lora_state_dict = load_state_dict(args.training.lora.pretrained_lora_path)
model_state_dict = model.state_dict()
for key, value in lora_state_dict.items():
if key in model_state_dict:
target_tensor = model_state_dict[key]
device_mesh = getattr(target_tensor, "device_mesh", None)
placements = getattr(target_tensor, "placements", None)
if device_mesh is not None and placements is not None:
target_tensor.copy_(tensor_to_dtensor(value, device_mesh, placements))
else:
target_tensor.copy_(value)
print_rank(
logger.info,
f"Reloaded {len(lora_state_dict)} LoRA parameters from {args.training.lora.pretrained_lora_path}",
)
self.profiler = Profiler(args.tools.profile)
self.profiler.start()
def average_losses_across_data_parallel_group(self, losses):
"""Reduce a tensor of losses across all GPUs."""
ps = get_parallel_state()
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=ps.get_dp_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(group=ps.get_dp_group())
return averaged_losses
def get_batch(self, data_iterator):
"""Generate a batch."""
if data_iterator is not None:
batch = next(data_iterator)
else:
raise ValueError("Data iterator is None. Unable to retrieve batch.")
if not self.args.data or not self.args.data.dataloader_param.enable_preload:
param_dtype = self.args.parallel.fsdp_plan.param_dtype
batch = move_to_device(batch, get_dtype(param_dtype) if param_dtype else None)
return batch
def set_loss_func(self, batch_data):
args = self.args
if args.features.loss_cfg.loss_type == "raw":
return
chunk_size = args.features.chunkloss_plan.chunk_size if args.features.enable_chunk_loss else None
if args.features.enable_dynamic_chunk_loss:
batch_data['total_chunk_size'] = args.features.chunkloss_plan.total_chunk_size
loss_func = build_loss_func(args.features.loss_cfg.loss_type, chunk_size=chunk_size, **batch_data)
if hasattr(self.model, "loss_function"):
self.model.loss_function = loss_func
else:
setattr(self.model, "loss_function", loss_func)
output_router_logits = args.features.loss_cfg.router_aux_loss_coef > 0.0
if output_router_logits:
batch_data.update(output_router_logits=True)
def train_step(self, train_dataloader_iter):
"""Perform a single training step with gradient accumulation."""
args = self.args
total_loss = 0
total_aux_loss = None
all_mtp_loss = None
for step in range(args.training.gradient_accumulation_steps):
batch_data = self.get_batch(train_dataloader_iter)
self.set_loss_func(batch_data)
is_last_step = step == args.training.gradient_accumulation_steps - 1
configure_hsdp_gradient_sync(self.model, is_last_step)
output = self.model(**batch_data, use_cache=False)
loss = output.loss / args.training.gradient_accumulation_steps
total_loss += loss
mtp_loss = getattr(output, 'mtp_loss', None)
if mtp_loss is not None:
final_mtp_loss = torch.mean(torch.stack(mtp_loss)) / args.training.gradient_accumulation_steps
loss += final_mtp_loss * getattr(args.model, 'mtp_loss_scaling_factor', 0.1)
if all_mtp_loss is None:
all_mtp_loss = [torch.zeros_like(loss) for loss in mtp_loss]
for i in range(len(mtp_loss)):
all_mtp_loss[i] += mtp_loss[i] / args.training.gradient_accumulation_steps
loss.backward()
if getattr(output, 'aux_loss', None) is not None:
aux_loss = output.aux_loss / args.training.gradient_accumulation_steps
total_aux_loss = aux_loss if total_aux_loss is None else total_aux_loss + aux_loss
loss_dict = {}
total_loss = self.average_losses_across_data_parallel_group([total_loss])
loss_dict["loss"] = total_loss.item()
if all_mtp_loss:
for i in range(len(all_mtp_loss)):
all_mtp_loss[i] = self.average_losses_across_data_parallel_group([all_mtp_loss[i]])
loss_dict[f"mtp_{i+1} loss"] = all_mtp_loss[i].item()
if total_aux_loss:
loss_dict["aux loss"] = total_aux_loss.item()
return loss_dict
def train(self):
"""Main training loop."""
args = self.args
train_dataloader_iter, _, _ = build_iterations(self.train_dataloader)
param_dtype = get_dtype(args.parallel.fsdp_plan.param_dtype) if args.parallel.fsdp_plan.param_dtype else None
if args.data and args.data.dataloader_param.enable_preload:
train_dataloader_iter = Preloader(train_dataloader_iter, param_dtype=param_dtype)
self.model.train()
curr_step_lr = self.lr_scheduler.get_last_lr()[0]
while self.iteration < args.training.train_iters:
memory_profiler.step()
start_time = get_time(barrier=True)
if self.args.parallel.fsdp_plan.pregather:
pregather_fsdp_params(self.model)
loss_dict = self.train_step(train_dataloader_iter)
grad_norm = clip_grad_norm(
self.model, max_norm=args.training.clip_grad, foreach=args.training.clip_grad_foreach
)
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.consumed_train_samples += args.training.global_batch_size
self.iteration += 1
elapsed_time_per_iteration = get_time(barrier=True) - start_time
self.profiler.step()
if self.iteration % args.training.log_interval == 0:
self.training_log(
self.iteration,
elapsed_time_per_iteration,
curr_step_lr,
self.consumed_train_samples,
loss_dict,
grad_norm,
)
if self.iteration == MEMORY_REPORT_ITERATION:
report_memory('(after {} iterations)'.format(self.iteration))
curr_step_lr = self.lr_scheduler.get_last_lr()[0]
if (
args.training.save
and args.training.save_interval > 0
and self.iteration % args.training.save_interval == 0
):
self.save(self.iteration, self.consumed_train_samples)
self.profiler.stop()
memory_profiler.stop()
if args.training.save:
self.save(self.iteration, self.consumed_train_samples)
def training_log(
self, iteration, elapsed_time_per_iteration, curr_step_lr, consumed_train_samples, loss_dict, grad_norm
):
args = self.args
log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.training.train_iters)
log_string += ' consumed samples: {:12d} |'.format(consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.6E} |'.format(curr_step_lr)
log_string += ' global batch size: {:5d} |'.format(args.training.global_batch_size)
for name, value in loss_dict.items():
log_string += f" {name}: {value:.6E} |"
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
print_rank(logger.info, log_string)
def load(self):
"""Load checkpoint and restore training state."""
args = self.args
iteration, consumed_train_samples = 0, 0
state = {"model": self.model, "extra_state": {}}
if not args.training.no_load_optim:
state["optimizer"] = self.optimizer
fmt = args.training.load_format
if fmt == "auto":
fmt = "hf" if looks_like_hf_weight_dir(args.training.load) else "dcp"
if fmt == "hf":
release = load_hf_checkpoint(
self.model, args.training.load,
load_rank0_and_broadcast=args.training.load_rank0_and_broadcast,
enable_lora=args.training.lora.enable,
load_strict=args.training.load_strict,
)
else:
release = self.checkpointer.load(
path=args.training.load,
state=state,
load_rank0_and_broadcast=args.training.load_rank0_and_broadcast,
load_strict=args.training.load_strict,
enable_lora=args.training.lora.enable,
)
if not release:
iteration = state["extra_state"]["iteration"]
consumed_train_samples = state["extra_state"]["consumed_train_samples"]
self.lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"])
if self.train_dataloader is not None:
self.train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"])
if not args.training.no_load_rng:
if "torch_rng_state" not in state["extra_state"]:
print_rank(logger.warning, "No RNG state found in checkpoint, skipping RNG loading")
else:
torch.set_rng_state(state["extra_state"]["torch_rng_state"])
torch.distributed.barrier()
return iteration, consumed_train_samples
def save(self, iteration, consumed_train_samples):
"""Save checkpoint with model, optimizer, and training state."""
args = self.args
if args.training.lora.enable:
if self.lora_weight_manager is not None:
self.lora_weight_manager.save_lora_only(
save_path=args.training.save,
iteration=iteration,
)
return
state = {
"model": self.model,
"extra_state": {
"iteration": iteration,
"consumed_train_samples": consumed_train_samples,
"lr_scheduler": self.lr_scheduler.state_dict(),
"train_dataloader": self.train_dataloader.state_dict(),
},
}
if not args.training.no_save_optim:
state["optimizer"] = self.optimizer
if not args.training.no_save_rng:
state["extra_state"]["torch_rng_state"] = torch.get_rng_state()
self.checkpointer.save(args.training.save, state=state, iteration=iteration)
torch.distributed.barrier()