516cc34c创建于 2024年7月24日历史提交
import time

_TRAIN_START_TIME = time.time()
import json
import os.path
import gc
import copy
from functools import wraps
import torch
import torch.nn
import torch_npu
from megatron.training import print_rank_0
from megatron.training.arguments import parse_args
from megatron.core import parallel_state
from megatron.core.parallel_state import get_embedding_group
from megatron.training import get_args
from megatron.training import get_timers
from megatron.training import training
from megatron.training.training import print_datetime
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core import mpu, tensor_parallel
from megatron.training.initialize import initialize_megatron, set_jit_fusion_options


class AutoPipeline:
    auto_pipeline = None

    def __init__(self, args):
        self.args = copy.deepcopy(args)
        self.context = {
            'module': []
        }
        self.modules_hooks = []
        self.profiling_step = 0
        self.stop_profiling_step = 5
        self.unit_mb = 1024 * 1024

    @staticmethod
    def get_memory_status():
        used_memory = torch.npu.memory_allocated()
        reserved_memory = torch.npu.memory_reserved()
        return used_memory, reserved_memory

    def _cal_tensor_size(self, tensor):
        try:
            return tensor.numel() * tensor.element_size() / self.unit_mb
        except ZeroDivisionError:
            return 0

    def pre_hook_func(self, state, sync: bool, *args, **kargs):
        if sync:
            torch.npu.synchronize()
        used_memory, _ = self.get_memory_status()
        torch.npu.reset_max_memory_allocated()
        state['memory'] = used_memory
        torch.npu.synchronize()
        state['time'] = time.time()
        size = 0
        for arg in args:
            if isinstance(arg, torch.Tensor):
                size += self._cal_tensor_size(arg)
            elif isinstance(arg, tuple) or isinstance(arg, list):
                for t in arg:
                    if isinstance(t, torch.Tensor):
                        size += self._cal_tensor_size(t)
        state['input'] = size

    def post_hook_func(self, state, sync: bool, *args, **kargs):
        if sync:
            torch.npu.synchronize()
        used_memory, _ = self.get_memory_status()
        max_mem = torch.npu.max_memory_allocated()
        state['peak_memory'] = max_mem - state['memory']
        state['memory'] = (used_memory - state['memory']) // self.unit_mb
        if 'pre_total_time' in state:
            state['forward_cnt'] += 1
            state['time'] = (time.time() - state['time']) * 1000
            state['pre_total_time'] += state['time']
            try:
                state['time'] = state['pre_total_time'] / state['forward_cnt']
            except ZeroDivisionError:
                state['time'] = 0
        else:
            state['forward_cnt'] = 0
            state['time'] = (time.time() - state['time']) * 1000
            state['pre_total_time'] = 0

    def forward_pre_hook(self, name, parent_ctx, ctx):
        if self.profiling_step < self.stop_profiling_step:
            ctx['name'] = name
            if 'layers' in parent_ctx:
                parent_ctx['layers'].append(ctx)

        def hook(module, *args, **kargs):
            if self.profiling_step < self.stop_profiling_step:
                if 'module' in self.context:
                    self.context['module'].append(ctx)
                self.pre_hook_func(ctx, True, *args, **kargs)

        return hook

    def forward_post_hook(self, ctx):
        def hook(module, *args, **kargs):
            if self.profiling_step < self.stop_profiling_step:
                self.post_hook_func(ctx, True, *args)
                if 'module' in self.context:
                    self.context['module'].pop()

        return hook

    def register_recursive_hook(self, prefix_name, model, ctx):
        for name, module in model.named_children():
            if 'layers' not in ctx:
                ctx['layers'] = []
            current_ctx = {}

            next_name = prefix_name + "." + name if prefix_name != "" else name
            pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(name, ctx, current_ctx))
            post_hook = module.register_forward_hook(self.forward_post_hook(current_ctx))
            self.modules_hooks.append(pre_hook)
            self.modules_hooks.append(post_hook)
            self.register_recursive_hook(next_name, module, current_ctx)

    def step_hook(self, model):
        self.profiling_step += 1

    def hook_step_func(self, step_func, models):
        def custom_step_func(*args, **kargs):
            result = step_func(*args, **kargs)
            if self.profiling_step < self.stop_profiling_step:
                used_memory, reserved_memory = self.get_memory_status()
                self.context['used_mem'] = used_memory // self.unit_mb
            if isinstance(models, list):
                for model in models:
                    self.step_hook(model)
            else:
                self.step_hook(models)
            return result

        return custom_step_func

    def get_comm_time(self, config, sync: bool):
        if torch.distributed.get_rank() == 0:
            if sync:
                torch.npu.synchronize()
            input_tensor = torch.ones(self.args.seq_length, self.args.micro_batch_size, self.args.hidden_size)
            start_time = time.time()
            p2p_communication.send_backward(input_tensor, config)
            comm_time = (time.time() - start_time) * 1000
            self.context['comm_time'] = comm_time
        else:
            self.context['comm_time'] = 0.028

    def get_modules_params_by_stages(self, init_memory, sync: bool):

        if self.args.pipeline_model_parallel_size == 2:
            self.context['first_stage_embed'] = self.args.padded_vocab_size * self.args.hidden_size
            self.context['last_stage_embed'] = self.args.padded_vocab_size * self.args.hidden_size
            attention_block = 3 * self.args.hidden_size * self.args.num_attention_heads * (
                    self.args.hidden_size / self.args.num_attention_heads) + self.args.hidden_size * self.args.hidden_size + self.args.hidden_size + self.args.hidden_size
            ffn_block = 3 * self.args.ffn_hidden_size * self.args.hidden_size + self.args.hidden_size + self.args.hidden_size
            per_trans_layer_param = attention_block + ffn_block
            per_trans_layer_param /= self.args.tensor_model_parallel_size
            self.context['per_trans_layer_param'] = per_trans_layer_param

        else:
            first_stage_param = 0
            per_trans_layer_param = 0
            last_stage_param = 0
            if sync:
                torch.npu.synchronize()
            first_stage_rank = 0
            last_stage_rank = torch.distributed.get_world_size() - 1
            layer_stage_rank = self.args.tensor_model_parallel_size

            first_stage_param = self.broadcast_param_in_ranks(first_stage_rank, first_stage_param, init_memory)
            last_stage_param = self.broadcast_param_in_ranks(last_stage_rank, last_stage_param, init_memory)
            per_trans_layer_param = self.broadcast_param_in_ranks(layer_stage_rank, per_trans_layer_param, init_memory)

            self.context['first_stage_embed'] = first_stage_param - per_trans_layer_param
            self.context['last_stage_embed'] = last_stage_param - per_trans_layer_param
            self.context['per_trans_layer_param'] = per_trans_layer_param

    def broadcast_param_in_ranks(self, src_rank, param, init_memory):
        if torch.distributed.get_rank() == src_rank:
            param = torch.npu.max_memory_allocated() / self.unit_mb - init_memory
        tmp_param = torch.cuda.IntTensor([param])
        torch.distributed.broadcast(tmp_param, src=src_rank)
        param = tmp_param.item()
        return param

    def update_args_for_profiling(self):
        args = get_args()
        if args.num_layers_per_virtual_pipeline_stage is None:
            args.num_layers = self.args.pipeline_model_parallel_size
            args.encoder_num_layers = self.args.pipeline_model_parallel_size
        args.train_iters = self.stop_profiling_step
        args.save = False
        args.log_interval = 10

    def restore_args_for_training(self):
        args = get_args()
        if args.num_layers_per_virtual_pipeline_stage is None:
            args.num_layers = self.args.num_layers
            args.encoder_num_layers = self.args.num_layers
        args.train_iters = self.args.train_iters
        args.optimizer = self.args.optimizer
        args.save = self.args.save
        args.log_interval = self.args.log_interval


