# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is derived from torchtitan,
# https://github.com/pytorch/torchtitan/blob/v0.2.2/torchtitan/train.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os

import torch

from torchtitan.config.manager import ConfigManager
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer

from torchtitan_npu.train import (
    _patch_for_garbage_collection_run,
    _patch_for_parallel_dims_build_mesh,
    _patch_torchtitan_model_reshape_for_broadcast,
)

if __name__ == "__main__":
    init_logger()
    config_manager = ConfigManager()
    config = config_manager.parse_args()
    trainer: Trainer | None = None

    _patch_for_garbage_collection_run()
    _patch_for_parallel_dims_build_mesh()
    _patch_torchtitan_model_reshape_for_broadcast()

    if config.compile.enable:
        if config.model.name in ("deepseek_v3", "deepseek_v4", "deepseek_v32"):
            if config.model.name == "deepseek_v3":
                # pyrefly: ignore [missing-import]
                from torch_npu.op_plugin.meta._meta_registrations import (
                    npu_fusion_attention_forward as original_meta_func,
                )

                # Lazy imports to avoid requiring NPU hardware at module load time
                from torchtitan_npu.patches.torch_npu._meta_registrations import (
                    npu_fusion_attention_forward,
                )

                # MLA performs shape inference according to the value tensor
                original_meta_func.__code__ = npu_fusion_attention_forward.__code__

            try:
                # pyrefly: ignore [missing-import]
                import inductor_npu_ext  # noqa: F401
            except Exception as e:
                raise RuntimeError(
                    f"compile.enable is True for {config.model.name} model but inductor_npu_ext is not available. "
                    "Please install inductor_npu_ext before enabling compile. "
                    "See docs/torch_compile.md for installation instructions."
                ) from e

            if "npu_bypass_triton_codegen" in config.model.converters:
                raise RuntimeError(
                    f"{config.model.name} model with compile.enable=True should not use npu_bypass_triton_codegen. "
                    f"Please remove 'npu_bypass_triton_codegen' from model.converters in your config."
                )
        else:
            if "npu_bypass_triton_codegen" not in config.model.converters:
                raise RuntimeError(
                    f"{config.model.name} model with compile.enable=True requires npu_bypass_triton_codegen. "
                    "Please add 'npu_bypass_triton_codegen' to model.converters in your config."
                )

    if config.model.name in ("deepseek_v32", "deepseek_v4"):
        from torchtitan_npu.train import (
            _patch_init_for_dsa_set_loss_scale,
            _patch_train_step_for_dsa_indexer_loss,
        )

        _patch_train_step_for_dsa_indexer_loss()
        _patch_init_for_dsa_set_loss_scale()

        from torchtitan_npu.train import _patch_for_train_npu_memory

        _patch_for_train_npu_memory()

    if config.model.name == "llama4":
        from torchtitan_npu.tools.checkpoint_patch import (
            patch_llama4_checkpoint_support,
        )

        patch_llama4_checkpoint_support()

    if config.model.name == "deepseek_v3":
        from torchtitan_npu.tools.checkpoint_patch import patch_dsv3_checkpoint_support

        patch_dsv3_checkpoint_support()

    try:
        trainer = Trainer(config)

        if config.checkpoint.create_seed_checkpoint:
            assert (
                int(os.environ["WORLD_SIZE"]) == 1
            ), "Must create seed checkpoint using a single device, to disable sharding."
            assert (
                config.checkpoint.enable
            ), "Must enable checkpointing when creating a seed checkpoint."
            trainer.checkpointer.save(curr_step=0, last_step=True)
            logger.info("Created seed checkpoint")
        else:
            trainer.train()
    except Exception:
        if trainer:
            trainer.close()
        raise
    else:
        trainer.close()
        torch.distributed.destroy_process_group()
        logger.info("Process group destroyed")