import logging
from contextlib import nullcontext
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.distributed.fully_shard_parallel import pregather_fsdp_params
from mindspeed_mm.fsdp.optimizer.clip_grad_norm import clip_grad_norm
from mindspeed_mm.fsdp.tools.memory_profiler import memory_profiler
from mindspeed_mm.fsdp.utils.dtype import get_dtype
from mindspeed_mm.fsdp.utils.utils import get_time, move_to_device
from mindspeed_mm.fsdp.train.train_engine import TrainEngine
logger = logging.getLogger(__name__)
class FunasrTrainEngine(TrainEngine):
"""FunASR-specific training engine with custom step logic and epoch/split loop."""
def train_step(self, train_dataloader_iter):
"""
FunASR-specific train step:
- Handles (loss, stats, weight) return signature
- Applies FSDP no_sync for gradient accumulation
- Accumulates stats with weight weighting
"""
total_loss = 0.0
accumulated_stats = {}
total_weight = 0.0
accum_steps = self.args.training.gradient_accumulation_steps
for accum_step in range(accum_steps):
batch = self.get_batch(train_dataloader_iter)
batch = move_to_device(
batch,
get_dtype(self.args.parallel.fsdp_plan.param_dtype)
if self.args.parallel.fsdp_plan.param_dtype else None
)
sync_context = nullcontext
if hasattr(self.model, 'no_sync') and accum_step < accum_steps - 1:
sync_context = self.model.no_sync
with sync_context():
loss, stats, weight = self.model(**batch)
scaled_loss = loss / accum_steps
scaled_loss.backward()
total_loss += scaled_loss
if accum_step == 0:
accumulated_stats = {k: v * weight for k, v in stats.items()}
else:
for k, v in stats.items():
accumulated_stats[k] = accumulated_stats.get(k, 0) + v * weight
total_weight += weight
total_loss = self.average_losses_across_data_parallel_group([total_loss])
return total_loss
def train(self):
"""
FunASR-specific training loop with epoch/split dataloader logic.
Reuses parent's utility methods: training_log, save, load, profiler, etc.
"""
if not hasattr(self.train_dataloader, 'build_iter'):
raise RuntimeError("FunASR dataloader factory must have 'build_iter(epoch, data_split_i, start_step)' method")
self.model.train()
current_epoch = getattr(self, '_current_epoch', 0)
start_data_split_i = getattr(self, 'start_data_split_i', 0)
start_step = getattr(self, 'start_step', 0)
for epoch in range(current_epoch, self.args.training.max_epochs):
for data_split_i in range(start_data_split_i, self.train_dataloader.data_split_num):
dataloader_tr, dataloader_val = self.train_dataloader.build_iter(
epoch=epoch,
data_split_i=data_split_i,
start_step=start_step
)
if hasattr(dataloader_tr, 'batch_sampler'):
dataloader_tr.batch_sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader_tr)
while self.iteration < self.args.training.train_iters:
memory_profiler.step()
start_time = get_time(barrier=True)
if self.args.parallel.fsdp_plan.pregather and not isinstance(self.model, DDP):
pregather_fsdp_params(self.model)
try:
loss = self.train_step(dataloader_iter)
except StopIteration:
break
grad_norm = None
if self.args.training.clip_grad > 0:
grad_norm = clip_grad_norm(
self.model,
max_norm=self.args.training.clip_grad,
norm_type=self.args.training.clip_norm_type,
foreach=self.args.training.clip_grad_foreach
)
if not torch.isfinite(grad_norm):
logger.warning(f"Non-finite grad_norm ({grad_norm}). Skipping update.")
self.optimizer.zero_grad()
continue
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
self.profiler.step()
self.consumed_train_samples += self.args.training.global_batch_size
self.iteration += 1
elapsed_time = get_time(barrier=True) - start_time
curr_lr = self.lr_scheduler.get_last_lr()[0]
if self.iteration % self.args.training.log_interval == 0:
self.training_log(
self.iteration, elapsed_time, curr_lr,
self.consumed_train_samples, loss, loss.new_tensor(0.0), grad_norm
)
if (self.args.training.save and
self.args.training.save_interval > 0 and
self.iteration % self.args.training.save_interval == 0):
self.save(self.iteration, self.consumed_train_samples)
if self.iteration >= self.args.training.train_iters:
break
start_step = 0
start_data_split_i = 0
if self.iteration >= self.args.training.train_iters:
break
self.profiler.stop()
memory_profiler.stop()
if self.args.training.save:
self.save(self.iteration, self.consumed_train_samples)
def save(self, iteration, consumed_train_samples):
args = self.args
extra_state = {
"iteration": iteration,
"consumed_train_samples": consumed_train_samples,
"lr_scheduler": self.lr_scheduler.state_dict(),
}
if hasattr(self, '_funasr_dataloader') and hasattr(self._funasr_dataloader, 'state_dict'):
extra_state["funasr_dataloader"] = self._funasr_dataloader.state_dict()
if not args.training.no_save_rng:
extra_state["torch_rng_state"] = torch.get_rng_state()
state = {
"model": self.model,
"extra_state": extra_state,
}
if not args.training.no_save_optim:
state["optimizer"] = self.optimizer
self.checkpointer.save(args.training.save, state=state, iteration=iteration)
torch.distributed.barrier()
def load(self):
args = self.args
attempt_optim_load = not args.training.no_load_optim
state = {"model": self.model, "extra_state": {}}
if attempt_optim_load:
state["optimizer"] = self.optimizer
state_model_only = {"model": self.model, "extra_state": {}}
release = self.checkpointer.load(
path=args.training.load, state=state_model_only
)
state["extra_state"] = state_model_only["extra_state"]
if not release:
iteration = state["extra_state"]["iteration"]
consumed_train_samples = state["extra_state"]["consumed_train_samples"]
self.lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"])
if not args.training.no_load_rng:
if "torch_rng_state" not in state["extra_state"]:
print_rank(logger.warning, f"No RNG state found in checkpoint, skipping RNG loading")
else:
torch.set_rng_state(state["extra_state"]["torch_rng_state"])
else:
iteration, consumed_train_samples = 0, 0
torch.distributed.barrier()
return iteration, consumed_train_samples