from typing import Optional, List

import inspect

import warnings

import dataclasses

from functools import reduce



from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE

from torch.distributed.checkpoint.storage import StorageWriter, StorageReader

from torch.distributed.checkpoint.planner import SavePlanner, LoadPlanner

from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, _EmptyStateDictLoadPlanner

from torch.distributed.checkpoint.logger import _dcp_method_logger

from torch.distributed.checkpoint.filesystem import _StoragePrefix





def partial_save_dcp_state_dict(

    state_dict: STATE_DICT_TYPE,

    storage_writer: StorageWriter,

    planner: Optional[SavePlanner] = None,

    part_idx: int = 0

):

    """

    Save a partial shard of a Distributed Checkpoint (DCP) state_dict.



    This function enables a single process (e.g., in a single-machine or testing context)

    to save its portion of model weights as torch_dcp format. It coordinates between

    a SavePlanner (which decides how tensors are laid out) and a StorageWriter (which

    handles actual I/O). The function returns global metadata (typically populated only

    by a coordinator rank in multi-process settings) and the results of write operations.



    Args:

        state_dict (STATE_DICT_TYPE): The local subset of the model state dict to save.

        storage_writer (StorageWriter): Handles writing data to persistent storage.

        planner (Optional[SavePlanner]): Custom save planner; uses DefaultSavePlanner if None.

        part_idx (int): Offset index used to generate unique storage prefixes (e.g., "__2_").



    Returns:

        Tuple[Optional[Metadata], List[Any]]:

            - global_metadata: checkpoint metadata.

            - all_writes: Results of the write operations.

    """



    # Use default planner if none is provided

    if planner is None:

        planner = DefaultSavePlanner()



    # Initialize global metadata; will be set during global planning phase

    global_metadata = None



    ckpt_kwargs = {}

    ckpt_id = getattr(storage_writer, "checkpoint_id", None)

    if ckpt_id is not None:

        ckpt_kwargs["checkpoint_id"] = ckpt_id



    @_dcp_method_logger(**ckpt_kwargs)

    def local_step():

        storage_meta = storage_writer.storage_meta()

        if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:

            warnings.warn(

                "The function definition for SavePlanner.set_up_planner has been updated"

                " to include the storage_meta argument. Please update your implementation"

                " to include this parameter."

            )

            planner.set_up_planner(state_dict, True)

        else:

            planner.set_up_planner(

                state_dict=state_dict,

                storage_meta=storage_meta,

                is_coordinator=True

            )

        storage_writer.set_up_storage_writer(True)



        local_plan = planner.create_local_plan()

        local_plan = storage_writer.prepare_local_plan(local_plan)

        return local_plan



    @_dcp_method_logger(**ckpt_kwargs)

    def global_step(all_local_plans):

        nonlocal global_metadata



        all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)

        all_local_plans = storage_writer.prepare_global_plan(all_local_plans)

        all_local_plans = [

            dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i+part_idx}_"))

            for i, plan in enumerate(all_local_plans)

        ]

        return all_local_plans



    local_plan = local_step()

    all_local_plan = global_step([local_plan])[0]



    @_dcp_method_logger(**ckpt_kwargs)

    def write_data():

        final_local_plan = planner.finish_plan(all_local_plan)

        all_writes = storage_writer.write_data(final_local_plan, planner)

        all_writes.wait()

        return all_writes.value()



    all_writes = write_data()



    # Return metadata and write results for potential post-processing or validation

    return global_metadata, all_writes





def save_metadata(

    global_metadata,

    all_writes,

    storage_writer: StorageWriter

):

    """

    Finalize a Distributed Checkpoint (DCP) by saving its global metadata.



    Args:

        global_metadata: The consolidated metadata describing the entire checkpoint.

                         Typically generated during the global planning phase.

        all_writes: Results from previous tensor write operations (e.g., list of

                    written file names or async write futures). Used by the writer

                    to finalize references in the metadata.

        storage_writer (StorageWriter): The writer responsible for persisting data

                                       and metadata to the underlying storage backend.

    """

    ckpt_kwargs = {}

    ckpt_id = getattr(storage_writer, "checkpoint_id", None)

    if ckpt_id is not None:

        ckpt_kwargs["checkpoint_id"] = ckpt_id



    @_dcp_method_logger(**ckpt_kwargs)

    def finish_checkpoint():

        storage_writer.finish(metadata=global_metadata, results=all_writes)

        return global_metadata



    return finish_checkpoint()





def merge_meta_info(

    global_meta_infos: List[Metadata],

):

    """

    Merge multiple DCP (Distributed Checkpoint) metadata objects into a single unified Metadata instance.



    This function is typically used when a checkpoint has been saved in multiple shards or parts

    (e.g., via partial saves), each producing its own Metadata object. The merge combines:

      - `state_dict_metadata`: mapping of tensor names to their storage/sharding info,

      - `planner_data`: auxiliary data used by the SavePlanner (e.g., layout hints, version info).



    It assumes that keys across shards are disjoint (i.e., no overlapping tensor names),

    so simple dictionary merging via `**` is safe.



    Args:

        global_meta_infos (List[Metadata]): A list of Metadata objects from individual shards.

                                            Must be non-empty to produce a valid result.



    Returns:

        Metadata: A merged Metadata object containing the union of all input metadata.

                  Returns None if the input list is empty.

    """



    # Use functools.reduce to iteratively merge all Metadata instances.

    # Start with the first element as the accumulator, then merge in the rest one by one.

    merged_data = reduce(

        lambda acc, x: acc.__class__(

            state_dict_metadata={**acc.state_dict_metadata, **x.state_dict_metadata},

            planner_data={**acc.planner_data, **x.planner_data}

        ),

        global_meta_infos[1:],  # All elements after the first

        global_meta_infos[0]    # Initial accumulator

    ) if global_meta_infos else None    # Only merge if the list is non-empty



    return merged_data





