import os
import torch
from torchtitan.config import ConfigManager
from torchtitan.tools.logging import init_logger, logger
import torchtitan_npu
from torchtitan_npu.converters.registry import has_npu_converter
_SKIP_FLEX_TO_SDPA_REWRITE_MODELS = {"vlm"}
_INDUCTOR_NPU_EXT_MODELS = {"deepseek_v3", "deepseek_v4", "deepseek_v32", "vlm"}
_BYPASS_TRITON_CODEGEN = "npu_bypass_triton_codegen"
def _has_model_converter(model_converters, name: str) -> bool:
if model_converters is None or not hasattr(model_converters, "converters"):
return False
return has_npu_converter(model_converters.converters, name)
def _uses_inductor_npu_ext(model_name: str) -> bool:
return model_name in _INDUCTOR_NPU_EXT_MODELS
def _compile_requires_bypass_triton_codegen(model_name: str) -> bool:
return not _uses_inductor_npu_ext(model_name)
def main() -> None:
"""Main entry point for NPU training with new config system."""
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer = None
model_name = (
config.model_spec.name
if config.model_spec
else "unknown"
)
from torchtitan.models.common import FlexAttention, ScaledDotProductAttention
from torchtitan.models.common.decoder import Decoder
if (
config.model_spec
and model_name not in _SKIP_FLEX_TO_SDPA_REWRITE_MODELS
and isinstance(
config.model_spec.model,
Decoder.Config,
)
):
for layer_cfg in config.model_spec.model.layers:
if isinstance(layer_cfg.attention.inner_attention, FlexAttention.Config):
layer_cfg.attention.inner_attention = ScaledDotProductAttention.Config()
layer_cfg.attention.mask_type = "causal"
logger.info(
"Replaced FlexAttention with ScaledDotProductAttention for NPU compatibility"
)
if (
config.compile.enable
and config.activation_checkpoint.mode
!= "none"
):
logger.warning(
"There might be performance issues with activation checkpointing and torch.compile enabled!"
)
if config.compile.enable:
has_bypass_triton_codegen = _has_model_converter(
config.model_converters,
_BYPASS_TRITON_CODEGEN,
)
if _uses_inductor_npu_ext(model_name):
if model_name == "deepseek_v3":
try:
from torch_npu.op_plugin.meta._meta_registrations import (
npu_fusion_attention_forward as original_meta_func,
)
except ImportError:
logger.info(
"torch_npu meta registrations not available, skipping compile patch"
)
else:
from torchtitan_npu.patches.torch_npu._meta_registrations import (
npu_fusion_attention_forward,
)
original_meta_func.__code__ = npu_fusion_attention_forward.__code__
try:
import inductor_npu_ext
except Exception as e:
raise RuntimeError(
f"compile.enable is True for {model_name} model but inductor_npu_ext is not available. "
"Please install inductor_npu_ext before enabling compile. "
"See docs/torch_compile.md for installation instructions."
) from e
if has_bypass_triton_codegen:
raise RuntimeError(
f"{model_name} model with compile.enable=True should not use npu_bypass_triton_codegen. "
"Please remove 'npu_bypass_triton_codegen' from model.converters in your config."
)
else:
if not has_bypass_triton_codegen:
raise RuntimeError(
f"{model_name} model with compile.enable=True requires npu_bypass_triton_codegen. "
"Please add 'npu_bypass_triton_codegen' to model.converters in your config."
)
if model_name in ("deepseek_v32", "deepseek_v4"):
from torchtitan_npu.train import (
_patch_init_for_dsa_set_loss_scale,
_patch_train_step_for_dsa_indexer_loss,
)
_patch_train_step_for_dsa_indexer_loss()
_patch_init_for_dsa_set_loss_scale()
if model_name == "llama4":
from torchtitan_npu.tools.checkpoint_patch import (
patch_llama4_checkpoint_support,
)
patch_llama4_checkpoint_support()
if model_name == "deepseek_v3":
logger.warning(
"deepseek_v3 checkpoint patch is temporarily disabled due to config system migration."
)
try:
trainer = config.build()
if (
config.checkpoint.create_seed_checkpoint
):
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")
if __name__ == "__main__":
main()