import logging
from mindspeed_mm.fsdp.data.datasets.funasr.funasr_dataset import (
build_funasr_dataloader_factory,
)
from mindspeed_mm.fsdp.distributed.torch_parallelize import ParallelApplier
from mindspeed_mm.fsdp.models.funasr.modeling_funasr import get_funasr_model
from mindspeed_mm.fsdp.params.argument import Arguments, parse_args
from mindspeed_mm.fsdp.tools.memory_profiler import memory_profiler
from mindspeed_mm.fsdp.tools.profiler import Profiler
from mindspeed_mm.fsdp.train.trainer import Trainer
from mindspeed_mm.fsdp.tasks.funasr.train_engine import FunasrTrainEngine
from mindspeed_mm.config.config_manager import ConfigManager
logger = logging.getLogger(__name__)
class FunasrTrainer(Trainer):
def __init__(self, args: Arguments):
self.args = args
self.initialize()
self.setup_funasr_training()
self.trainer = FunasrTrainEngine(
args=args,
train_dataloader=self._funasr_dataloader,
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
checkpointer=self.checkpointer,
)
self.trainer._current_epoch = self._current_epoch
self.trainer.start_data_split_i = self.start_data_split_i
self.trainer.start_step = self.start_step
def setup_funasr_training(self):
"""Setup training for FunASR models using official FunASR dataloader logic."""
self.checkpointer = self.get_checkpointer()
self.model_parallel_applier = ParallelApplier(self.args.parallel, self.args.training)
self.model, self.tokenizer, self.frontend = get_funasr_model(self.args.model, self.model_parallel_applier)
self._validate_and_set_train_iters(self.args)
self.optimizer = self.get_optimizer()
self.lr_scheduler = self.get_scheduler()
self._funasr_dataloader = build_funasr_dataloader_factory(self.args.data, self.frontend, self.tokenizer)
self._current_epoch = 0
self.start_data_split_i = 0
self.start_step = 0
self.iteration = 0
self.consumed_train_samples = 0
memory_profiler.reset(self.args.tools.memory_profile)
self.profiler = Profiler(self.args.tools.profile)
self.profiler.start()
def get_scheduler(self):
"""Build learning rate scheduler."""
from funasr.schedulers import scheduler_classes
logging.info("Build scheduler")
scheduler_name = self.args.training.scheduler
if scheduler_name not in scheduler_classes:
raise ValueError(f"Invalid scheduler name: {scheduler_name}. Available schedulers: {list(scheduler_classes.keys())}")
scheduler_class = scheduler_classes[scheduler_name]
scheduler = scheduler_class(self.optimizer, **self.args.training.scheduler_conf)
return scheduler
if __name__ == "__main__":
args = ConfigManager(config_class=Arguments).load_and_parse()
trainer = FunasrTrainer(args=args)
trainer.train()