import os
import torch
from torchtitan.config.manager import ConfigManager
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from torchtitan_npu.train import (
_patch_for_garbage_collection_run,
_patch_for_parallel_dims_build_mesh,
_patch_torchtitan_model_reshape_for_broadcast,
)
if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Trainer | None = None
_patch_for_garbage_collection_run()
_patch_for_parallel_dims_build_mesh()
_patch_torchtitan_model_reshape_for_broadcast()
if config.compile.enable:
if config.model.name in ("deepseek_v3", "deepseek_v4", "deepseek_v32"):
if config.model.name == "deepseek_v3":
from torch_npu.op_plugin.meta._meta_registrations import (
npu_fusion_attention_forward as original_meta_func,
)
from torchtitan_npu.patches.torch_npu._meta_registrations import (
npu_fusion_attention_forward,
)
original_meta_func.__code__ = npu_fusion_attention_forward.__code__
try:
import inductor_npu_ext
except Exception as e:
raise RuntimeError(
f"compile.enable is True for {config.model.name} model but inductor_npu_ext is not available. "
"Please install inductor_npu_ext before enabling compile. "
"See docs/torch_compile.md for installation instructions."
) from e
if "npu_bypass_triton_codegen" in config.model.converters:
raise RuntimeError(
f"{config.model.name} model with compile.enable=True should not use npu_bypass_triton_codegen. "
f"Please remove 'npu_bypass_triton_codegen' from model.converters in your config."
)
else:
if "npu_bypass_triton_codegen" not in config.model.converters:
raise RuntimeError(
f"{config.model.name} model with compile.enable=True requires npu_bypass_triton_codegen. "
"Please add 'npu_bypass_triton_codegen' to model.converters in your config."
)
if config.model.name in ("deepseek_v32", "deepseek_v4"):
from torchtitan_npu.train import (
_patch_init_for_dsa_set_loss_scale,
_patch_train_step_for_dsa_indexer_loss,
)
_patch_train_step_for_dsa_indexer_loss()
_patch_init_for_dsa_set_loss_scale()
from torchtitan_npu.train import _patch_for_train_npu_memory
_patch_for_train_npu_memory()
if config.model.name == "llama4":
from torchtitan_npu.tools.checkpoint_patch import (
patch_llama4_checkpoint_support,
)
patch_llama4_checkpoint_support()
if config.model.name == "deepseek_v3":
from torchtitan_npu.tools.checkpoint_patch import patch_dsv3_checkpoint_support
patch_dsv3_checkpoint_support()
try:
trainer = Trainer(config)
if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")