# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# pylint: skip-file

import logging
from typing import Any, Callable, List, Optional

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from megatron.core.dist_checkpointing.dict_utils import nested_values
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedStateDict
from megatron.core.optimizer.clip_grads import count_zeros_fp32, get_grad_norm_fp32
from megatron.core.optimizer.optimizer import (
    ChainedOptimizer,
    Float16OptimizerWithFloat16Params,
    FP32Optimizer,
)

from mindspeed.core.optimizer.muon.utils import LegacyProcessGroupCollection, get_pg_rank, get_pg_size


logger = logging.getLogger(__name__)


class LayerWiseDistributedOptimizer(ChainedOptimizer):
    """Layer-wise distributed optimizer for Megatron-core models.

    Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer.
    Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW)
    When using, keep all megatron distributed-optimizer related options OFF.

    How LayerWiseDistributedOptimizer work:
    1. weights are splited into lists and each rank only keep its shard in its optimizer
    2. Megatron DDP handle allreduce grad, note that each rank have full model and grad
    3. optimizer is already modified so only param belong to this DP rank is updated
    4. grad_norm and zero counting will reduce metrics globally in step function
    5. Do regular update with chained optimizers, modified optimizer only update shard
    6. allgather updated params to every rank
    """

    def __init__(
        self,
        optimizers: List[torch.optim.Optimizer],
        config,
        pg_collection: Optional[Any] = None,
        init_state_fn_list: Optional[List[Callable]] = None,
        model_chunks: Optional[List] = None,
    ) -> None:
        """
        Initialize LayerWiseDistributedOptimizer.

        Args:
            optimizers: List of MegatronOptimizers.
            config: OptimizerConfig.
            pg_collection: ProcessGroupCollection.
            init_state_fn_list: List of init state functions.
            model_chunks: DDP-wrapped model chunks (needed for overlap_param_gather).
        """
        if pg_collection is None:
            self.pg_collection = LegacyProcessGroupCollection()
        else:
            self.pg_collection = pg_collection
        self.dp_group = self.pg_collection.dp_cp
        self.expt_dp_group = self.pg_collection.expt_dp
        self.dp_cp_params_list = None
        self.dp_params_list = None
        self.expt_dp_params_list = None
        self.overlap_param_gather = getattr(config, "overlap_param_gather", False)
        self.model_chunks = model_chunks

        self.shard_params(optimizers)
        if self.overlap_param_gather:
            if model_chunks is None:
                raise ValueError("model_chunks must be provided if overlap_param_gather is True")
            self.set_bucket_layerwise_params_list(model_chunks)

        if init_state_fn_list is None:
            init_state_fn_list = [None] * len(optimizers)
        if len(init_state_fn_list) != len(optimizers):
            raise ValueError("init_state_fn_list must match optimizers length")

        wrapped_optimizers = []
        for optimizer, init_state_fn in zip(optimizers, init_state_fn_list):
            if isinstance(optimizer, (Float16OptimizerWithFloat16Params, FP32Optimizer)):
                raise TypeError(
                    "LayerWiseDistributedOptimizer expects base torch optimizers, "
                    f"got {type(optimizer).__name__}. Do not pre-wrap with Megatron optimizers."
                )
            if getattr(config, "bf16", False):
                # Wrap base torch optimizers with Float16 for bf16 training.
                # Callers pass base optimizers; wrapping happens here *after*
                # shard_params so master weights are only created for the local shard.
                wrapped_optimizer = Float16OptimizerWithFloat16Params(optimizer, config, None, init_state_fn)
            else:
                wrapped_optimizer = FP32Optimizer(optimizer, config, init_state_fn)
            # Match Megatron dev's grad-norm duplicate filtering: non-TP params
            # are counted only on tensor-parallel rank 0, even for expert groups.
            setattr(wrapped_optimizer, "tp_group", self.pg_collection.tp)
            wrapped_optimizers.append(wrapped_optimizer)

        super().__init__(wrapped_optimizers)
        self.config = config

    def shard_params(self, optimizers: List[torch.optim.Optimizer]) -> None:
        """Shard all params into lists by rank."""
        # list of parameter are sorted by numel and assigned to ranks in ping-pong style
        # example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be
        # [[p0, p7, p8], [p1, p6, p9], [p2, p5], [p3, p4]]

        # simplify when dp_cp group size is 1
        dp_cp_size = get_pg_size(self.pg_collection.dp_cp)
        expt_dp_size = get_pg_size(self.pg_collection.expt_dp)
        if dp_cp_size == 1 and expt_dp_size == 1:
            return

        dp_cp_idx, expt_dp_idx = 0, 0
        # create ping-pong style loop so memory is more balanced
        dp_cp_loop = list(range(dp_cp_size)) + list(range(dp_cp_size))[::-1]
        expt_dp_loop = list(range(expt_dp_size)) + list(range(expt_dp_size))[::-1]
        self.dp_cp_params_list = [[] for _ in range(dp_cp_size)]
        self.expt_dp_params_list = [[] for _ in range(expt_dp_size)]

        # get all param groups
        param_groups = []
        for optimizer in optimizers:
            param_groups += optimizer.param_groups

        # sort param in all groups by param numel and assign to each rank evenly
        param_list = []
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                param_list.append((param, group_index))
        param_list.sort(key=lambda item: item[0].numel())

        param_groups_this_rank = [[] for _ in param_groups]
        dp_cp_rank = get_pg_rank(self.pg_collection.dp_cp)
        expt_dp_rank = get_pg_rank(self.pg_collection.expt_dp)
        # assign params to rank in ping-pong style loop
        for param, group_index in param_list:
            if param_groups[group_index].get("is_expert_parallel", False):
                owner = expt_dp_loop[expt_dp_idx]
                if owner == expt_dp_rank:
                    param_groups_this_rank[group_index].append(param)
                self.expt_dp_params_list[owner].append(param)
                expt_dp_idx = (expt_dp_idx + 1) % len(expt_dp_loop)
            else:
                owner = dp_cp_loop[dp_cp_idx]
                if owner == dp_cp_rank:
                    param_groups_this_rank[group_index].append(param)
                self.dp_cp_params_list[owner].append(param)
                dp_cp_idx = (dp_cp_idx + 1) % len(dp_cp_loop)

        # now we modify the group to only handle local params
        for group, params in zip(param_groups, param_groups_this_rank):
            group["params"] = params

        self.dp_params_list = self.dp_cp_params_list
        if dp_cp_size == 1:
            self.dp_cp_params_list = None
            self.dp_params_list = None
        # simplify when expt_dp group size is 1 or expert parallel is off
        if expt_dp_size == 1 or not self.expt_dp_params_list or len(self.expt_dp_params_list[0]) == 0:
            self.expt_dp_params_list = None

    def set_bucket_layerwise_params_list(self, model_chunks):
        """Map sharded params to DDP buckets for async all-gather.

        For each bucket in each model chunk's bucket groups, build per-rank param lists
        by cross-referencing the layer-wise sharded param lists with the bucket's params.

        Args:
            model_chunks: DDP-wrapped model chunks with bucket_groups.
        """
        for model_chunk in model_chunks:
            for group in model_chunk.bucket_groups:
                for bucket in group.buckets:
                    if self.dp_cp_params_list is not None:
                        bucket_params_list = [[] for _ in range(get_pg_size(self.pg_collection.dp_cp))]
                        for bucket_list, full_params_list in zip(bucket_params_list, self.dp_cp_params_list):
                            for param in full_params_list:
                                if param in bucket.params:
                                    bucket_list.append(param)
                    else:
                        bucket_params_list = [list(bucket.params_list)]
                    bucket.set_layerwise_params_list(bucket_params_list)
            # Do the same for expert parallel bucket groups.
            for group in getattr(model_chunk, "expert_parallel_bucket_groups", []):
                for bucket in group.buckets:
                    if self.expt_dp_params_list is not None:
                        bucket_params_list = [[] for _ in range(get_pg_size(self.pg_collection.expt_dp))]
                        for bucket_list, full_params_list in zip(bucket_params_list, self.expt_dp_params_list):
                            for param in full_params_list:
                                if param in bucket.params:
                                    bucket_list.append(param)
                    else:
                        # expt_dp_size == 1: single rank owns all params, no
                        # all-gather needed but data structures must be initialized.
                        bucket_params_list = [list(bucket.params_list)]
                    bucket.set_layerwise_params_list(bucket_params_list)

    @torch.no_grad()
    def allgather_params(self) -> None:
        """All-gather updated params from all ranks."""

        # helper function to flatten local params, all-gather,
        # unflatten and copy to model params
        def _allgather_helper(params_list, group):
            device = params_list[0][0].device
            dtype = params_list[0][0].dtype
            rank = get_pg_rank(group)
            dp_size = get_pg_size(group)
            # Flatten this rank's params.
            src = (
                _flatten_dense_tensors(params_list[rank])
                if len(params_list[rank]) > 0
                else torch.empty(0, device=device, dtype=dtype)
            )
            flat_sizes = [sum(p.numel() for p in params) for params in params_list]
            if max(flat_sizes) == 0:
                return

            # Allocate per-rank receive buffers with actual sizes (no padding).
            # PyTorch's NCCL backend handles uneven sizes in all_gather via
            # grouped send/recv internally. Reuse src for local rank's slot.
            gather_list = []
            for i in range(dp_size):
                if i == rank:
                    gather_list.append(src)
                else:
                    gather_list.append(torch.empty(flat_sizes[i], device=device, dtype=dtype))

            torch.distributed.all_gather(gather_list, src, group=group)

            # Unflatten and copy gathered params for each rank.
            for idx, params in enumerate(params_list):
                if len(params) == 0 or idx == rank:
                    continue
                updated_params = _unflatten_dense_tensors(gather_list[idx], params)
                for updated_p, model_p in zip(updated_params, params):
                    model_p.data.copy_(updated_p)

        if self.dp_cp_params_list is not None:
            _allgather_helper(self.dp_cp_params_list, self.pg_collection.dp_cp)
        if self.expt_dp_params_list is not None:
            _allgather_helper(self.expt_dp_params_list, self.pg_collection.expt_dp)

    @torch.no_grad()
    def get_grad_norm(self):
        # similar to dist opt, always aggregate globally
        grads_for_norm = []
        for optimizer in self.chained_optimizers:
            grads_for_norm += optimizer.get_main_grads_for_grad_norm()
        return get_grad_norm_fp32(grads_for_norm, grad_stats_parallel_group=None)

    @torch.no_grad()
    def count_zeros(self):
        params = []
        for optimizer in self.chained_optimizers:
            params += optimizer.get_parameters()
        return count_zeros_fp32(
            params,
            grad_stats_parallel_group=None,
            use_decoupled_grad=getattr(
                self.config,
                "use_precision_aware_optimizer_no_fp8_or_ds_fp8",
                getattr(self.config, "use_precision_aware_optimizer", False),
            ),
        )

    @torch.no_grad()
    def step(self):
        """step function for layer-wise optimizer."""
        update_successful, grad_norm, num_zeros_in_grad = super().step()

        # All gather updated params. If overlap_param_gather is True, the allgather
        # is deferred to the forward pre-hooks via DDP bucket infrastructure.
        if not self.overlap_param_gather:
            self.allgather_params()
        return update_successful, grad_norm, num_zeros_in_grad

    # TODO(deyuf): need to improve dist checkpointing design to properly handle this
    # fp32_from_fp16_params is list, each sub list could be empty if group is empty
    # this breaks dist checkpointing assumption since extract_sharded_base drop list structure
    # for now, we convert it to dict with index as key and convert back in load_state_dict
    def load_state_dict(self, state_dict):
        if len(self.chained_optimizers) == 1:
            wrapped_state_dict = {1: state_dict}
        else:
            wrapped_state_dict = state_dict
        for state in wrapped_state_dict.values():
            if "fp32_from_fp16_params" in state and isinstance(state["fp32_from_fp16_params"], dict):
                logger.info("[layerwise] converting fp32_from_fp16_params from dict to list")
                state["fp32_from_fp16_params"] = [value for _, value in sorted(state["fp32_from_fp16_params"].items())]
        super().load_state_dict(state_dict)

    def sharded_state_dict(self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False, **kwargs):
        """
        Sharded state dict for torch_dist format checkpointing.
        For fixed DP usage only, set replica_id to 0 for all ShardedTensor.
        """
        sharded_state_dict = super().sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs)

        # for fixed DP usage only
        for sharded_base in nested_values(sharded_state_dict):
            if hasattr(sharded_base, "replica_id"):
                if not (isinstance(sharded_base.replica_id, int) or len(sharded_base.replica_id) == 3):
                    raise ValueError(f"Expected replica_id as int or (PP, TP, DP), got: {sharded_base}")
                sharded_base.replica_id = (
                    0 if isinstance(sharded_base.replica_id, int) else (*sharded_base.replica_id[:2], 0)
                )

        # later code assume list but chained optimizer fallback to non-list if there's only one
        if len(self.chained_optimizers) == 1:
            wrapped_sharded_state_dict = {1: sharded_state_dict}
        else:
            wrapped_sharded_state_dict = sharded_state_dict

        # Adjust dict rank 0 output correct global metadata into common_dict
        for state in wrapped_sharded_state_dict.values():
            # wrap empty containers into LocalNonpersistentObject so it won't be saved/loaded
            # params is already wrapped, we only need to handle fp32_from_fp16_params and state
            # more details in load_state_dict comment
            if "fp32_from_fp16_params" in state:
                state["fp32_from_fp16_params"][:] = [
                    group if group else LocalNonpersistentObject(group) for group in state["fp32_from_fp16_params"]
                ]
                state["fp32_from_fp16_params"] = {
                    idx: value for idx, value in enumerate(state["fp32_from_fp16_params"])
                }
            if not state["optimizer"]["state"]:
                state["optimizer"]["state"] = LocalNonpersistentObject(state["optimizer"]["state"])
            # group keys(e.g. 'step') might be missing or not updated
            for idx, group in enumerate(state["optimizer"]["param_groups"]):
                # keep local param tensor so we only gather metadata
                local_params = group.pop("params")
                # save whether this group is empty, so we can use non-empty rank for metadata
                group["params"] = bool(local_params.unwrap())
                all_rank_groups = [None for _ in range(torch.distributed.get_world_size())]
                torch.distributed.all_gather_object(all_rank_groups, group)
                # find first non-empty group if it exists
                nonempty_rank_group = next(
                    (rank_group for rank_group in all_rank_groups if rank_group["params"]),
                    group,
                )
                nonempty_rank_group["params"] = local_params
                state["optimizer"]["param_groups"][idx] = nonempty_rank_group
        return sharded_state_dict

    def save_state_dict_to_file(self, filename: str) -> None:
        """Save the parameter state of the optimizer. For torch format only.
        Args:
            filename: The filename to save the parameter state.
        """
        torch.save(super().state_dict(), filename)

    def load_state_dict_from_file(self, filename: str) -> None:
        """Load the parameter state of the optimizer. For torch format only."""
        super().load_state_dict(torch.load(filename, weights_only=True))