56b31fce创建于 2021年9月3日历史提交
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context of auto parallel"""
import threading

import mindspore.context as context
import mindspore.log as logger
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver
from mindspore._c_expression import AutoParallelContext
from mindspore._checkparam import args_type_check, Validator

_MAX_GROUP_NAME_LEN = 127
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"


class _AutoParallelContext:
    """
    _AutoParallelContext is the environment in which operations are executed

    Note:
        Create a context through instantiating Context object is not recommended.
        Should use auto_parallel_context() to get the context since Context is singleton.
    """
    _instance = None
    _instance_lock = threading.Lock()

    def __init__(self):
        self._context_handle = AutoParallelContext.get_instance()
        self._dataset_strategy_using_str = True

    def __new__(cls):
        if cls._instance is None:
            cls._instance_lock.acquire()
            cls._instance = object.__new__(cls)
            cls._instance_lock.release()
        return cls._instance

    def check_context_handle(self):
        """
        Check context handle.

        Raises:
            ValueError: If the context handle is none.
        """
        if self._context_handle is None:
            raise ValueError("Context handle is none in context!!!")

    def set_device_num(self, device_num):
        """
        Set device num for auto parallel.

        Args:
            device_num (int): The device number.

        Raises:
            ValueError: If the device num is not in [1, 4096].
        """
        self.check_context_handle()
        if device_num < 1 or device_num > 4096:
            raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
        self._context_handle.set_device_num(device_num)

    def get_device_num(self):
        """Get device num."""
        self.check_context_handle()
        return self._context_handle.get_device_num()

    def set_global_rank(self, global_rank):
        """
        Set global rank for auto parallel.

        Args:
            global_rank (int): The rank id of current rank.

        Raises:
            ValueError: If the global rank is not in [1, 4096].
        """
        self.check_context_handle()
        if global_rank < 0 or global_rank > 4095:
            raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
        self._context_handle.set_global_rank(global_rank)

    def get_global_rank(self):
        """Get current rank id."""
        self.check_context_handle()
        return self._context_handle.get_global_rank()

    def set_pipeline_stages(self, stages):
        """Set the stages of the pipeline"""
        if isinstance(stages, bool):
            raise TypeError("The type of pipeline_stage_num must be int, but got bool.")
        if not isinstance(stages, int):
            raise TypeError("The type of pipeline_stage_num must be int.")
        if stages < 1:
            raise ValueError("pipeline_stage_num can't be less than 1.")
        backend = context.get_context("device_target")
        if backend == "GPU" and stages > 1:
            raise RuntimeError("Now GPU don't support pipeline parallel.")
        self.check_context_handle()
        self._context_handle.set_pipeline_stage_split_num(stages)

    def get_pipeline_stages(self):
        """Get the stages of the pipeline"""
        self.check_context_handle()
        return self._context_handle.get_pipeline_stage_split_num()

    def set_gradients_mean(self, gradients_mean):
        """
        Set gradients_mean flag.

        Note:
            If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.

        Args:
            gradients_mean (bool): The gradients_mean flag.
        """
        self.check_context_handle()
        self._context_handle.set_gradients_mean(gradients_mean)

    def get_gradients_mean(self):
        """Get gradients_mean flag."""
        self.check_context_handle()
        return self._context_handle.get_gradients_mean()

    def set_gradient_fp32_sync(self, gradient_fp32_sync):
        """
        Set gradient_fp32_sync.

        Note:
            If gradient_fp32_sync is true,
            it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.

        Args:
            gradient_fp32_sync (bool): The gradient_fp32_sync flag.
        """
        self.check_context_handle()
        self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)

    def get_gradient_fp32_sync(self):
        """Get gradient_fp32_sync flag."""
        self.check_context_handle()
        return self._context_handle.get_gradient_fp32_sync()

    def set_loss_repeated_mean(self, loss_repeated_mean):
        """
        Set loss_repeated_mean flag.

        Note:
            If loss_repeated_mean is true,
            Distributed automatic differentiation will perform a mean operator
            in backward in the case of repeated calculations.

        Args:
            loss_repeated_mean (bool): The loss_repeated_mean flag.
        """
        if not isinstance(loss_repeated_mean, bool):
            raise TypeError(f"The type of loss_repeated_mean must be bool, but got {type(loss_repeated_mean)}.")
        self.check_context_handle()
        self._context_handle.set_loss_repeated_mean(loss_repeated_mean)

    def get_loss_repeated_mean(self):
        """Get loss_repeated_mean flag."""
        self.check_context_handle()
        return self._context_handle.get_loss_repeated_mean()

    def set_parallel_mode(self, parallel_mode):
        """
        Set parallel mode for auto parallel.

        Args:
            parallel_mode (str): The parallel mode of auto parallel.

        Raises:
            ValueError: If parallel mode is not supported.
        """
        self.check_context_handle()
        ret = self._context_handle.set_parallel_mode(parallel_mode)
        if ret is False:
            raise ValueError("Parallel mode does not support {}".format(parallel_mode))

    def get_parallel_mode(self):
        """Get parallel mode."""
        self.check_context_handle()
        if _is_role_pserver():
            return context.ParallelMode.STAND_ALONE
        return self._context_handle.get_parallel_mode()

    def set_strategy_search_mode(self, auto_parallel_search_mode):
        """
        Set search mode of strategy.

        Args:
            auto_parallel_search_mode (str): The search mode of strategy.
        """
        self.check_context_handle()
        ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
        if ret is False:
            raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))

    def get_strategy_search_mode(self):
        """Get search mode of strategy."""
        self.check_context_handle()
        return self._context_handle.get_strategy_search_mode()

    def set_parameter_broadcast(self, parameter_broadcast):
        """
        Set parameter broadcast.

        Args:
            parameter_broadcast (bool): Parameter broadcast or not.
        """
        self.check_context_handle()
        self._context_handle.set_parameter_broadcast(parameter_broadcast)

    def get_parameter_broadcast(self):
        """Get parameter broadcast flag."""
        self.check_context_handle()
        return self._context_handle.get_parameter_broadcast()

    def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
        """
        Set strategy checkpoint load path.

        Args:
            strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint.
        """
        self.check_context_handle()
        self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)

    def get_strategy_ckpt_load_file(self):
        """Get strategy checkpoint load path."""
        self.check_context_handle()
        return self._context_handle.get_strategy_ckpt_load_file()

    def set_full_batch(self, full_batch):
        """
        Set whether load full batch on each device.

        Args:
            full_batch (bool): True if load full batch on each device.
        """
        self.check_context_handle()
        self._context_handle.set_full_batch(full_batch)

    def get_full_batch(self):
        """Get whether load full batch on each device."""
        self.check_context_handle()
        if _is_role_pserver():
            return False
        return self._context_handle.get_full_batch()

    def set_dataset_strategy(self, dataset_strategy):
        """
        Set dataset sharding strategy.

        Args:
            dataset_strategy (str or tuple(tuple)): The dataset sharding strategy.
        """
        self.check_context_handle()
        if isinstance(dataset_strategy, str):
            if dataset_strategy not in ("full_batch", "data_parallel"):
                raise ValueError("The dataset_strategy string should be 'full_batch' or 'data_parallel', "
                                 "otherwise, incoming tuple(tuple) type strategy")
            self._context_handle.set_full_batch(dataset_strategy == "full_batch")
            self._dataset_strategy_using_str = True
            return
        if not isinstance(dataset_strategy, tuple):
            raise TypeError(f'strategy must be str or tuple type, but got:{type(dataset_strategy)}')
        for ele in dataset_strategy:
            if not isinstance(ele, tuple):
                raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
            for dim in ele:
                if not isinstance(dim, int):
                    raise TypeError(f'The dim of each strategy value must be int type, but got:{type(dim)}')
        self._dataset_strategy_using_str = False
        self._context_handle.set_dataset_strategy(dataset_strategy)

    def get_dataset_strategy(self):
        """Get dataset sharding strategy."""
        self.check_context_handle()
        if self._dataset_strategy_using_str:
            if self._context_handle.get_full_batch():
                return "full_batch"
            return "data_parallel"
        return self._context_handle.get_dataset_strategy()

    def set_grad_accumulation_step(self, grad_accumulation_step):
        """
        Set grad accumulation step.

        Args:
            grad_accumulation_step (int): The grad accumulation step.
        """
        self.check_context_handle()
        Validator.check_positive_int(grad_accumulation_step)
        self._context_handle.set_grad_accumulation_step(grad_accumulation_step)

    def get_grad_accumulation_step(self):
        """Get grad accumulation step."""
        self.check_context_handle()
        return self._context_handle.get_grad_accumulation_step()

    def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
        """
        Set strategy checkpoint save path.

        Args:
            strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
        """
        self.check_context_handle()
        import os
        dir_path = os.path.dirname(strategy_ckpt_save_file)
        if dir_path and not os.path.exists(dir_path):
            os.makedirs(dir_path)
        self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)

    def get_strategy_ckpt_save_file(self):
        """Get strategy checkpoint save path."""
        self.check_context_handle()
        return self._context_handle.get_strategy_ckpt_save_file()

    def set_group_ckpt_save_file(self, group_ckpt_save_file):
        """Set group checkpoint save path."""
        self.check_context_handle()
        import os
        dir_path = os.path.dirname(group_ckpt_save_file)
        if dir_path and not os.path.exists(dir_path):
            os.makedirs(dir_path)
        self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)

    def get_parameter_broadcast_is_set(self):
        """Get parameter broadcast is set or not."""
        self.check_context_handle()
        return self._context_handle.get_parameter_broadcast_is_set()

    def set_all_reduce_fusion_split_indices(self, indices, group=""):
        """
        Set allreduce fusion strategy by parameters indices.

        Args:
            indices (list): Indices list.
            group (str): The communication group of hccl/nccl.

        Raises:
            TypeError: If type of indices item is not int.
            TypeError: If group is not a python str.
        """
        self.check_context_handle()
        if not indices:
            raise ValueError('indices can not be empty')

        if isinstance(indices, (list)):
            for index in indices:
                if not isinstance(index, int) or isinstance(index, bool):
                    raise TypeError(f"The type of index must be int, but got {type(index)}.")
        else:
            raise TypeError('indices must be a python list')

        if len(set(indices)) != len(indices):
            raise ValueError('indices has duplicate elements')

        if sorted(indices) != indices:
            raise ValueError('elements in indices must be sorted in ascending order')

        new_group = self._check_and_default_group(group)

        self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
        if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
            _set_fusion_strategy_by_idx(indices)

    def get_all_reduce_fusion_split_indices(self, group=""):
        """
        Get allreduce fusion split indices.

        Args:
            group (str): The communication group of hccl/nccl.

        Returns:
            Return split sizes list according to the group.

        Raises:
            TypeError: If group is not a python str.
        """
        self.check_context_handle()
        new_group = self._check_and_default_group(group)
        return self._context_handle.get_all_reduce_fusion_split_indices(new_group)

    def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
        """
        Set allreduce fusion strategy by parameters data sizes.

        Args:
            sizes (list): Sizes list.
            group (str): The communication group of hccl/nccl.

        Raises:
            TypeError: If type of sizes item is not int.
            TypeError: If group is not a python str.
        """
        self.check_context_handle()
        if isinstance(sizes, (list)):
            for size in sizes:
                if not isinstance(size, int) or isinstance(size, bool):
                    raise TypeError(f"The type of size must be int, but got {type(size)}.")
        else:
            raise TypeError('sizes must be a python list')

        new_group = self._check_and_default_group(group)
        self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
        if context.get_context("device_target") == "Ascend":
            _set_fusion_strategy_by_size(sizes)

    def get_all_reduce_fusion_split_sizes(self, group=""):
        """
        Get allreduce fusion split sizes.

        Args:
            group (str): The communication group of hccl/nccl.

        Returns:
            Return split sizes list according to the group.

        Raises:
            TypeError: If group is not a python str.
        """
        self.check_context_handle()
        new_group = self._check_and_default_group(group)
        return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)

    def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
        """
        Set enable/disable all reduce fusion.

        Args:
            enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
        """
        self.check_context_handle()
        if not isinstance(enable_all_reduce_fusion, bool):
            raise TypeError('enable_all_reduce_fusion is invalid type')
        self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)

    def get_enable_all_reduce_fusion(self):
        """Get all reduce fusion flag."""
        self.check_context_handle()
        return self._context_handle.get_enable_all_reduce_fusion()

    def get_device_num_is_set(self):
        """Get device number is set or not."""
        self.check_context_handle()
        return self._context_handle.get_device_num_is_set()

    def get_global_rank_is_set(self):
        """Get global rank is set or not."""
        self.check_context_handle()
        return self._context_handle.get_global_rank_is_set()

    def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
        """
        Set enable/disable parallel optimizer.

        Args:
            set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
        """
        self.check_context_handle()
        if not isinstance(enable_parallel_optimizer, bool):
            raise TypeError('enable_parallel_optimizer is invalid type')
        self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)

    def get_enable_parallel_optimizer(self):
        """Get parallel optimizer flag."""
        self.check_context_handle()
        return self._context_handle.get_enable_parallel_optimizer()

    def set_sharding_propagation(self, sharding_propagation):
        """
        Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
        will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
        will search the desired strategies.
        Default: False.

        Args:
            sharding_propagation (bool): Enable/disable strategy propagation.
        """
        self.check_context_handle()
        if not isinstance(sharding_propagation, bool):
            raise TypeError("'sharding_propagation' is an invalid type.")
        self._context_handle.set_sharding_propagation(sharding_propagation)

    def get_sharding_propagation(self):
        """Get the value of sharding strategy propagation."""
        self.check_context_handle()
        return self._context_handle.get_sharding_propagation()

    def set_enable_alltoall(self, enable_a2a):
        """
        Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
        Default: False.

        Args:
            enable_a2a (bool): Enable/disable AllToAll.
        """
        self.check_context_handle()
        if not isinstance(enable_a2a, bool):
            raise TypeError("'enable_a2a' is an invalid type.")
        self._context_handle.set_enable_alltoall(enable_a2a)

    def get_enable_alltoall(self):
        """Get the value of enabling AllToAll."""
        self.check_context_handle()
        return self._context_handle.get_enable_alltoall()

    def set_communi_parallel_mode(self, communi_parallel_mode):
        """
        Set communication parallel mode.

        Args:
            communi_parallel_mode (str): The communication parallel mode.

        Raises:
            ValueError: If parallel mode is not supported.
        """
        if not isinstance(communi_parallel_mode, str):
            raise TypeError(f"The type of communi_parallel_mode must be str, \
                but got {type(communi_parallel_mode)}.")
        self.check_context_handle()
        ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
        if ret is False:
            raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))

    def get_communi_parallel_mode(self):
        """Get communication parallel mode."""
        self.check_context_handle()
        return self._context_handle.get_communi_parallel_mode()

    def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
        """
        Set optimizer_weight_shard_size.

        Args:
            optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
                                               optimizer across devices.
        """
        self.check_context_handle()
        if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool):
            raise TypeError(f"The type of optimizer_weight_shard_size must be int, \
                but got {type(optimizer_weight_shard_size)}.")
        if optimizer_weight_shard_size <= 1:
            logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
                           "Please use the integer larger than 1.")
            return
        self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)

    def get_optimizer_weight_shard_size(self):
        """Get optimizer_weight_shard_size."""
        self.check_context_handle()
        return self._context_handle.get_optimizer_weight_shard_size()

    def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
        """
        Set optimizer_weight_shard_aggregated_save.

        Args:
            optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
                                                           enable parallel optimizer.
        """
        self.check_context_handle()
        if not isinstance(optimizer_weight_shard_aggregated_save, bool):
            raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
        self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)


    def get_optimizer_weight_shard_aggregated_save(self):
        """Get optimizer_weight_shard_size."""
        self.check_context_handle()
        return self._context_handle.get_optimizer_weight_shard_aggregated_save()


    def reset(self):
        """Reset all settings."""
        self.check_context_handle()
        self._context_handle.reset()


    def _check_and_default_group(self, group):
        """Validate the given group, if group is empty, returns a default fusion group"""
        if isinstance(group, (str)):
            group_len = len(group)
            if group_len > _MAX_GROUP_NAME_LEN:
                raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
        else:
            raise TypeError('Group must be a python str')

        if group == "":
            if context.get_context("device_target") == "Ascend":
                group = _DEFAULT_HCCL_FUSION_GROUP_NAME
            else:
                group = _DEFAULT_NCCL_FUSION_GROUP_NAME
        return group


_auto_parallel_context = None


def auto_parallel_context():
    """
    Get the global _auto_parallel_context, if it is not created, create a new one.

    Returns:
        _AutoParallelContext, the global auto parallel context.
    """
    global _auto_parallel_context
    if _auto_parallel_context is None:
        _auto_parallel_context = _AutoParallelContext()
    return _auto_parallel_context


_set_auto_parallel_context_func_map = {
    "device_num": auto_parallel_context().set_device_num,
    "global_rank": auto_parallel_context().set_global_rank,
    "gradients_mean": auto_parallel_context().set_gradients_mean,
    "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
    "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
    "pipeline_stages": auto_parallel_context().set_pipeline_stages,
    "parallel_mode": auto_parallel_context().set_parallel_mode,
    "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
    "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
    "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
    "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
    "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
    "full_batch": auto_parallel_context().set_full_batch,
    "dataset_strategy": auto_parallel_context().set_dataset_strategy,
    "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
    "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
    "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
    "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
    "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
    "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
    "sharding_propagation": auto_parallel_context().set_sharding_propagation,
    "enable_alltoall": auto_parallel_context().set_enable_alltoall}


_get_auto_parallel_context_func_map = {
    "device_num": auto_parallel_context().get_device_num,
    "global_rank": auto_parallel_context().get_global_rank,
    "gradients_mean": auto_parallel_context().get_gradients_mean,
    "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
    "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
    "pipeline_stages": auto_parallel_context().get_pipeline_stages,
    "parallel_mode": auto_parallel_context().get_parallel_mode,
    "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
    "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
    "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
    "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
    "full_batch": auto_parallel_context().get_full_batch,
    "dataset_strategy": auto_parallel_context().get_dataset_strategy,
    "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
    "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
    "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
    "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
    "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
    "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
    "sharding_propagation": auto_parallel_context().get_sharding_propagation,
    "enable_alltoall": auto_parallel_context().get_enable_alltoall}


@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
                 loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
                 parameter_broadcast=bool, strategy_ckpt_load_file=str,
                 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
                 grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
                 communi_parallel_mode=str, optimizer_weight_shard_size=int,
                 optimizer_weight_shard_aggregated_save=bool,
                 sharding_propagation=bool, enable_alltoall=bool)

def _set_auto_parallel_context(**kwargs):
    """
    Set auto parallel context.

    Note:
        Attribute name is required for setting attributes.

    Args:
        device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
        global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
        gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
        loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
                        calculations. Default: True.
        gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
                        Default: True.
        parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
                     "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".

                     - stand_alone: Only one processor working.

                     - data_parallel: Distributing the data across different processors.

                     - hybrid_parallel: Achieving data parallelism and model parallelism manually.

                     - semi_auto_parallel: Achieving data parallelism and model parallelism by
                       setting parallel strategies.

                     - auto_parallel: Achieving parallelism automatically.
        auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
                     and "dynamic_programming". Default: "dynamic_programming".

                     - recursive_programming: Recursive programming search mode.

                     - dynamic_programming: Dynamic programming search mode.
        parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
                       "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
                       broadcast. Default: False.
        strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
        strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
        group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
        full_batch (bool): Whether to load the whole batch on each device. Default: False.
        dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel".
        enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
        all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
        pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
                        the devices are distributed alone the pipeline. The total devices will be divided into
                        'pipeline_stags' stages. This currently could only be used when
                        parallel mode semi_auto_parallel is enabled. Default: 0
        communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
                     "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".

                     - all_group_parallel: All communication groups are in parallel.

                     - same_server_group_parallel: Only the communication groups within the same server are parallel.

                     - no_group_parallel: All communication groups are not parallel.
        optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
                                    It should be larger than one and less than or equal with the data parallel size.
                                    Default: -1, which means fully use parallel optimizer in data parallel dimension.
        optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
                                                       optimizer. Default: False.
        sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
                                    the strategy-configured operators will propagate the strategies to other
                                    operators with minimum redistribution cost; otherwise, the algorithm will
                                    search the desired strategies. Default: False.
        enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
                                circumvent AllToAll. Default: False.

    Raises:
        ValueError: If input key is not attribute in auto parallel context.
    """
    for key, value in kwargs.items():
        if key not in _set_auto_parallel_context_func_map:
            raise ValueError("Set context keyword %s is not recognized!" % key)
        set_func = _set_auto_parallel_context_func_map[key]
        set_func(value)


def _get_auto_parallel_context(attr_key):
    """
    Get auto parallel context attribute value according to the key.

    Args:
        attr_key (str): The key of the attribute.

    Returns:
        Return attribute value according to the key.

    Raises:
        ValueError: If input key is not attribute in auto parallel context.
    """
    if attr_key not in _get_auto_parallel_context_func_map:
        raise ValueError("Get context keyword %s is not recognized!" % attr_key)
    get_func = _get_auto_parallel_context_func_map[attr_key]
    return get_func()


def _reset_auto_parallel_context():
    """
    Reset auto parallel context attributes to the default values:

    - device_num: 1.
    - global_rank: 0.
    - gradients_mean: False.
    - gradient_fp32_sync: True.
    - parallel_mode: "stand_alone".
    - parameter_broadcast: False.
    - strategy_ckpt_load_file: ""
    - strategy_ckpt_save_file: ""
    - enable_parallel_optimizer: False
    - auto_parallel_search_mode: dynamic_programming
    - pipeline_stages: 0
    """
    auto_parallel_context().reset()