import os
import sys
import types
from dataclasses import dataclass, field, fields
import torch
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
except ImportError:
pass
from transformers import AutoConfig, AutoModelForCausalLM, is_torch_npu_available
from mindspeed_llm.fsdp2.models.model_factory import ModelFactory
from mindspeed_llm.fsdp2.optim.optimizer import OptimizerFactory
from mindspeed_llm.fsdp2.optim.scheduler import SchedulerFactory
from mindspeed_llm.fsdp2.checkpoint.checkpoint_manager import CheckpointManager
from mindspeed_llm.fsdp2.train.trainer import Trainer
from mindspeed_llm.fsdp2.data.data_factory import DataFactory
from mindspeed_llm.fsdp2.data.tokenizer import TokenizerFactory
from mindspeed_llm.fsdp2.data.template import get_template_and_fix_tokenizer
from mindspeed_llm.fsdp2.utils.logging import setup_global_logging, get_logger
from mindspeed_llm.fsdp2.utils.arguments import (
ModelArguments, DataArguments, ParallelArguments, TrainingArguments, OptimizationArguments, fsdp2_parse_args
)
from mindspeed_llm.fsdp2.utils.global_vars import set_args
from mindspeed_llm.fsdp2.utils.train_monitor import TrainMonitor
from mindspeed_llm.fsdp2.utils.device import set_accelerator_compatible
from mindspeed.fsdp.utils.random import set_seed
from mindspeed.fsdp.utils.torch_patch import apply_hccl_premul_sum_patch
from mindspeed_llm.fsdp2.utils.coverage import auto_coverage
logger = get_logger(__name__)
@dataclass
class Arguments:
"""Root arguments class containing model, data, parallel, and training arguments."""
model: ModelArguments = field(default_factory=ModelArguments)
data: DataArguments = field(default_factory=DataArguments)
parallel: ParallelArguments = field(default_factory=ParallelArguments)
training: TrainingArguments = field(default_factory=TrainingArguments)
optimization: OptimizationArguments = field(default_factory=OptimizationArguments)
class MindSpeedAutoTrainer:
"""
AutoTrainer: Dependency Injection Container.
Based on FSDP2 Arguments (HfArgumentParser style).
"""
def __init__(self):
self._parse_args()
self._initialize(seed=self.training_args.seed)
self.rank = torch.distributed.get_rank()
self._print_parsed_args()
self.model = self._build_model()
self.tokenizer = self._build_tokenizer()
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.data_manager = self._build_data_manager(self.tokenizer, self.template)
self.optimizer = self._build_optimizer(self.model)
self.lr_scheduler = self._build_scheduler(self.optimizer)
self.checkpoint_manager = self._build_checkpointer()
self.train_monitor = self._build_monitor()
self.trainer = Trainer(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
data_manager=self.data_manager,
args=self.training_args,
parallel_args=self.parallel_args,
optimization_args=self.optimization_args,
data_args=self.data_args,
ckpt_manager=self.checkpoint_manager,
monitor=self.train_monitor,
tokenizer=self.tokenizer,
)
@staticmethod
def _initialize(seed: int):
"""
Static initialization method: Receives external seed and local_rank,
avoiding dependency on hardcoding or self.
"""
if is_torch_npu_available():
fallback = torch.npu
dist_backend = "hccl"
apply_hccl_premul_sum_patch()
elif torch.cuda.is_available():
fallback = torch.cuda
dist_backend = "nccl"
set_accelerator_compatible(fallback)
setup_global_logging(level="INFO")
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1:
target_device_index = env_local_rank
else:
target_device_index = 0
os.environ["LOCAL_RANK"] = str(target_device_index)
torch.accelerator.set_device_index(target_device_index)
torch.accelerator.set_device(target_device_index)
set_seed(seed, set_deterministic=True)
if "RANK" not in os.environ:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "0"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend=dist_backend,
rank=rank,
world_size=world_size
)
def train(self):
self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
def _parse_args(self):
root_args = fsdp2_parse_args(Arguments)
self.model_args = root_args.model
self.data_args = root_args.data
self.parallel_args = root_args.parallel
self.training_args = root_args.training
self.optimization_args = root_args.optimization
self.args = types.SimpleNamespace(**{
k: v for ns in [root_args.model, root_args.data, root_args.parallel, root_args.training, root_args.optimization]
for k, v in ns.__dict__.items()
})
set_args(self.args)
def _print_parsed_args(self):
arg_modules = [
("ModelArguments", self.model_args),
("DataArguments", self.data_args),
("ParallelArguments", self.parallel_args),
("TrainingArguments", self.training_args)
]
for module_name, arg_instance in arg_modules:
logger.info_plain_rank0(f"\n {module_name}")
logger.info_plain_rank0("-" * 60)
for f in fields(arg_instance):
val = getattr(arg_instance, f.name)
logger.info_plain_rank0(f" {f.name:<30} {val if val is not None else 'None'}")
def _build_tokenizer(self):
logger.info_rank0("> Building Tokenizer...")
return TokenizerFactory.create(self.model_args)
def _build_model(self):
logger.info_rank0("> Building FSDP2 Model...")
return ModelFactory.create(self.model_args, self.parallel_args)
def _build_optimizer(self, model):
logger.info_rank0("> Building Optimizer...")
return OptimizerFactory.create(
model=model,
ep_size=self.parallel_args.ep_size,
lr=self.training_args.lr,
optimizer_type=self.training_args.optimizer,
weight_decay=self.training_args.weight_decay,
betas=(self.training_args.adam_beta1, self.training_args.adam_beta2),
adam_epsilon=self.training_args.adam_epsilon
)
def _build_scheduler(self, optimizer):
logger.info_rank0("> Building LR Scheduler...")
if self.training_args.max_steps > 0:
max_steps = self.training_args.max_steps
else:
max_steps = 100000
return SchedulerFactory.create(
optimizer=optimizer,
train_steps=max_steps,
lr=self.training_args.lr,
lr_decay_style=self.training_args.lr_scheduler_type,
lr_warmup_ratio=self.training_args.warmup_ratio,
lr_min=self.training_args.min_lr
)
def _build_data_manager(self, tokenizer, template):
logger.info_rank0("> Building DataFactory...")
return DataFactory.create(
data_manager_type=self.data_args.data_manager_type,
model_args=self.model_args,
data_args=self.data_args,
parallel_args=self.parallel_args,
training_args=self.training_args,
stage="sft",
tokenizer=tokenizer,
template=template
)
def _build_monitor(self):
logger.info_rank0("> Building Monitor...")
hf_config = AutoConfig.from_pretrained(
self.model_args.model_name_or_path,
trust_remote_code=True
)
return TrainMonitor(self.training_args, hf_config)
def _build_checkpointer(self):
logger.info_rank0("> Building Checkpointer...")
return CheckpointManager
class AutoTrainer:
"""
Unified entry point for Training.
Dispatches to MindSpeedAutoTrainer (New) or McoreAutoTrainer (Old) based on configuration.
"""
def __init__(self):
logger.info_rank0(f">>> [AutoTrainer] Initializing MindSpeed FSDP backend...")
self.trainer = MindSpeedAutoTrainer()
def train(self):
"""Delegate to the implementation"""
self.trainer.train()
@auto_coverage
def main():
trainer = AutoTrainer()
trainer.train()
if __name__ == "__main__":
main()