def load_metadata(

    storage_reader: StorageReader

):

    """

    Load the global metadata of a Distributed Checkpoint (DCP) from persistent storage.



    This function uses a `StorageReader` to read the checkpoint's central metadata file,

    which typically contains:

      - `state_dict_metadata`

      - `planner_data`



    The metadata is essential for correctly reconstructing the full state dict during loading.



    Args:

        storage_reader (StorageReader): An object capable of reading checkpoint data from

                                        the underlying storage backend (e.g., filesystem, cloud).



    Returns:

        Metadata: The deserialized global metadata object describing the entire checkpoint.

    """



    ckpt_kwargs = {}

    ckpt_id = getattr(storage_reader, "checkpoint_id", None)

    if ckpt_id is not None:

        ckpt_kwargs["checkpoint_id"] = ckpt_id



    @_dcp_method_logger(**ckpt_kwargs)

    def read_metadata():

        metadata = storage_reader.read_metadata()

        return metadata



    return read_metadata()





def partial_load_dcp_state_dict(

    metadata: Metadata,

    storage_reader: StorageReader,

    planner: Optional[LoadPlanner] = None,

):

    """

    Load a partial subset of a Distributed Checkpoint (DCP) state dictionary.



    This function is designed for scenarios where only a portion of the full model

    needs to be loaded (e.g., loading specific layers or shards). The input `metadata`

    describes only the relevant tensors to load—not the entire checkpoint.



    It uses a LoadPlanner and StorageReader to:

      - Plan which tensors to load based on the provided metadata,

      - Read the corresponding data from storage,

      - Populate a local `state_dict` in-place.



    Note: The resulting `state_dict` will contain only the keys covered by the input `metadata`.



    Args:

        metadata (Metadata): Partial metadata describing the subset of tensors to load.

                             Must include entries in `state_dict_metadata` for the desired keys.

        storage_reader (StorageReader): Handles reading tensor data from persistent storage.

        planner (Optional[LoadPlanner]): Custom load planner. If not provided, uses

                                        `_EmptyStateDictLoadPlanner`, which initializes an

                                        empty state dict and populates it during loading.



    Returns:

        STATE_DICT_TYPE: A state dictionary containing only the tensors specified in `metadata`.

    """



    # Initialize an empty state dict; it will be populated in-place by the planner during loading

    state_dict: STATE_DICT_TYPE = {}



    if planner is None:

        planner = _EmptyStateDictLoadPlanner()



    ckpt_kwargs = {}

    ckpt_id = getattr(storage_reader, "checkpoint_id", None)

    if ckpt_id is not None:

        ckpt_kwargs["checkpoint_id"] = ckpt_id



    @_dcp_method_logger(**ckpt_kwargs)

    def local_step():

        planner.set_up_planner(state_dict, metadata, True)

        storage_reader.set_up_storage_reader(metadata, True)



        local_plan = planner.create_local_plan()

        local_plan = storage_reader.prepare_local_plan(local_plan)

        return local_plan



    @_dcp_method_logger(**ckpt_kwargs)

    def global_step(all_local_plans):

        all_local_plans = planner.create_global_plan(all_local_plans)

        all_local_plans = storage_reader.prepare_global_plan(all_local_plans)

        return all_local_plans



    local_plan = local_step()

    central_plan = global_step([local_plan])[0]



    @_dcp_method_logger(**ckpt_kwargs)

    def read_data():

        final_local_plan = planner.finish_plan(central_plan)

        all_reads = storage_reader.read_data(final_local_plan, planner)

        all_reads.wait()



    read_data()

    return state_dict





def extract_metadata(

    selected_keys: List[str],

    metadata: Metadata

):

    """

    Extract a partial Metadata object containing only the entries corresponding to the given keys.



    This function filters a full DCP (Distributed Checkpoint) Metadata instance to produce a

    reduced version that includes only the tensors (or state dict keys) specified in `selected_keys`.

    It selectively subsets three core components of the metadata:

      - `state_dict_metadata`

      - `storage_data`

      - `planner_data`



    The resulting partial metadata can be used to load or save only a subset of the checkpoint,

    enabling efficient partial operations (e.g., loading specific layers of a large model).



    Args:

        selected_keys (List[str]): A list of fully qualified names (FQNs) of tensors to retain.

                                   Only metadata entries matching these keys will be included.

        metadata (Metadata): The complete metadata object from a DCP checkpoint.



    Returns:

        Metadata: A new Metadata instance containing only the entries associated with `selected_keys`.

    """



    # select metadata

    partial_state_dict_metadata = {}

    partial_storage_data_metadata = {}

    partial_planner_data = {}



    # Filter meta_items include only entries whose keys are in selected_keys

    for dcp_key, tensor_storage_metadata in metadata.state_dict_metadata.items():

        if dcp_key in selected_keys:

            partial_state_dict_metadata.update({dcp_key: tensor_storage_metadata})

    for metadataindex, storage_info in metadata.storage_data.items():

        if metadataindex.fqn in selected_keys:

            partial_storage_data_metadata.update({metadataindex: storage_info})

    for dcp_key, state_dict_key_tuple in metadata.planner_data.items():

        if dcp_key in selected_keys:

            partial_planner_data.update({dcp_key: state_dict_key_tuple})



    # Construct and return a new Metadata object with the filtered components

    partial_metadata = Metadata(

        state_dict_metadata=partial_state_dict_metadata,

        storage_data=partial_storage_data_metadata,

        planner_data=partial_planner_data

    )



    return partial_metadata