def check_equal_model_configs(args, parsed_contents):
    model_index = 0
    for model_instance in parsed_contents:
        if args.hidden_size == model_instance["model_configs"]["hidden_size"] \
                and args.ffn_hidden_size == model_instance["model_configs"]["ffn_hidden_size"] \
                and args.seq_length == model_instance["model_configs"]["seq_length"] \
                and args.num_attention_heads == model_instance["model_configs"]["num_attention_heads"]:
            return model_index
        else:
            model_index += 1
    return -1


def check_equal_parallel_configs(args, parsed_content):
    for parallel_instance in parsed_content["autopipeline_policy"]:
        if args.num_layers == parallel_instance["num_layers"] \
                and args.pipeline_model_parallel_size == parallel_instance["pipeline_model_parallel_size"] \
                and args.tensor_model_parallel_size == parallel_instance["tensor_model_parallel_size"] \
                and args.save_memory_ratio == parallel_instance["ratio"]:
            return parallel_instance["num_layer_list"], parallel_instance["recompute_module_list"], parallel_instance[
                "recompute_type"]
    return None, None, None


def check_skip_profiling(args, config_file):
    if os.path.exists(config_file):
        with open(config_file) as config_json:
            config_contents = config_json.read()
        parsed_contents = json.loads(config_contents)
        index = check_equal_model_configs(args, parsed_contents)
        if index != -1:
            num_layer_list, recompute_module_list, recompute_type = check_equal_parallel_configs(args,
                                                                                                 parsed_contents[index])
            if num_layer_list:
                return True, [(num_layer_list, recompute_module_list, (0, [0]), recompute_type)]
    return False, None


