import os
import torch
import torch.distributed as dist
from typing import Any, Type
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig

from mindspeed_llm.fsdp2.models.model_registry import ModelRegistry
from mindspeed_llm.fsdp2.distributed.mindspeed_parallel_engine import MindSpeedParallelEngine
from mindspeed_llm.fsdp2.distributed.parallel_engine_config import (
    ParallelEngineConfig,
    FSDPPlanConfig,
    TPPlanConfig,
    EPPlanConfig,
    CPPlanConfig,
    QuantizeConfig
)

from mindspeed_llm.fsdp2.utils.logging import get_logger
from mindspeed_llm.fsdp2.models.model_loader import ModelLoader

logger = get_logger(__name__)


# ==============================================================================
# ModelFactory
# ==============================================================================
class ModelFactory:
    """
    Responsible for building HuggingFace native models and wrapping them 
    as MindSpeed FSDP instances based on parallelization arguments.
    
    Supports two initialization modes controlled by model_args.init_model_with_meta_device:
    - False: Load model fully on CPU (original behavior)
    - True: Create empty model on meta device, load weights after FSDP wrapping
    """

    @staticmethod
    def create(model_args, parallel_args):
        """
        Creates a MindSpeed FSDP wrapped model.
        
        Args:
            model_args: Contains model_name_or_path, trust_remote_code, train_from_scratch, 
                        init_model_with_meta_device, etc.
            parallel_args: Contains tp_size, fsdp_size, recompute, ep_size, etc.
        """
        # 1. Setup Device
        # Ensure NPU is being used
        if torch.accelerator.is_available():
            local_rank = int(os.environ.get("LOCAL_RANK", 0))
            target_device = torch.device(torch.accelerator.current_accelerator().type, local_rank)
            torch.accelerator.set_device(target_device)
        else:
            target_device = torch.device("cpu")

        # 2. Determine initialization device based on init_model_with_meta_device flag
        use_meta_device = getattr(model_args, 'init_model_with_meta_device', False)
        init_device = "meta" if use_meta_device else "cpu"
        logger.info_rank0(f"> Model initialization device: {init_device} (init_model_with_meta_device={use_meta_device})")

        # 3. Load HF Config
        logger.info_rank0(f"> Loading AutoConfig from {model_args.model_name_or_path}...")
        trust_remote_code = model_args.trust_remote_code
        hf_config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=trust_remote_code
        )

        # 4. Load HF Model
        # Decide loading method based on init_device and whether training from scratch or fine-tuning.
        # if model_id is configured, load model according to model_id.
        model_cls = None
        if getattr(model_args, 'model_id', None):
            logger.info_rank0(f"> Using factory mode with model_id: {model_args.model_id}")
            model_cls = ModelRegistry.get_model_class(model_args.model_id)
            if hasattr(model_cls, 'register_patches'):
                model_cls.register_patches(model_args)

        # Use ModelLoader to create model based on init_device
        loader = ModelLoader(model_args, init_device=init_device)
        model, weights_path = loader.create_model(model_cls=model_cls)

        # 5. Build MindSpeed FSDP Configuration
        # Dynamically calculate Data Parallel (DP) Size
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        
        # Guard against division by zero if args are not set correctly
        tp_size = parallel_args.tp_size
        fsdp_size = parallel_args.fsdp_size
        cp_size = parallel_args.cp_size
        dp_size = world_size // (tp_size * fsdp_size * cp_size)

        parallel_config = ModelFactory._build_parallel_config(model_args, parallel_args, dp_size)

        # 6. Wrap & Move
        logger.info_rank0(f"> Wrapping model with MindSpeed FSDP (TP={tp_size},CP={cp_size}, FSDP={fsdp_size})...")

        # MindSpeed FSDP will shard and wrap the CPU model based on the config.
        # The wrapped model automatically handles forward/backward communication.
        # Pass init_device and weights_path for meta device support.
        model = MindSpeedParallelEngine(
            config=parallel_config, 
            model=model,
            init_device=init_device,
            weights_path=weights_path
        )

        # 7. Move to target device
        # For cpu mode: move to NPU after wrapping
        # For meta mode: weights are already loaded to device in MindSpeedParallelEngine
        if init_device == "cpu":
            model = model.to(target_device)

        return model

    @staticmethod
    def _build_parallel_config(model_args, parallel_args, dp_size) -> 'ParallelEngineConfig':
        """
        Builds the Config based on parallel_args and hardcoded layer name rules.
        Note: The wildcards here (e.g., 'model.layers.{*}') are suitable for standard structures like Llama/Qwen.
        If using other non-standard models, these strings might need adjustment.
        """
        # --- 1. FSDP Plan ---
        # Requirement: Apply FSDP to transformer layers
        apply_modules = {
            parallel_args.fsdp_modules[0]: {'reshard_after_forward': parallel_args.reshard_after_forward,
                                            'shard_placement_fn': parallel_args.shard_placement_fn},
        }
        for modules in parallel_args.fsdp_modules[1:]:
            apply_modules[modules] = {'reshard_after_forward': parallel_args.reshard_after_forward,}
        fsdp_plan = FSDPPlanConfig(
            ignored_modules=parallel_args.ignored_modules if parallel_args.ignored_modules else [],
            apply_modules= apply_modules,
            param_dtype=parallel_args.param_dtype,
            reduce_dtype=parallel_args.reduce_dtype,
            num_to_forward_prefetch=parallel_args.num_to_forward_prefetch,
            num_to_backward_prefetch=parallel_args.num_to_backward_prefetch
        )

        # --- 2. Tensor Parallel Plan ---
        # Requirement: Column Parallel for Q/K/V/Gate/Up, Row Parallel for O/Down
        tp_plan = TPPlanConfig(
            colwise_parallel=parallel_args.tp_colwise,
            rowwise_parallel=parallel_args.tp_rowwise
        )

        # --- 3. Expert Parallel Plan ---
        # For Mixture-of-Experts (MoE) models
        ep_size = parallel_args.ep_size
        ep_fsdp_size = parallel_args.ep_fsdp_size

        ep_plan = EPPlanConfig(
            apply_modules=parallel_args.ep_modules,
            apply_efsdp_modules=parallel_args.ep_fsdp_modules,
            dispatcher=parallel_args.ep_dispatcher,
        )


        cp_plan = CPPlanConfig(
            context_parallel_type=parallel_args.cp_type,
            is_pack=getattr(model_args, "pack", False)
        )

        # --- 4. Recompute Plan ---
        # Activation Checkpointing
        recompute_plan = parallel_args.recompute_modules if parallel_args.recompute else []

        # --- 5. Quantization Config ---
        quantization_plan = QuantizeConfig(
            recipe_name=model_args.quant_recipe_name,
            apply_modules=model_args.quant_apply_modules,
            ignored_modules=model_args.quant_ignored_modules,
            quant_converters=model_args.quant_converters,
            enable_fsdp_low_precision_all_gather=model_args.enable_fsdp_low_precision_all_gather,
            fsdp_low_precision_all_gather_mode=model_args.fsdp_low_precision_all_gather_mode,
        )

        # --- 6. Assemble Config ---
        # Get parallel sizes safely
        tp_size = parallel_args.tp_size
        fsdp_size = parallel_args.fsdp_size

        config = ParallelEngineConfig(
            # Parallelism parameters
            data_parallel_size=dp_size,

            fully_shard_parallel_size=fsdp_size,
            fsdp_plan=fsdp_plan,

            tensor_parallel_size=tp_size,
            tp_plan=tp_plan,

            # Expert Parallelism
            expert_parallel_size=ep_size,
            expert_fully_shard_parallel_size=ep_fsdp_size,
            expert_data_parallel_size=dp_size,  # Usually EP data parallel size matches global or has specific logic
            ep_plan=ep_plan,

            # Context Parallelism
            context_parallel_size=parallel_args.cp_size,
            context_parallel_type=parallel_args.cp_type,
            cp_plan=cp_plan,

            # Recomputation
            recompute=parallel_args.recompute,
            recompute_plan=recompute_plan,

            # Quantization
            quantization_plan = quantization_plan
        )

        return config


# ==============================================================================
# [Facade] AutoModelFactory
# Unified entry point used by AutoTrainer
# ==============================================================================
class AutoModelFactory:
    """
    Unified Factory for creating models.
    Dispatches to ModelFactory or McoreModelFactory based on the environment.
    """

    @staticmethod
    def create(*args, **kwargs):
        """
        Factory method that forwards arguments to the specific implementation.
        """
        return ModelFactory.create(*args, **kwargs)