import os

import torch.distributed as dist
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_config import ProfilingModelInfo
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_parse import ProfilingParser
from mindspeed.auto_settings.utils.file_utils import restricted_read, restricted_write


class GatherNodeProfiling:
    """
    Gather other node profiling result to rank0
    """

    def __init__(self, profiling_file_path):
        self.profiling_file_path = profiling_file_path
        self.fusion_model = ProfilingModelInfo()
        self.stage_id_list = []
        self.logger = get_logger('profiling_parser')

    @staticmethod
    def _extend_stage_lists(source, target):
        source.time.extend(target.time)
        source.start_memory.extend(target.start_memory)
        source.peak_memory.extend(target.peak_memory)
        source.communication_info.extend(target.communication_info)
        source.operator_info.extend(target.operator_info)

    def fuse_node_pkl(self):
        """
        Args:
            pkl_path: str

        Returns:
            fusion_model: ProfilingModelInfo
        """
        pkl_path = os.path.join(self.profiling_file_path, 'pkl_path')
        pkl_files = sorted(os.listdir(pkl_path))
        if len(pkl_files) > 1:
            self.logger.info(f'Get pp profiling parse result.')
            for pkl_file in pkl_files:
                node_pkl_path = os.path.join(pkl_path, pkl_file)
                pkl_model = restricted_read(node_pkl_path)
                self._fuse_models(pkl_model)
        else:
            node_pkl_path = os.path.join(pkl_path, pkl_files[0])
            pkl_model = restricted_read(node_pkl_path)
            self.fusion_model = pkl_model
        return self.fusion_model

    def parse_node_pkl(self, args):
        parent_dir = os.path.dirname(self.profiling_file_path)
        at_node_path = os.path.join(parent_dir, f'at_{args.node_rank}.pkl')
        cfg = restricted_read(at_node_path)
        profiling_parser = ProfilingParser(self.profiling_file_path, search_cfg=cfg, args=args)
        profiling_res = profiling_parser.parser()
        if args.pipeline_model_parallel_size > 1 and profiling_parser.nodes > 1:
            ranks = [i * profiling_parser.devices_per_node for i in range(profiling_parser.nodes)]
            profiling_group = dist.new_group(ranks, backend=dist.Backend.GLOO)
            gather_objects = [None for _ in range(profiling_parser.nodes)]
            dist.all_gather_object(gather_objects, profiling_res, group=profiling_group)
            for i in range(profiling_parser.nodes):
                pkl_path = os.path.join(self.profiling_file_path, 'pkl_path')
                if not os.path.exists(pkl_path):
                    os.mkdir(pkl_path)
                pkl_node_path = os.path.join(pkl_path, f'node_{i}.pkl')
                restricted_write(pkl_node_path, gather_objects[i])

            dist.barrier(group=profiling_group)
            dist.destroy_process_group(group=profiling_group)
        else:
            pkl_path = os.path.join(self.profiling_file_path, 'pkl_path')
            if not os.path.exists(pkl_path):
                os.mkdir(pkl_path)
            pkl_node_path = os.path.join(pkl_path, f'node_{args.node_rank}.pkl')
            restricted_write(pkl_node_path, profiling_res)

    def _fuse_models(self, new_model):
        if new_model.stage_id not in self.stage_id_list:
            self.stage_id_list.append(new_model.stage_id)
            self.fusion_model.extend_stage_info(new_model)