def set_recompute_mode(models):
    for model in models:
        for name, module in model.named_modules():
            if str.isdigit(name) and name != "0":
                module.forward = hook_checkpoint_forward(module.forward)


def hook_checkpoint_forward(forward_func):
    def custom_forward(*args, **kargs):
        def inside_forward(*args):
            return forward_func(*args, **kargs)

        return tensor_parallel.checkpoint(inside_forward, None, *args)

    return custom_forward


def get_auto_pipeline(args):
    if AutoPipeline.auto_pipeline is None:
        AutoPipeline.auto_pipeline = AutoPipeline(args)
    return AutoPipeline.auto_pipeline


def initialize_cfg_from_args_wrapper(initialize_cfg_from_args):
    @wraps(initialize_cfg_from_args)
    def wrapper(*args, **kwargs):
        from mindspeed.core import training as mc_training
        argument = get_args()
        disable_mc2 = argument.automated_pipeline and not mc_training.policy
        if not disable_mc2:
            initialize_cfg_from_args(*args, **kwargs)
    return wrapper


def autopipeline_profiling(model_provider, model_type, forward_step_func, train_valid_test_dataset_provider,
                           process_non_loss_data_func, args):
    is_skip, policy = check_skip_profiling(args, config_file="autopipeline_config.json")
    if not is_skip:
        initialize_megatron(extra_args_provider=None,
                            args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
        set_jit_fusion_options()
        global _TRAIN_START_TIME
        start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
        torch.distributed.all_reduce(start_time_tensor,
                                     op=torch.distributed.ReduceOp.MIN)
        _TRAIN_START_TIME = start_time_tensor.item()
        print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
            time.time() - _TRAIN_START_TIME))
        print_datetime('after megatron is initialized')
        args = get_args()
        pipelining = get_auto_pipeline(args)
        pipelining.update_args_for_profiling()
        init_memory = torch.npu.max_memory_allocated() / pipelining.unit_mb
        models, optimizer, lr_scheduler = training.setup_model_and_optimizer(model_provider, model_type)
        optimizer.step = pipelining.hook_step_func(optimizer.step, models)
        config = training.get_model_config(models[0])

        if args.virtual_pipeline_model_parallel_size is not None:
            train_data_iterator = []
            valid_data_iterator = []
            for i in range(len(models)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                iterators = training.build_train_valid_test_data_iterators(
                    train_valid_test_dataset_provider)
                train_data_iterator.append(iterators[0])
                valid_data_iterator.append(iterators[1])
        else:
            train_data_iterator, valid_data_iterator, _ = training.build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
        if isinstance(models, list):
            for model in models:
                pipelining.register_recursive_hook("module", model, pipelining.context)
        else:
            pipelining.register_recursive_hook("module", models, pipelining.context)
        pipelining.get_modules_params_by_stages(init_memory, sync=True)
        set_recompute_mode(models)
        checkpointing_context = {}
        training.train(forward_step_func, models, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator,
                       process_non_loss_data_func, config, checkpointing_context)
        pipelining.get_comm_time(config, sync=True)

        timers = get_timers()
        if timers('interval-time'):
            timers('interval-time').stop(barrier=True)

        for hook_handle in pipelining.modules_hooks:
            hook_handle.remove()
        pipelining.modules_hooks.clear()
        pipelining.restore_args_for_training()

        for key, value in optimizer.optimizer.state.items():
            key.detach()
            key.grad = None
            key.storage().resize_(0)
            if "momentum_buffer" in value:
                value["momentum_buffer"].detach()
                value["momentum_buffer"].grad = None
                value["momentum_buffer"].storage().resize_(0)
        for ofg in optimizer.param_groups:
            if "params" in ofg:
                for og in ofg["params"]:
                    og.detach()
                    og.grad = None
                    og.storage().resize_(0)
        for md in models:
            for param in md.parameters():
                param.detach()
                param.grad = None
                param.storage().resize_(0)
            for param_tensor in md.state_dict():
                if md.state_dict()[param_tensor] is not None:
                    md.state_dict()[param_tensor].detach()
                    md.state_dict()[param_tensor].grad = None
                    md.state_dict()[param_tensor].storage().resize_(0)

        gc.collect()
        torch_npu.npu.empty_cache()
        time.sleep(5)
        return pipelining.context, policy
    else:
        print_rank_0("[INFO] Found existed automated pipeline policy, apply it directly.")
        return None, policy