839d0f2d创建于 2024年11月27日历史提交
diff -Nur '--exclude=.git' apex/apex/amp/amp.py apex-develop/apex/amp/amp.py
--- apex/apex/amp/amp.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/amp.py	2024-03-07 21:33:04.293391422 +0800
@@ -65,7 +65,14 @@
 
 
 # Top-level function to insert _all_ the hooks.
-def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
+def init(
+    enabled=True,
+    loss_scale="dynamic",
+    enable_caching=True,
+    verbose=False,
+    allow_banned=False,
+    user_cast_preferred=None):
+
     global _DECORATOR_HANDLE
 
     if not enabled:
@@ -76,7 +83,10 @@
     handle = AmpHandle(loss_scale, enable_caching, verbose)
 
     # 0) Force-{fp16, fp32} for user-annotated functions
+    _user_cast_registry = set()
     for mod, fn, cast_fn in _USER_CAST_REGISTRY:
+        if user_cast_preferred:
+            _user_cast_registry.add((mod, fn))
         try_caching = (cast_fn == utils.maybe_half)
         wrap.cached_cast(mod, fn, cast_fn, handle,
                          try_caching, verbose)
@@ -96,6 +106,8 @@
     for module, (list_name, cast_fn) in itertools.product(override_modules,
                                                           cast_table):
         for fn in getattr(module, list_name):
+            if user_cast_preferred and (module.MODULE, fn) in _user_cast_registry:
+                continue
             try_caching = (cast_fn == utils.maybe_half)
             wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
                              try_caching, verbose)
diff -Nur '--exclude=.git' apex/apex/amp/_amp_state.py apex-develop/apex/amp/_amp_state.py
--- apex/apex/amp/_amp_state.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/_amp_state.py	2024-03-07 21:33:04.289391423 +0800
@@ -8,10 +8,10 @@
 TORCH_MAJOR = int(torch.__version__.split('.')[0])
 TORCH_MINOR = int(torch.__version__.split('.')[1])
 
-if TORCH_MAJOR == 0:
-    import collections.abc as container_abcs
-else:
+if TORCH_MAJOR == 1 and TORCH_MINOR < 9:
     from torch._six import container_abcs
+else:
+    import collections.abc as container_abcs
 
 
 class AmpState(object):
diff -Nur '--exclude=.git' apex/apex/amp/frontend.py apex-develop/apex/amp/frontend.py
--- apex/apex/amp/frontend.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/frontend.py	2024-03-07 21:33:04.293391422 +0800
@@ -19,6 +19,11 @@
             "keep_batchnorm_fp32" : None,
             "master_weights" : None,
             "loss_scale" : 1.0,
+            "combine_grad": None,
+            "combine_ddp": None,
+            "ddp_replica_count": 4,
+            "check_combined_tensors": None,
+            "user_cast_preferred":None,
             # Reserved for future functionality
             # "fused_optimizer" : False,
             # "enable_ddp_interop" : False,
@@ -91,6 +96,20 @@
                         self.options[name] = value
                     else:
                         self.options[name] = float(value)
+                elif name == "combine_grad" or name == "check_combined_tensors":
+                    if self.opt_level not in ["O1", "O2"] and value:
+                        warn_or_err("Currently, combine_grad=True or check_combined_tensors=True should only be set "
+                                    "by selecting opt_level='O1' or opt_level='O2'.")
+                    self.options[name] = value
+                elif name == "combine_ddp":
+                    if not self.combine_grad:
+                        warn_or_err("Combine_grad should be True when combine_ddp using.. \n")
+                    self.options[name] = value
+                elif name == "user_cast_preferred":
+                    if self.opt_level != "O1" and value:
+                        warn_or_err("Currently, user_cast_preferred=True should only be set by "
+                                    "selecting opt_level='O1'.")
+                    self.options[name] = value
                 else:
                     self.options[name] = value
         else:
@@ -161,6 +180,7 @@
         properties.keep_batchnorm_fp32 = None
         properties.master_weights = None
         properties.loss_scale = "dynamic"
+        properties.combine_grad = None
         # properties.fused_optimizer = False
         # properties.enable_ddp_interop = False
         return properties # modified in place so this isn't really necessary
@@ -205,8 +225,17 @@
     cast_model_outputs=None,
     num_losses=1,
     verbosity=1,
+    dynamic_init_scale=2.**16,
+    scale_growth_factor=2.,
+    scale_backoff_factor=0.5,
+    scale_window=2000,
     min_loss_scale=None,
-    max_loss_scale=2.**24
+    max_loss_scale=2.**24,
+    combine_grad=None,
+    combine_ddp=None,
+    ddp_replica_count=4,
+    user_cast_preferred=None,
+    check_combined_tensors=None
     ):
     """
     Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
@@ -254,11 +283,32 @@
             support multiple losses/backward passes, but use a single global loss scale
             for all of them.
         verbosity (int, default=1):  Set to 0 to suppress Amp-related output.
+        dynamic_init_scale (float, optional, default=2.**16):  Initial dynamic loss scale factor.
+        scale_growth_factor (float, optional, default=2.0):  Factor by which the scale is multiplied
+            if no overflow occurs for ``scale_window`` consecutive iterations.
+            If dynamic loss scaling is not used, `scale_growth_factor` is ignored.
+        scale_backoff_factor (float, optional, default=0.5):  Factor by which the scale is multiplied
+            if overflow occurs in an iteration. If dynamic loss scaling is not used, `scale_backoff_factor` is ignored.
+        scale_window (int, optional, default=2000):  Number of consecutive iterations without overflow
+            that must occur for the scale to be multiplied by ``scale_growth_factor``.
+            If dynamic loss scaling is not used, `scale_window` is ignored.
         min_loss_scale (float, default=None):  Sets a floor for the loss scale values that can be chosen by dynamic
             loss scaling.  The default value of None means that no floor is imposed.
             If dynamic loss scaling is not used, `min_loss_scale` is ignored.
         max_loss_scale (float, default=2.**24):  Sets a ceiling for the loss scale values that can be chosen by
             dynamic loss scaling.  If dynamic loss scaling is not used, `max_loss_scale` is ignored.
+        combine_grad (bool, optional, default=None): If True, make gradients fused for unscale.
+        combine_ddp (bool, optional, default=None): If True, use combined gradients for data exchange,
+            accelerate multi-card training, and functionally replace DistributedDataParallel.
+        ddp_replica_count (bool, optional, default=4): Set the number of replicas of combined gradients.
+            Theoretically, the more replicas, the higher the degree of parallelism, but the time-consuming
+            distribution operation itself will lead to a decrease in performance even though the degree
+            of parallelism is improved. Therefore, we limit and optimize the replica size for data exchange.
+            The final number of replicas is not necessarily exactly the same as the set number
+        user_cast_preferred (bool, optional, default=None): If True in O1, user cast registry is preferred
+            rather than fp16 white- / black-list, to avoid redundant dtype cast.
+        check_combined_tensors (bool, optional, default=None): If True, check if the combined grads and combined params
+            are valid during training
 
     Returns:
         Model(s) and optimizer(s) modified according to the ``opt_level``.
@@ -306,6 +356,7 @@
         https://github.com/NVIDIA/apex/issues
     """
     _amp_state.opt_properties = Properties()
+    # Here add a switch to open combine tensor
     _amp_state.verbosity = verbosity
 
     if not enabled:
@@ -330,6 +381,10 @@
         for k, v in _amp_state.opt_properties.options.items():
             maybe_print("{:22} : {}".format(k, v), True)
 
+    _amp_state.dynamic_init_scale = dynamic_init_scale
+    _amp_state.scale_growth_factor = scale_growth_factor
+    _amp_state.scale_backoff_factor = scale_backoff_factor
+    _amp_state.scale_window = scale_window
     _amp_state.min_loss_scale = min_loss_scale
     _amp_state.max_loss_scale = max_loss_scale
 
@@ -350,6 +405,16 @@
         _amp_state.opt_properties.master_weights = master_weights
     if loss_scale is not None:
         _amp_state.opt_properties.loss_scale = loss_scale
+    if combine_grad is not None:
+        _amp_state.opt_properties.combine_grad = combine_grad
+    if combine_ddp is not None:
+        _amp_state.opt_properties.combine_ddp = combine_ddp
+    if ddp_replica_count is not None:
+        _amp_state.opt_properties.ddp_replica_count = ddp_replica_count
+    if user_cast_preferred is not None:
+        _amp_state.opt_properties.user_cast_preferred = user_cast_preferred
+    if check_combined_tensors is not None:
+        _amp_state.opt_properties.check_combined_tensors = check_combined_tensors
 
     maybe_print("After processing overrides, optimization options are:", True)
     for k, v in _amp_state.opt_properties.options.items():
diff -Nur '--exclude=.git' apex/apex/amp/handle.py apex-develop/apex/amp/handle.py
--- apex/apex/amp/handle.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/handle.py	2024-03-07 21:33:04.293391422 +0800
@@ -1,7 +1,24 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import contextlib
 import warnings
 import sys
 import torch
+import torch_npu
 
 from . import utils
 from .opt import OptimWrapper
@@ -110,6 +127,11 @@
                 if not optimizer._amp_stash.params_have_scaled_gradients:
                     optimizer._prepare_amp_backward()
 
+    is_support_inf_nan = hasattr(
+        torch_npu.npu.utils, 'is_support_inf_nan') and torch_npu.npu.utils.is_support_inf_nan()
+    if loss_scaler.dynamic and not is_support_inf_nan:
+        torch_npu.npu.clear_npu_overflow_flag()
+
     yield (loss.float())*loss_scale
 
     if delay_unscale:
@@ -119,6 +141,7 @@
         # FusedSGD may take care of unscaling as part of their step() methods.
         # if not isinstance(optimizers, FP16_Optimizer_for_fused):
             loss_scaler.clear_overflow_state()
+            loss_scaler.check_overflow_and_sync()
             for optimizer in optimizers:
                 optimizer._post_amp_backward(loss_scaler)
                 optimizer._amp_stash.params_have_scaled_gradients = False
@@ -142,8 +165,12 @@
                                 # Maybe skip should delegate to a method owned by the optimizers themselves.
                                 if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
                                     # Clear the master grads that wouldn't be zeroed by model.zero_grad()
-                                    for param in opt._amp_stash.all_fp32_from_fp16_params:
-                                        param.grad = None
+                                    if opt.accelerate or opt.is_npu_fused_optimizer:
+                                        if opt._amp_stash.main_fp32_from_fp16_grad_combine is not None:
+                                            opt._amp_stash.main_fp32_from_fp16_grad_combine.zero_()
+                                    else:
+                                        for param in opt._amp_stash.all_fp32_from_fp16_params:
+                                            param.grad = None
                                 if hasattr(opt, "most_recent_scale"):
                                     opt.most_recent_scale = 1.0
                                     opt.scale_set_by_backward = False
diff -Nur '--exclude=.git' apex/apex/amp/_initialize.py apex-develop/apex/amp/_initialize.py
--- apex/apex/amp/_initialize.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/_initialize.py	2024-03-07 21:33:04.289391423 +0800
@@ -1,11 +1,27 @@
-import torch
-from torch._six import string_classes
-import functools
-import numpy as np
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import sys
 from types import MethodType
+import functools
+import torch
+import torch.distributed as dist
+import numpy as np
 import warnings
-from ._amp_state import _amp_state, warn_or_err, container_abcs
+from ._amp_state import _amp_state, warn_or_err, container_abcs, maybe_print
 from .handle import disable_casts
 from .scaler import LossScaler
 from ._process_optimizer import _process_optimizer
@@ -18,11 +34,38 @@
     from ..parallel.LARC import LARC
 
 
+def zero_grad(self, set_to_none: bool = False) -> None:
+    r"""Patch for torch.nn.Module.zero_grad. For combined grad or NPU fused optimizers,
+    set_to_none must be False.
+
+    Args:
+        set_to_none (bool): instead of setting to zero, set the grads to None.
+            See :meth:`torch.optim.Optimizer.zero_grad` for details.
+    """
+
+    assert set_to_none is False, "For combined grad, `set_to_none` must be False."
+
+    if getattr(self, '_is_replica', False):
+        warnings.warn(
+            "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
+            "The parameters are copied (in a differentiable manner) from the original module. "
+            "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
+            "If you need gradients in your forward method, consider using autograd.grad instead.")
+
+    for p in self.parameters():
+        if p.grad is not None:
+            if p.grad.grad_fn is not None:
+                p.grad.detach_()
+            else:
+                p.grad.requires_grad_(False)
+            p.grad.zero_()
+
+
 def to_type(dtype, t):
     if isinstance(t, torch.Tensor):
-        if not t.is_cuda:
+        if not 'npu' in t.type():
             # This should not be a hard error, since it may be legitimate.
-            warnings.warn("An input tensor was not cuda.")
+            warnings.warn("An input tensor was not npu.")
         # GANs require this.
         # if t.requires_grad:
         #     warn_or_err("input data requires grad.  Since input data is not a model parameter,\n"
@@ -39,7 +82,7 @@
 def applier(value, fn):
     if isinstance(value, torch.Tensor):
         return fn(value)
-    elif isinstance(value, string_classes):
+    elif isinstance(value, str):
         return value
     elif isinstance(value, np.ndarray):
         return value
@@ -81,15 +124,15 @@
         for name, param in model.named_parameters():
             if param.is_floating_point():
                 if 'Half' in param.type():
-                    warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
+                    warn_or_err("Found param {} with type {}, expected torch.npu.FloatTensor.\n"
                         "When using amp.initialize, you do not need to call .half() on your model\n"
                         "before passing it, no matter what optimization level you choose.".format(
                         name, param.type()))
-                elif not param.is_cuda:
-                    warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
+                elif not 'npu' in param.type():
+                    warn_or_err("Found param {} with type {}, expected torch.npu.FloatTensor.\n"
                         "When using amp.initialize, you need to provide a model with parameters\n"
-                        "located on a CUDA device before passing it no matter what optimization level\n"
-                        "you chose. Use model.to('cuda') to use the default device.".format(
+                        "located on a Npu device before passing it no matter what optimization level\n"
+                        "you chose. Use model.to('npu') to use the default device.".format(
                         name, param.type()))
 
         # Backward compatibility for PyTorch 0.4
@@ -104,15 +147,15 @@
                 name, buf = obj, buf_iter[obj]
             if buf.is_floating_point():
                 if 'Half' in buf.type():
-                    warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
+                    warn_or_err("Found buffer {} with type {}, expected torch.npu.FloatTensor.\n"
                         "When using amp.initialize, you do not need to call .half() on your model\n"
                         "before passing it, no matter what optimization level you choose.".format(
                         name, buf.type()))
-                elif not buf.is_cuda:
-                    warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
+                elif not 'npu' in buf.type():
+                    warn_or_err("Found buffer {} with type {}, expected torch.npu.FloatTensor.\n"
                         "When using amp.initialize, you need to provide a model with buffers\n"
-                        "located on a CUDA device before passing it no matter what optimization level\n"
-                        "you chose. Use model.to('cuda') to use the default device.".format(
+                        "located on a Npu device before passing it no matter what optimization level\n"
+                        "you chose. Use model.to('npu') to use the default device.".format(
                         name, buf.type()))
 
 
@@ -227,12 +270,18 @@
     _amp_state.loss_scalers = []
     for _ in range(num_losses):
         _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
+                                                  init_scale=_amp_state.dynamic_init_scale,
+                                                  scale_growth_factor=_amp_state.scale_growth_factor,
+                                                  scale_backoff_factor=_amp_state.scale_backoff_factor,
+                                                  scale_window=_amp_state.scale_window,
                                                   min_loss_scale=_amp_state.min_loss_scale,
                                                   max_loss_scale=_amp_state.max_loss_scale))
 
     if properties.patch_torch_functions:
         # handle is unused here. It's accessible later through a global value anyway.
-        handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
+        handle = amp_init(loss_scale=properties.loss_scale,
+                          verbose=(_amp_state.verbosity == 2),
+                          user_cast_preferred=properties.user_cast_preferred)
         for optimizer in optimizers:
             # Disable Amp casting for the optimizer step, because it should only be
             # applied to FP32 master params anyway.
@@ -245,6 +294,24 @@
 
             optimizer.step = MethodType(patch_step(optimizer.step), optimizer)
 
+
+    is_npu_fused_optimizer = False
+    for optimizer in optimizers:
+        if hasattr(optimizer, 'is_npu_fused_optimizer') and optimizer.is_npu_fused_optimizer:
+            is_npu_fused_optimizer = True
+            break
+    if properties.combine_grad or is_npu_fused_optimizer:
+        torch.nn.Module.zero_grad = zero_grad
+        maybe_print(
+            "Warning: "
+            "Default value of `set_to_none` in torch.nn.Module.zero_grad() is set as False for combine grad, "
+            "which is True since torch 2.0.")
+
+    if properties.combine_ddp:
+        for model in models:
+            for name, param in model.named_parameters():
+                dist.broadcast(param, 0)
+
     if optimizers_was_list:
         if models_was_list:
             return models, optimizers
diff -Nur '--exclude=.git' apex/apex/amp/_process_optimizer.py apex-develop/apex/amp/_process_optimizer.py
--- apex/apex/amp/_process_optimizer.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/_process_optimizer.py	2024-03-07 21:33:04.289391423 +0800
@@ -1,9 +1,89 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import types
+import torch
+import torch_npu
+from change_data_ptr import change_data_ptr
+import torch.distributed as dist
+from ._amp_state import maybe_print
 from ..fp16_utils import master_params_to_model_params
 from ..multi_tensor_apply import multi_tensor_applier
-from ._amp_state import maybe_print
-import torch
 from ..optimizers import FusedSGD
+from ..contrib.combine_tensors import (
+    combine_npu,
+    get_part_combined_tensor,
+    is_combined_tensor_valid,
+    get_aligned_storage_size
+)
+
+TORCH_MAJOR = int(torch.__version__.split('.')[0])
+
+if TORCH_MAJOR == 1:
+    from torch._six import inf
+else:
+    from torch import inf
+
+
+def get_grad_combined_tensor_from_param(list_of_params):
+    if len(list_of_params) > 0 and list_of_params[0].grad is not None:
+        list_of_grad = []
+        for param in list_of_params:
+            if param.requires_grad:
+                list_of_grad.append(param.grad)
+        original_combined_tensor = combine_npu(list_of_grad)
+        return original_combined_tensor, list_of_grad
+    else:
+        return None, []
+
+
+def get_grad_combined_tensor_mask_from_param(list_of_params):
+    if len(list_of_params) > 0 and list_of_params[0].grad is not None:
+        list_of_grad_mask = []
+        for param in list_of_params:
+            if param.requires_grad:
+                grad_size = param.grad.size()
+                grad_format = torch_npu.get_npu_format(param)
+                list_of_grad_mask.append(torch_npu.npu_format_cast(torch.ones(grad_size).npu(), grad_format))
+        grad_combined_tensor_mask = combine_npu(list_of_grad_mask)
+        return grad_combined_tensor_mask
+    else:
+        return None
+
+
+def clip_grad_norm_fused(combined_grads, combined_grad_masks, max_norm, norm_type):
+    max_norm = float(max_norm)
+    norm_type = float(norm_type)
+    tmp_lst = []
+    if norm_type == inf:
+        for combined_grad, combined_grad_mask in zip(combined_grads, combined_grad_masks):
+            if combined_grad is not None:
+                tmp_lst.append(combined_grad.float().abs().mul_(combined_grad_mask).max())
+        total_norm = max(tmp_lst)
+    else:
+        for combined_grad, combined_grad_mask in zip(combined_grads, combined_grad_masks):
+            if combined_grad is not None:
+                tmp_lst.append(combined_grad.float().abs().pow(norm_type).mul_(combined_grad_mask).sum())
+        total_norm = torch.stack(tmp_lst).sum().pow(1/norm_type)
+    clip_coef = max_norm / (total_norm + 1e-6)
+    if clip_coef < 1:
+        for combined_grad in combined_grads:
+            if combined_grad is not None:
+                combined_grad.mul_(clip_coef)
+    return total_norm
 
 
 class AmpOptimizerState(object):
@@ -26,96 +106,117 @@
 
 
 def lazy_init_with_master_weights(self):
-        stash = self._amp_stash
-        stash.fp16_groups = []
-        stash.fp32_from_fp16_groups = []
-        stash.fp32_from_fp32_groups = []
-        for i, param_group in enumerate(self.param_groups):
-            # maybe_print("FP16_Optimizer processing param group {}:".format(i))
-            fp16_params_this_group = []
-            fp32_params_this_group = []
-            fp32_from_fp16_params_this_group = []
-            for i, param in enumerate(param_group['params']):
-                if param.requires_grad:
-                    if param.type() == 'torch.cuda.HalfTensor':
-                        # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
-                        #             .format(param.size()))
-                        fp16_params_this_group.append(param)
-                        master_param = param.detach().clone().float()
-                        master_param.requires_grad = True
-                        param_group['params'][i] = master_param
-                        fp32_from_fp16_params_this_group.append(master_param)
-                        # Reset existing state dict key to the new master param.
-                        # We still need to recast per-param state tensors, if any, to FP32.
-                        if param in self.state:
-                           self.state[master_param] = self.state.pop(param)
-                    elif param.type() == 'torch.cuda.FloatTensor':
-                        # maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
-                        #             .format(param.size()))
-                        fp32_params_this_group.append(param)
-                        param_group['params'][i] = param
-                    else:
-                        raise TypeError("Optimizer's parameters must be either "
-                                        "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
-                                        "Received {}".format(param.type()))
-
-            stash.fp16_groups.append(fp16_params_this_group)
-            stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
-            stash.fp32_from_fp32_groups.append(fp32_params_this_group)
+    stash = self._amp_stash
+    stash.fp16_groups = []
+    stash.fp32_from_fp16_groups = []
+    stash.fp32_from_fp32_groups = []
+    for i, param_group in enumerate(self.param_groups):
+        # maybe_print("FP16_Optimizer processing param group {}:".format(i))
+        fp16_params_this_group = []
+        fp32_params_this_group = []
+        fp32_from_fp16_params_this_group = []
+        for i, param in enumerate(param_group['params']):
+            if param.requires_grad:
+                if param.type() == 'torch.npu.HalfTensor':
+                    # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
+                    #             .format(param.size()))
+                    fp16_params_this_group.append(param)
+                    master_param = param.detach().clone().float()
+                    master_param.requires_grad = True
+                    param_group['params'][i] = master_param
+                    fp32_from_fp16_params_this_group.append(master_param)
+                    # Reset existing state dict key to the new master param.
+                    # We still need to recast per-param state tensors, if any, to FP32.
+                    if param in self.state:
+                        self.state[master_param] = self.state.pop(param)
+                elif param.type() == 'torch.npu.FloatTensor':
+                    # maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
+                    #             .format(param.size()))
+                    fp32_params_this_group.append(param)
+                    param_group['params'][i] = param
+                else:
+                    raise TypeError("Optimizer's parameters must be either "
+                                    "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
+                                    "Received {}".format(param.type()))
+
+        stash.fp16_groups.append(fp16_params_this_group)
+        stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
+        stash.fp32_from_fp32_groups.append(fp32_params_this_group)
 
-        stash.all_fp16_params = []
-        for group in stash.fp16_groups:
-            stash.all_fp16_params += group
-
-        stash.all_fp32_from_fp16_params = []
-        for group in stash.fp32_from_fp16_groups:
-            stash.all_fp32_from_fp16_params += group
-
-        stash.all_fp32_from_fp32_params = []
-        for group in stash.fp32_from_fp32_groups:
-            stash.all_fp32_from_fp32_params += group
-
-        # all_fp16_grad_stash is only needed for fused optimizers.
-        stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
-        # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
-        stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
+    stash.all_fp16_params = []
+    for group in stash.fp16_groups:
+        stash.all_fp16_params += group
 
-        for param in stash.all_fp32_from_fp16_params:
-            param.grad = None
+    stash.all_fp32_from_fp16_params = []
+    for group in stash.fp32_from_fp16_groups:
+        stash.all_fp32_from_fp16_params += group
+
+    stash.all_fp32_from_fp32_params = []
+    for group in stash.fp32_from_fp32_groups:
+        stash.all_fp32_from_fp32_params += group
 
-        for param in stash.all_fp32_from_fp32_params:
-            param.grad = None
+    # all_fp16_grad_stash is only needed for fused optimizers.
+    stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
+    # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
+    stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
 
-        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
-        self.load_state_dict(self.state_dict())
+    for param in stash.all_fp32_from_fp16_params:
+        param.grad = None
 
+    for param in stash.all_fp32_from_fp32_params:
+        param.grad = None
+    
+    stash.main_fp16_grad_combine = None
+    stash.main_fp32_from_fp16_grad_combine = None
+    stash.main_fp32_from_fp32_grad_combine = None
+    stash.main_fp16_grad_combine_mask = None
+    stash.main_fp32_from_fp16_grad_combine_mask = None
+    stash.main_fp32_from_fp32_grad_combine_mask = None
+
+    stash.all_fp32_from_fp32_grad_stash_combine = None
+
+    stash.main_fp16_param_combine = None
+    stash.main_fp32_from_fp16_param_combine = None
+    stash.main_fp32_from_fp32_param_combine = None
+    # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
+    self.load_state_dict(self.state_dict())
+
+
+def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None, 
+                                     main_grads_combined=None, stashed_grads_combined=None, 
+                                     use_npu_fused_optimizer=False, stashed_grads_are_zero=False, main_grads_list=None):
+    grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
 
-def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
-        grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
+    # not much to do if scale == 1.0 and static scaling
+    if scaler.loss_scale() == 1.0 and not scaler.dynamic:
+        # Clear the stash.
+        for i in range(len(stashed_grads)):
+            stashed_grads[i] = None
+        return
 
-        # not much to do if scale == 1.0 and static scaling
-        if scaler.loss_scale() == 1.0 and not scaler.dynamic:
-            # Clear the stash.
-            for i in range(len(stashed_grads)):
-                stashed_grads[i] = None
-            return
-        
-        if scale_override is not None:
-            grads_have_scale, stashed_have_scale, out_scale = scale_override
+    if scale_override is not None:
+        grads_have_scale, stashed_have_scale, out_scale = scale_override
 
-        # This is a lot of python overhead...
+    # This is a lot of python overhead...
+    if main_grads_combined is not None:
+        scaler.unscale_with_stashed_combined(
+            main_grads_combined, 
+            stashed_grads_combined if not stashed_grads_are_zero else None,
+            scale_override=(grads_have_scale, stashed_have_scale, out_scale),
+            grads_list=main_grads_list)
+    else:
         grads_needing_unscale = []
         grads_needing_unscale_with_stash = []
         stashed = []
         for param, stashed_grad in zip(params, stashed_grads):
             if param.grad is None and stashed_grad is not None:
                 param.grad = stashed_grad
-            elif param.grad is not None and stashed_grad is None:
+            elif param.grad is not None and (stashed_grad is None or stashed_grads_are_zero):
                 grads_needing_unscale.append(param.grad)
             elif param.grad is not None and stashed_grad is not None:
                 grads_needing_unscale_with_stash.append(param.grad)
                 stashed.append(stashed_grad)
-            else: # param.grad is None and stashed_grad is None
+            else:  # param.grad is None and stashed_grad is None
                 continue
 
         # unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
@@ -123,130 +224,358 @@
             scaler.unscale(
                 grads_needing_unscale,
                 grads_needing_unscale,
-                None, # unused_scale, currently present to avoid API breakage elsewhere
+                None,  # unused_scale, currently present to avoid API breakage elsewhere
                 models_are_masters=True,
-                scale_override=grads_have_scale/out_scale)
+                scale_override=grads_have_scale / out_scale)
 
         if len(grads_needing_unscale_with_stash) > 0:
             scaler.unscale_with_stashed(
                 grads_needing_unscale_with_stash,
                 stashed,
                 grads_needing_unscale_with_stash,
-                scale_override=(grads_have_scale, stashed_have_scale, out_scale))
+                scale_override=(grads_have_scale, stashed_have_scale, out_scale),
+                use_npu_fused_optimizer=use_npu_fused_optimizer)
 
-        # Clear the stash.
-        for i in range(len(stashed_grads)):
-            stashed_grads[i] = None
+        if not use_npu_fused_optimizer:
+            # Clear the stash.
+            for i in range(len(stashed_grads)):
+                stashed_grads[i] = None
 
 
 def prepare_backward_with_master_weights(self):
     stash = self._amp_stash
 
     self._amp_lazy_init()
+    self._check_already_combined_params_and_grads()
 
-    for i, param in enumerate(stash.all_fp16_params):
-        # Set up to leverage grad copy elision.
-        # This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
-        param.grad = None
+    if (self.accelerate or self.is_npu_fused_optimizer) and stash.already_combined:
+        if stash.process_zero_grad:
+            return
 
-    # for i, param in enumerate(stash.all_fp32_from_fp16_params):
-    #     stash.all_fp32_from_fp16_grad_stash[i] = param.grad
+        if stash.main_fp16_grad_combine is not None:
+            stash.main_fp16_grad_combine.zero_()
 
-    for i, param in enumerate(stash.all_fp32_from_fp32_params):
-        stash.all_fp32_from_fp32_grad_stash[i] = param.grad
-        # Set up to leverage grad copy elision:
-        param.grad = None
+        if stash.main_fp32_from_fp32_grad_combine is not None:
+            stash.all_fp32_from_fp32_grad_stash_combine.copy_(stash.main_fp32_from_fp32_grad_combine)
+            stash.main_fp32_from_fp32_grad_combine.zero_()
+    else:
+        for i, param in enumerate(stash.all_fp16_params):
+            # Set up to leverage grad copy elision.
+            # This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
+            param.grad = None
 
+        for i, param in enumerate(stash.all_fp32_from_fp32_params):
+            stash.all_fp32_from_fp32_grad_stash[i] = param.grad
+            # Set up to leverage grad copy elision:
+            param.grad = None
 
-def post_backward_with_master_weights(self, scaler):
+
+def combine_ddp_hook_func(name, param, target_grads_size_list, current_param_size_list,
+              name_dict, reduce_stream, partial_combined_grad_list, ready_reduce_index, world_size):
+    def hook_function(grad):
+        if ready_reduce_index:
+            index = ready_reduce_index.pop()
+            current_param_size_list[index] = 0
+            partial_combined_grad_list[index].div_(world_size)
+            reduce_stream.wait_stream(torch.npu.current_stream())
+            with torch.npu.stream(reduce_stream):
+                dist.all_reduce(partial_combined_grad_list[index])
+
+        current_param_size_list[name_dict[name]] += get_aligned_storage_size(param)
+        for i, _ in enumerate(current_param_size_list):
+            if current_param_size_list[i] == target_grads_size_list[i] and current_param_size_list[i] != 0:
+                ready_reduce_index.append(i)
+                break
+    return hook_function
+
+
+def init_combine_ddp_no_master_weights(self):
     stash = self._amp_stash
+    combined_grads_list = [stash.main_fp32_grad_combine]
+    params_list = [stash.all_fp32_params]
 
-    self._amp_lazy_init()
+    return self._init_combine_ddp_common(combined_grads_list, params_list)
 
-    # This is a lot of python overhead...
-    fp16_grads_needing_unscale = []
-    new_fp32_grads = []
-    fp16_grads_needing_unscale_with_stash = []
-    preexisting_fp32_grads = []
-    for fp16_param, fp32_param in zip(stash.all_fp16_params,
-                                      stash.all_fp32_from_fp16_params):
-        if fp16_param.grad is None and fp32_param.grad is not None:
+
+def init_combine_ddp_with_master_weights(self):
+    stash = self._amp_stash
+    combined_grads_list = [stash.main_fp16_grad_combine, stash.main_fp32_from_fp32_grad_combine]
+    params_list = [stash.all_fp16_params, stash.all_fp32_from_fp32_params]
+
+    return self._init_combine_ddp_common(combined_grads_list, params_list)
+
+
+def init_combine_ddp_common(self, combined_grads_list, params_list):
+    exchange_threshold_max = 24 * 1024 * 1024
+    exchange_threshold_min = 1 * 1024 * 1024
+    ddp_replica_count = self.ddp_replica_count
+    world_size = dist.get_world_size()
+    all_reduce_stream = torch.npu.Stream()
+    exchange_threshold_list = [0 for _ in combined_grads_list]
+    target_grads_size_lists = [[] for _ in combined_grads_list]
+    name_dict_list = [{} for _ in combined_grads_list]
+    partial_combined_grad_lists = [[] for _ in combined_grads_list]
+
+    for idx, combined_grads in enumerate(combined_grads_list):
+        if combined_grads is None:
             continue
-        elif fp16_param.grad is not None and fp32_param.grad is None:
-            fp32_param.grad = torch.empty_like(fp32_param)
-            fp16_grads_needing_unscale.append(fp16_param.grad)
-            new_fp32_grads.append(fp32_param.grad)
-        elif fp16_param.grad is not None and fp32_param.grad is not None:
-            fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
-            preexisting_fp32_grads.append(fp32_param.grad)
-        else: # fp16_param.grad is None and fp32_param.grad is None:
+
+        if combined_grads.dim() == 1:
+            combined_grads_len = combined_grads.shape[0]
+            tmp_combined_grads = torch.tensor(combined_grads_len, dtype=torch.float32, device=combined_grads.device)
+            gather_list = [torch.zeros(1, dtype=torch.float32).npu() for _ in range(world_size)]
+            dist.all_gather(gather_list, tmp_combined_grads)
+
+            for i in range(1, world_size):
+                if gather_list[0] != gather_list[i]:
+                    raise RuntimeError("When using combine_ddp, "
+                                       "combine_grad does not support inconsistent parameters in each rank. "
+                                       "Please consider using the consistent parameters of each rank instead.")
+
+        tmp_combined_grads_len = combined_grads.shape[0] // ddp_replica_count
+        exchange_threshold_list[idx] = min(tmp_combined_grads_len, exchange_threshold_max \
+            if combined_grads.type() == 'torch.npu.FloatTensor' else exchange_threshold_max * 2)
+        exchange_threshold_list[idx] = max(exchange_threshold_list[idx], exchange_threshold_min)
+        dist.all_reduce(combined_grads.div_(world_size))
+
+    for idx, params in enumerate(params_list):
+        target_grads_size_list = target_grads_size_lists[idx]
+        name_dict = name_dict_list[idx]
+        tmp_size = 0
+        name_order = 0
+        for param_idx, param in enumerate(params):
+            name = '%d_%d'%(idx, param_idx)
+            cur_size = get_aligned_storage_size(param)
+            if cur_size > exchange_threshold_list[idx] and tmp_size != 0:
+                target_grads_size_list.append(tmp_size)
+                tmp_size = 0
+                name_order += 1
+            tmp_size += cur_size
+            name_dict[name] = name_order
+            if tmp_size > exchange_threshold_list[idx]:
+                target_grads_size_list.append(tmp_size)
+                tmp_size = 0
+                name_order += 1
+        if tmp_size != 0:
+            target_grads_size_list.append(tmp_size)
+    maybe_print('Optimized combine_ddp replicas: {}'.format(target_grads_size_lists), rank0=True)
+
+    for idx, target_grads_size_list in enumerate(target_grads_size_lists):
+        combined_grads = combined_grads_list[idx]
+        if combined_grads is None:
             continue
 
-    if len(fp16_grads_needing_unscale) > 0:
-        scaler.unscale(
-            fp16_grads_needing_unscale,
-            new_fp32_grads,
-            scaler.loss_scale(),
-            models_are_masters=False)
-
-    if len(fp16_grads_needing_unscale_with_stash) > 0:
-        scaler.unscale_with_stashed(
-            fp16_grads_needing_unscale_with_stash,
-            preexisting_fp32_grads,
-            preexisting_fp32_grads)
-
-    # fp32 params can be treated as they would be in the "no_master_weights" case.
-    post_backward_models_are_masters(
-        scaler,
-        stash.all_fp32_from_fp32_params,
-        stash.all_fp32_from_fp32_grad_stash)
+        ptr_index = 0
+        partial_combined_grad_list = partial_combined_grad_lists[idx]
+        for target_grads_size in target_grads_size_list:
+            partial_combined_grad_list.append(get_part_combined_tensor(combined_grads, ptr_index, target_grads_size))
+            ptr_index += target_grads_size
+
+    current_param_size_lists = [[0] * len(target_grads_size_list) for target_grads_size_list in
+                               target_grads_size_lists]
+    ready_reduce_index_list = [[] for _ in combined_grads_list]
+
+    for idx, params in enumerate(params_list):
+        for param_idx, param in enumerate(params):
+            name = '%d_%d'%(idx, param_idx)
+            param.register_hook(
+                combine_ddp_hook_func(name, param, target_grads_size_lists[idx], current_param_size_lists[idx],
+                          name_dict_list[idx], all_reduce_stream, partial_combined_grad_lists[idx],
+                          ready_reduce_index_list[idx], world_size))
+
+    self.ready_reduce_index_list = ready_reduce_index_list
+    self.partial_combined_grad_lists = partial_combined_grad_lists
+    self.current_param_size_lists = current_param_size_lists
+    self.all_reduce_stream = all_reduce_stream
+    self.world_size = world_size
+
+
+def combine_ddp_all_reduce(self):
+    last_reduce_grad_list = []
+    for idx, partial_combined_grad_list in enumerate(self.partial_combined_grad_lists):
+        if partial_combined_grad_list:
+            last_reduce_grad = partial_combined_grad_list[self.ready_reduce_index_list[idx][0]]
+            last_reduce_grad.div_(self.world_size)
+            last_reduce_grad_list.append(last_reduce_grad)
+
+    torch.npu.current_stream().wait_stream(self.all_reduce_stream)
+    for idx, last_reduce_grad in enumerate(last_reduce_grad_list):
+        dist.all_reduce(last_reduce_grad)
+        self.current_param_size_lists[idx][self.ready_reduce_index_list[idx][0]] = 0
+        self.ready_reduce_index_list[idx].pop()
+
+def combine_ddp_proc(self):
+    if self.combine_ddp:
+        if not self.init_combine_ddp:
+            self._init_combine_ddp()
+            self.init_combine_ddp = True
+        else:
+            self._combine_ddp_all_reduce()
+
+def post_backward_with_master_weights(self, scaler):
+    stash = self._amp_stash
+
+    self._amp_lazy_init()
+    self._check_already_combined_params_and_grads()
+    self._amp_combined_init()
+    self._combine_ddp_proc()
+
+    if self.accelerate:
+        scaler.unscale_grad_O2(
+            model_grads_combined=stash.main_fp16_grad_combine,
+            stashed_master_grads_combined=stash.main_fp32_from_fp16_grad_combine if not stash.process_zero_grad else None,
+            master_grads_combined=stash.main_fp32_from_fp16_grad_combine,
+            master_grads=stash.fp32_from_fp16_grad_list,
+            model_grads=stash.fp16_grad_list)
+        if stash.main_fp32_from_fp32_grad_combine is not None:
+            scaler.unscale_grad_O2(
+                model_grads_combined=stash.main_fp32_from_fp32_grad_combine,
+                stashed_master_grads_combined=stash.all_fp32_from_fp32_grad_stash_combine if not stash.process_zero_grad else None,
+                master_grads_combined=stash.main_fp32_from_fp32_grad_combine,
+                model_grads=stash.fp32_from_fp32_grad_list)
+    else:
+        # This is a lot of python overhead...
+        fp16_grads_needing_unscale = []
+        new_fp32_grads = []
+        fp16_grads_needing_unscale_with_stash = []
+        preexisting_fp32_grads = []
+        for fp16_param, fp32_param in zip(stash.all_fp16_params,
+                                          stash.all_fp32_from_fp16_params):
+            if fp16_param.grad is None and fp32_param.grad is not None:
+                continue
+            elif fp16_param.grad is not None and fp32_param.grad is None:
+                fp32_param.grad = torch.empty_like(fp32_param)
+                fp16_grads_needing_unscale.append(fp16_param.grad)
+                new_fp32_grads.append(fp32_param.grad)
+            elif fp16_param.grad is not None and fp32_param.grad is not None:
+                if stash.process_zero_grad:
+                    fp16_grads_needing_unscale.append(fp16_param.grad)
+                    new_fp32_grads.append(fp32_param.grad)
+                else:
+                    fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
+                    preexisting_fp32_grads.append(fp32_param.grad)
+            else: # fp16_param.grad is None and fp32_param.grad is None:
+                continue
+
+        if len(fp16_grads_needing_unscale) > 0:
+            scaler.unscale(
+                fp16_grads_needing_unscale,
+                new_fp32_grads,
+                scaler.loss_scale(),
+                models_are_masters=False)
+
+        if len(fp16_grads_needing_unscale_with_stash) > 0:
+            scaler.unscale_with_stashed(
+                fp16_grads_needing_unscale_with_stash,
+                preexisting_fp32_grads,
+                preexisting_fp32_grads,
+                use_npu_fused_optimizer=self.is_npu_fused_optimizer)
+
+        # fp32 params can be treated as they would be in the "no_master_weights" case.
+        post_backward_models_are_masters(
+            scaler,
+            stash.all_fp32_from_fp32_params,
+            stash.all_fp32_from_fp32_grad_stash,
+            use_npu_fused_optimizer=self.is_npu_fused_optimizer,
+            stashed_grads_are_zero=stash.process_zero_grad)
+    
+    stash.process_zero_grad = False
 
 
 def lazy_init_no_master_weights(self):
     stash = self._amp_stash
     stash.all_fp16_params = []
     stash.all_fp32_params = []
+
+    check_param_require_grad = self.accelerate or self.is_npu_fused_optimizer
+
     for i, param_group in enumerate(self.param_groups):
         for i, param in enumerate(param_group['params']):
-            if param.type() == 'torch.cuda.HalfTensor':
+            if check_param_require_grad and not param.requires_grad:
+                continue
+
+            if param.type() == 'torch.npu.HalfTensor':
                 stash.all_fp16_params.append(param)
-            elif param.type() == 'torch.cuda.FloatTensor':
+            elif param.type() == 'torch.npu.FloatTensor':
                 stash.all_fp32_params.append(param)
             else:
                 raise TypeError("Optimizer's parameters must be either "
-                                "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
+                                "torch.npu.FloatTensor or torch.npu.HalfTensor."
                                 "Received {}".format(param.type()))
 
     stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
     stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
 
+    stash.all_fp16_grad_stash_combine = None
+    stash.all_fp32_grad_stash_combine = None
+
+    stash.fp16_grad_list = []
+    stash.main_fp16_grad_combine = None
+    stash.main_fp16_grad_combine_mask = None
+
+    stash.fp32_grad_list = []
+    stash.main_fp32_grad_combine = None
+    stash.main_fp32_grad_combine_mask = None
+
+    stash.main_fp16_param_combine = None
+    stash.main_fp32_param_combine = None
+
 
 def prepare_backward_no_master_weights(self):
     stash = self._amp_stash
 
     self._amp_lazy_init()
+    self._check_already_combined_params_and_grads()
 
-    for i, param in enumerate(stash.all_fp16_params):
-        stash.all_fp16_grad_stash[i] = param.grad
-        # Set up to leverage grad copy elision:
-        param.grad = None
+    if (self.accelerate or self.is_npu_fused_optimizer) and stash.already_combined:
+        if stash.process_zero_grad:
+            return
 
-    for i, param in enumerate(stash.all_fp32_params):
-        stash.all_fp32_grad_stash[i] = param.grad
-        # Set up to leverage grad copy elision:
-        param.grad = None
+        if stash.main_fp16_grad_combine is not None:
+            stash.all_fp16_grad_stash_combine.copy_(stash.main_fp16_grad_combine)
+            stash.main_fp16_grad_combine.zero_()
+        if stash.main_fp32_grad_combine is not None:
+            stash.all_fp32_grad_stash_combine.copy_(stash.main_fp32_grad_combine)
+            stash.main_fp32_grad_combine.zero_()
+    else:
+        for i, param in enumerate(stash.all_fp16_params):
+            stash.all_fp16_grad_stash[i] = param.grad
+            # Set up to leverage grad copy elision:
+            param.grad = None
+
+        for i, param in enumerate(stash.all_fp32_params):
+            stash.all_fp32_grad_stash[i] = param.grad
+            # Set up to leverage grad copy elision:
+            param.grad = None
 
 
 def post_backward_no_master_weights(self, scaler):
     stash = self._amp_stash
 
     self._amp_lazy_init()
+    self._check_already_combined_params_and_grads()
+    self._amp_combined_init()
+    self._combine_ddp_proc()
+
+    if self.accelerate:
+        split_types = ((stash.main_fp16_grad_combine, stash.all_fp16_grad_stash_combine, stash.fp16_grad_list),
+                (stash.main_fp32_grad_combine, stash.all_fp32_grad_stash_combine, stash.fp32_grad_list))
+        for main_grads_combined, stash_grads_combined, main_grads_list in split_types:
+            if main_grads_combined is not None:
+                post_backward_models_are_masters(scaler, None, None, None, 
+                                                 main_grads_combined, stash_grads_combined,
+                                                 use_npu_fused_optimizer=self.is_npu_fused_optimizer,
+                                                 stashed_grads_are_zero=stash.process_zero_grad,
+                                                 main_grads_list=main_grads_list)
+    else:
+        split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
+                 (stash.all_fp32_params, stash.all_fp32_grad_stash))
 
-    split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
-             (stash.all_fp32_params, stash.all_fp32_grad_stash))
-
-    for params, stashed_grads in split_types:
-        post_backward_models_are_masters(scaler, params, stashed_grads)
+        for params, stashed_grads in split_types:
+            post_backward_models_are_masters(scaler, params, stashed_grads, 
+                                             use_npu_fused_optimizer=self.is_npu_fused_optimizer,
+                                             stashed_grads_are_zero=stash.process_zero_grad)
+    stash.process_zero_grad = False
 
 
 #####################################################################################
@@ -318,6 +647,546 @@
         stash.lazy_init_called = True
 
 
+@torch.no_grad()
+def combined_init_with_master_weights(self):
+    stash = self._amp_stash
+    if stash.already_combined:
+        return
+
+    if (not self.accelerate) and (not self.is_npu_fused_optimizer):
+        return
+
+    # fp32 from fp32
+    all_fp32_from_fp32_params, all_fp32_from_fp32_grad_stash = [], []
+    for param in stash.all_fp32_from_fp32_params:
+        if param.grad is not None:
+            if torch_npu.get_npu_format(param) != torch_npu.get_npu_format(param.grad):
+                param.grad = torch_npu.npu_format_cast(param.grad, torch_npu.get_npu_format(param)).contiguous()
+            all_fp32_from_fp32_params.append(param)
+            all_fp32_from_fp32_grad_stash.append(torch.zeros_like(param.grad))
+    stash.all_fp32_from_fp32_params = all_fp32_from_fp32_params
+    stash.all_fp32_from_fp32_grad_stash = all_fp32_from_fp32_grad_stash
+
+    if len(stash.all_fp32_from_fp32_grad_stash) > 0:
+        stash.all_fp32_from_fp32_grad_stash_combine = combine_npu(stash.all_fp32_from_fp32_grad_stash)
+
+    # fp32 from fp16
+    all_fp16_params, all_fp32_from_fp16_params = [], []
+    for fp16_param, fp32_from_fp16_param in zip(stash.all_fp16_params, stash.all_fp32_from_fp16_params):
+        if fp16_param.grad is not None:
+            if torch_npu.get_npu_format(fp16_param.grad) != torch_npu.get_npu_format(fp32_from_fp16_param):
+                fp16_param.grad = torch_npu.npu_format_cast(fp16_param.grad,
+                                                        torch_npu.get_npu_format(fp32_from_fp16_param)).contiguous()
+            fp32_from_fp16_param.grad = torch.zeros_like(fp32_from_fp16_param)
+            all_fp16_params.append(fp16_param)
+            all_fp32_from_fp16_params.append(fp32_from_fp16_param)
+    stash.all_fp16_params = all_fp16_params
+    stash.all_fp32_from_fp16_params = all_fp32_from_fp16_params
+
+    stash.main_fp16_grad_combine, stash.fp16_grad_list = get_grad_combined_tensor_from_param(stash.all_fp16_params)
+
+    stash.main_fp32_from_fp16_grad_combine, stash.fp32_from_fp16_grad_list = \
+        get_grad_combined_tensor_from_param(stash.all_fp32_from_fp16_params)
+    stash.main_fp32_from_fp32_grad_combine, stash.fp32_from_fp32_grad_list = \
+        get_grad_combined_tensor_from_param(stash.all_fp32_from_fp32_params)
+    # please do not change the order of tensor in this list.
+    stash.grads_list = [stash.main_fp16_grad_combine, 
+                        stash.main_fp32_from_fp16_grad_combine, 
+                        stash.main_fp32_from_fp32_grad_combine]
+
+    if self.is_npu_fused_optimizer:
+        # stash.main_fp16_param_combine = combine_npu(stash.all_fp16_params)
+        stash.main_fp32_from_fp16_param_combine = combine_npu(stash.all_fp32_from_fp16_params)
+        stash.main_fp32_from_fp32_param_combine = combine_npu(stash.all_fp32_from_fp32_params)
+    
+    stash.already_combined = True
+
+
+@torch.no_grad()
+def combined_init_no_master_weights(self):
+    stash = self._amp_stash
+    if stash.already_combined:
+        return
+
+    if (not self.accelerate) and (not self.is_npu_fused_optimizer):
+        return
+
+    all_fp16_params, all_fp16_grad_stash = [], []
+    for param in stash.all_fp16_params:
+        if param.grad is not None:
+            if torch_npu.get_npu_format(param) != torch_npu.get_npu_format(param.grad):
+                param.grad = torch_npu.npu_format_cast(param.grad, torch_npu.get_npu_format(param)).contiguous()
+            all_fp16_params.append(param)
+            all_fp16_grad_stash.append(torch.zeros_like(param.grad))
+
+    stash.all_fp16_params = all_fp16_params
+    stash.all_fp16_grad_stash = all_fp16_grad_stash
+
+    all_fp32_params, all_fp32_grad_stash = [], []
+    for param in stash.all_fp32_params:
+        if param.grad is not None:
+            if torch_npu.get_npu_format(param) != torch_npu.get_npu_format(param.grad):
+                param.grad = torch_npu.npu_format_cast(param.grad, torch_npu.get_npu_format(param)).contiguous()
+            all_fp32_params.append(param)
+            all_fp32_grad_stash.append(torch.zeros_like(param.grad))
+
+    stash.all_fp32_params = all_fp32_params
+    stash.all_fp32_grad_stash = all_fp32_grad_stash
+
+    if len(stash.all_fp16_grad_stash) > 0:
+        # if len == 0, avoid to create a useless combined tensor
+        stash.all_fp16_grad_stash_combine = combine_npu(stash.all_fp16_grad_stash, require_copy_value=False)
+    if len(stash.all_fp32_grad_stash) > 0:
+        stash.all_fp32_grad_stash_combine = combine_npu(stash.all_fp32_grad_stash, require_copy_value=False)
+
+    stash.main_fp16_grad_combine, stash.fp16_grad_list = get_grad_combined_tensor_from_param(stash.all_fp16_params)
+    stash.main_fp32_grad_combine, stash.fp32_grad_list = get_grad_combined_tensor_from_param(stash.all_fp32_params)
+    # please do not change the order of tensor in this list.
+    stash.grads_list = [stash.main_fp16_grad_combine, stash.main_fp32_grad_combine]
+
+    if self.is_npu_fused_optimizer:
+        # stash.main_fp16_param_combine = combine_npu(stash.all_fp16_params)
+        stash.main_fp32_param_combine = combine_npu(stash.all_fp32_params)
+
+    stash.already_combined = True
+
+
+def reset_all_combine_flags(self):
+    stash = self._amp_stash
+    stash.already_combined = False
+    stash.params_grads_are_combined_by_group = False
+    stash.param_states_are_combined_by_group = False
+
+
+def check_already_combined_params_and_grads_with_master_weights(self):
+    stash = self._amp_stash
+    if not self.check_combined_tensors or not stash.already_combined:
+        return
+
+    fp16_grad_list = []
+    for param in stash.all_fp16_params:
+        if param.requires_grad:
+            fp16_grad_list.append(param.grad)
+
+    fp32_from_fp16_grad_list = []
+    for param in stash.all_fp32_from_fp16_params:
+        if param.requires_grad:
+            fp32_from_fp16_grad_list.append(param.grad)
+
+    fp32_from_fp32_grad_list = []
+    for param in stash.all_fp32_from_fp32_params:
+        if param.requires_grad:
+            fp32_from_fp32_grad_list.append(param.grad)
+
+    if not is_combined_tensor_valid(stash.main_fp16_grad_combine, fp16_grad_list) or \
+        not is_combined_tensor_valid(stash.main_fp32_from_fp16_grad_combine, fp32_from_fp16_grad_list) or \
+        not is_combined_tensor_valid(stash.main_fp32_from_fp32_grad_combine, fp32_from_fp32_grad_list):
+        maybe_print("Combined grad has been destroyed and will be recombined afterwards, please check if "
+                    "there is any operation that may change the data_ptr/size/format of the grads.")
+        self._reset_all_combine_flags()
+        return
+
+    if self.is_npu_fused_optimizer:
+        if not is_combined_tensor_valid(stash.main_fp32_from_fp16_param_combine, stash.all_fp32_from_fp16_params) or \
+            not is_combined_tensor_valid(stash.main_fp32_from_fp32_param_combine, stash.all_fp32_from_fp32_params):
+            maybe_print("Combined param has been destroyed and will be recombined afterwards, please check if "
+                        "there is any operation that may change the data_ptr/size/format of the params.")
+            self._reset_all_combine_flags()
+            return
+
+
+def check_already_combined_params_and_grads_no_master_weights(self):
+    stash = self._amp_stash
+    if not self.check_combined_tensors or not stash.already_combined:
+        return
+
+    fp16_grad_list = []
+    for param in stash.all_fp16_params:
+        if param.requires_grad:
+            fp16_grad_list.append(param.grad)
+
+    fp32_grad_list = []
+    for param in stash.all_fp32_params:
+        if param.requires_grad:
+            fp32_grad_list.append(param.grad)
+
+    if not is_combined_tensor_valid(stash.main_fp16_grad_combine, fp16_grad_list) or \
+        not is_combined_tensor_valid(stash.main_fp32_grad_combine, fp32_grad_list):
+        maybe_print("Combined grad has been destroyed and will be recombined afterwards, please check if "
+                    "there is any operation that may change the data_ptr/size/format of the grads.")
+        self._reset_all_combine_flags()
+        return
+
+    if self.is_npu_fused_optimizer:
+        if not is_combined_tensor_valid(stash.main_fp32_param_combine, stash.all_fp32_params):
+            maybe_print("Combined param has been destroyed and will be recombined afterwards, please check if "
+                        "there is any operation that may change the data_ptr/size/format of the params.")
+            self._reset_all_combine_flags()
+            return
+
+
+def is_grad_in_combined_tensor(grad, combined_tensor):
+    if combined_tensor is None:
+        return False
+
+    combined_tensor_data_start_addr = combined_tensor.data_ptr()
+    combined_tensor_data_end_addr = combined_tensor.data_ptr() + \
+                                    combined_tensor.numel() * combined_tensor.element_size()
+    
+    if combined_tensor_data_start_addr <= grad.data_ptr() < combined_tensor_data_end_addr:
+        return True
+    else:
+        return False
+
+
+def combine_params_and_grads_by_group_with_master_weights(self):
+    stash = self._amp_stash
+    if stash.params_grads_are_combined_by_group:
+        return
+
+    self._amp_combined_init()
+    stash.combined_params_indexed_by_group = []
+    stash.combined_grads_indexed_by_group = []
+    stash.params_lists_indexed_by_group = []
+
+    combined_fp32_from_fp32_param = stash.main_fp32_from_fp32_param_combine
+    combined_fp32_from_fp16_param = stash.main_fp32_from_fp16_param_combine
+    combined_fp32_from_fp32_grad = stash.main_fp32_from_fp32_grad_combine
+    combined_fp32_from_fp16_grad = stash.main_fp32_from_fp16_grad_combine
+
+    combined_group_fp32_from_fp32_param_index, combined_group_fp32_from_fp16_param_index = 0, 0
+    combined_group_fp32_from_fp32_grad_index, combined_group_fp32_from_fp16_grad_index = 0, 0
+
+    group_num = 0
+    for group in self.param_groups:
+        group_num += 1
+
+        group_fp32_from_fp32_params = []
+        group_fp32_from_fp16_params = []
+        group_fp32_from_fp32_param_size, group_fp32_from_fp16_param_size = 0, 0
+        group_fp32_from_fp32_grad_size, group_fp32_from_fp16_grad_size = 0, 0
+
+        for p in group['params']:
+            if p.grad is None:
+                continue
+            param_size = get_aligned_storage_size(p)
+            grad_size = get_aligned_storage_size(p.grad)
+            if is_grad_in_combined_tensor(p.grad, combined_fp32_from_fp32_grad):
+                group_fp32_from_fp32_param_size += param_size
+                group_fp32_from_fp32_params.append(p)
+                group_fp32_from_fp32_grad_size += grad_size
+            else:
+                group_fp32_from_fp16_param_size += param_size
+                group_fp32_from_fp16_params.append(p)
+                group_fp32_from_fp16_grad_size += grad_size
+
+        combined_group_fp32_from_fp32_param = None
+        combined_group_fp32_from_fp16_param = None
+        combined_group_fp32_from_fp32_grad = None
+        combined_group_fp32_from_fp16_grad = None
+
+        combined_group_fp32_from_fp32_param = get_part_combined_tensor(combined_fp32_from_fp32_param,
+                                                                       combined_group_fp32_from_fp32_param_index,
+                                                                       group_fp32_from_fp32_param_size)
+        combined_group_fp32_from_fp16_param = get_part_combined_tensor(combined_fp32_from_fp16_param,
+                                                                       combined_group_fp32_from_fp16_param_index,
+                                                                       group_fp32_from_fp16_param_size)
+        combined_group_fp32_from_fp32_grad = get_part_combined_tensor(combined_fp32_from_fp32_grad, 
+                                                                      combined_group_fp32_from_fp32_grad_index,
+                                                                      group_fp32_from_fp32_grad_size)
+        combined_group_fp32_from_fp16_grad = get_part_combined_tensor(combined_fp32_from_fp16_grad, 
+                                                                      combined_group_fp32_from_fp16_grad_index,
+                                                                      group_fp32_from_fp16_grad_size)
+
+        combined_group_fp32_from_fp32_param_index += group_fp32_from_fp32_param_size
+        combined_group_fp32_from_fp16_param_index += group_fp32_from_fp16_param_size
+        combined_group_fp32_from_fp32_grad_index += group_fp32_from_fp32_grad_size
+        combined_group_fp32_from_fp16_grad_index += group_fp32_from_fp16_grad_size
+
+        combined_params = []
+        combined_grads = []
+        params_list = []
+
+        combined_params.append(combined_group_fp32_from_fp32_param)
+        combined_params.append(combined_group_fp32_from_fp16_param)
+        combined_grads.append(combined_group_fp32_from_fp32_grad)
+        combined_grads.append(combined_group_fp32_from_fp16_grad)
+        params_list.append(group_fp32_from_fp32_params)
+        params_list.append(group_fp32_from_fp16_params)
+
+        stash.combined_params_indexed_by_group.append(combined_params)
+        stash.combined_grads_indexed_by_group.append(combined_grads)
+        stash.params_lists_indexed_by_group.append(params_list)
+
+    maybe_print("group num: {}".format(group_num))
+    stash.params_grads_are_combined_by_group = True
+
+
+def combine_params_and_grads_by_group_no_master_weights(self):
+    stash = self._amp_stash
+    if stash.params_grads_are_combined_by_group:
+        return
+
+    self._amp_combined_init()
+    stash.combined_params_indexed_by_group = []
+    stash.combined_grads_indexed_by_group = []
+    stash.params_lists_indexed_by_group = []
+
+    combined_fp32_param = stash.main_fp32_param_combine
+    combined_fp32_grad = stash.main_fp32_grad_combine
+
+    combined_group_fp32_param_index = 0
+    combined_group_fp32_grad_index = 0
+
+    group_num = 0
+    for group in self.param_groups:
+        group_num += 1
+
+        group_fp32_params = []
+        group_fp32_param_size = 0
+        group_fp32_grad_size = 0
+
+        for p in group['params']:
+            if p.grad is None:
+                continue
+
+            param_size = get_aligned_storage_size(p)
+            group_fp32_param_size += param_size
+            group_fp32_params.append(p)
+
+            grad_size = get_aligned_storage_size(p.grad)
+            group_fp32_grad_size += grad_size
+
+        combined_group_fp32_param = None
+        combined_group_fp32_grad = None
+        combined_group_fp32_param = get_part_combined_tensor(combined_fp32_param, 
+                                                             combined_group_fp32_param_index,
+                                                             group_fp32_param_size)
+        combined_group_fp32_grad = get_part_combined_tensor(combined_fp32_grad, 
+                                                            combined_group_fp32_grad_index, 
+                                                            group_fp32_grad_size)
+        combined_group_fp32_param_index += group_fp32_param_size
+        combined_group_fp32_grad_index += group_fp32_grad_size
+
+        combined_params = []
+        combined_grads = []
+        params_list = []
+
+        combined_params.append(combined_group_fp32_param)
+        combined_grads.append(combined_group_fp32_grad)
+        params_list.append(group_fp32_params)
+
+        stash.combined_params_indexed_by_group.append(combined_params)
+        stash.combined_grads_indexed_by_group.append(combined_grads)
+        stash.params_lists_indexed_by_group.append(params_list)
+
+    maybe_print("group num: {}".format(group_num))
+    stash.params_grads_are_combined_by_group = True
+
+
+def new_zero_grad_with_master_weights(self):
+    stash = self._amp_stash
+    self._amp_lazy_init()
+    # Zero the model grads.
+    for param in stash.all_fp16_params:
+        if param.grad is not None:
+            param.grad.detach_()
+            param.grad.zero_()
+    for param in stash.all_fp32_from_fp32_params:
+        if param.grad is not None:
+            param.grad.detach_()
+            param.grad.zero_()
+    # Clear the master grads that are independent of model grads
+    for param in stash.all_fp32_from_fp16_params:
+        param.grad = None
+
+
+def new_zero_grad_accelerate_with_master_weights(self):
+    stash = self._amp_stash
+    self._amp_lazy_init()
+    self._check_already_combined_params_and_grads()
+    # Zero the model grads.
+    stash.process_zero_grad = True
+
+    if not stash.already_combined:
+        for param in stash.all_fp16_params:
+            if param.grad is not None:
+                param.grad.detach_()
+                param.grad.zero_()
+        for param in stash.all_fp32_from_fp32_params:
+            if param.grad is not None:
+                param.grad.detach_()
+                param.grad.zero_()
+        for param in stash.all_fp32_from_fp16_params:
+            if param.grad is not None:
+                param.grad.zero_()
+        return
+
+    if stash.main_fp16_grad_combine is not None:
+        stash.main_fp16_grad_combine.zero_()
+    if stash.main_fp32_from_fp32_grad_combine is not None:
+        stash.main_fp32_from_fp32_grad_combine.zero_()
+    # Clear the master grads that are independent of model grads
+    if stash.main_fp32_from_fp16_grad_combine is not None:
+        stash.main_fp32_from_fp16_grad_combine.zero_()
+
+
+def can_get_combined_tensors(self, name):
+    if name == 'params':
+        if not self.is_npu_fused_optimizer:
+            maybe_print("To get combined params, please use npu fused optimizer.")
+            return False
+    elif name == 'grads' or name == 'grad_masks':
+        if (not self.accelerate) and (not self.is_npu_fused_optimizer):
+            maybe_print("To get combined {}, please set combine_grad=True or use npu fused optimizer.".format(name))
+            return False
+    else:
+        maybe_print("{} are not supported to be combined.".format(name))
+        return False
+
+    stash = self._amp_stash
+    if not stash.already_combined:
+        maybe_print("Please get the combined {} after backward phase.".format(name))
+        return False
+    return True
+
+
+def get_model_combined_params(self):
+    stash = self._amp_stash
+    combined_params = []
+
+    if not self._can_get_combined_tensors('params'):
+        return combined_params
+
+    self._check_already_combined_params_and_grads()
+    self._amp_combined_init()
+
+    if stash.master_weights:
+        combined_params.append(stash.main_fp16_param_combine)
+        combined_params.append(stash.main_fp32_from_fp32_param_combine)
+    else:
+        combined_params.append(stash.main_fp32_param_combine)
+    return combined_params
+
+
+def get_model_combined_grads(self):
+    stash = self._amp_stash
+    combined_grads = []
+
+    if not self._can_get_combined_tensors('grads'):
+        return combined_grads
+
+    self._check_already_combined_params_and_grads()
+    self._amp_combined_init()
+
+    if stash.master_weights:
+        combined_grads.append(stash.main_fp16_grad_combine)
+        combined_grads.append(stash.main_fp32_from_fp32_grad_combine)
+    else:
+        combined_grads.append(stash.main_fp32_grad_combine)
+    return combined_grads
+
+
+def get_model_combined_grad_masks(self):
+    stash = self._amp_stash
+    combined_grad_masks = []
+
+    if not self._can_get_combined_tensors('grad_masks'):
+        return combined_grad_masks
+
+    if stash.master_weights:
+        if stash.main_fp16_grad_combine_mask is None:
+            stash.main_fp16_grad_combine_mask = \
+                get_grad_combined_tensor_mask_from_param(stash.all_fp16_params)
+            stash.main_fp32_from_fp32_grad_combine_mask = \
+                get_grad_combined_tensor_mask_from_param(stash.all_fp32_from_fp32_params)
+        combined_grad_masks.append(stash.main_fp16_grad_combine_mask)
+        combined_grad_masks.append(stash.main_fp32_from_fp32_grad_combine_mask)
+    else:
+        if stash.main_fp32_grad_combine_mask is None:
+            stash.main_fp32_grad_combine_mask = \
+                get_grad_combined_tensor_mask_from_param(stash.all_fp32_params)
+        combined_grad_masks.append(stash.main_fp32_grad_combine_mask)
+    return combined_grad_masks
+
+
+def get_optimizer_combined_params(self):
+    stash = self._amp_stash
+    combined_params = []
+
+    if not self._can_get_combined_tensors('params'):
+        return combined_params
+
+    self._check_already_combined_params_and_grads()
+    self._amp_combined_init()
+
+    if stash.master_weights:
+        combined_params.append(stash.main_fp32_from_fp16_param_combine)
+        combined_params.append(stash.main_fp32_from_fp32_param_combine)
+    else:
+        combined_params.append(stash.main_fp32_param_combine)
+    return combined_params
+
+
+def get_optimizer_combined_grads(self):
+    stash = self._amp_stash
+    combined_grads = []
+
+    if not self._can_get_combined_tensors('grads'):
+        return combined_grads
+
+    self._check_already_combined_params_and_grads()
+    self._amp_combined_init()
+
+    if stash.master_weights:
+        combined_grads.append(stash.main_fp32_from_fp16_grad_combine)
+        combined_grads.append(stash.main_fp32_from_fp32_grad_combine)
+    else:
+        combined_grads.append(stash.main_fp32_grad_combine)
+    return combined_grads
+
+
+def get_optimizer_combined_grad_masks(self):
+    stash = self._amp_stash
+    combined_grad_masks = []
+
+    if not self._can_get_combined_tensors('grad_masks'):
+        return combined_grad_masks
+
+    if stash.master_weights:
+        if stash.main_fp32_from_fp16_grad_combine_mask is None:
+            stash.main_fp32_from_fp16_grad_combine_mask = \
+                get_grad_combined_tensor_mask_from_param(stash.all_fp32_from_fp16_params)
+            stash.main_fp32_from_fp32_grad_combine_mask = \
+                get_grad_combined_tensor_mask_from_param(stash.all_fp32_from_fp32_params)
+        combined_grad_masks.append(stash.main_fp32_from_fp16_grad_combine_mask)
+        combined_grad_masks.append(stash.main_fp32_from_fp32_grad_combine_mask)
+    else:
+        if stash.main_fp32_grad_combine_mask is None:
+            stash.main_fp32_grad_combine_mask = \
+                get_grad_combined_tensor_mask_from_param(stash.all_fp32_params)
+        combined_grad_masks.append(stash.main_fp32_grad_combine_mask)
+    return combined_grad_masks
+
+
+def clip_model_grad_norm_fused(self, max_norm, norm_type=2):
+    stash = self._amp_stash
+    if stash.master_weights:
+        raise RuntimeError("clip_model_grad_norm_fused can only be used when opt_level='O1'")
+
+    combined_grads = self.get_model_combined_grads()
+    combined_grad_masks = self.get_model_combined_grad_masks()
+    total_norm = clip_grad_norm_fused(combined_grads, combined_grad_masks, max_norm, norm_type)
+    return total_norm
+
+
+def clip_optimizer_grad_norm_fused(self, max_norm, norm_type=2):
+    combined_grads = self.get_optimizer_combined_grads()
+    combined_grad_masks = self.get_optimizer_combined_grad_masks()
+    total_norm = clip_grad_norm_fused(combined_grads, combined_grad_masks, max_norm, norm_type)
+    return total_norm
+
+
 def _process_optimizer(optimizer, properties):
     if hasattr(optimizer, "_amp_stash"):
         raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
@@ -327,15 +1196,64 @@
     optimizer._amp_stash.lazy_init_called = False
     optimizer._amp_stash.already_patched = False
     optimizer._amp_stash.params_have_scaled_gradients = False
+    optimizer.accelerate = properties.combine_grad
+    optimizer.combine_ddp = properties.combine_ddp
+    optimizer.init_combine_ddp = False
+    optimizer.ddp_replica_count = properties.ddp_replica_count
+    optimizer.check_combined_tensors = properties.check_combined_tensors
+    optimizer._amp_stash.master_weights = properties.master_weights
+    optimizer._amp_stash.grads_list = []
+    optimizer._amp_stash.already_combined = False
+
+    optimizer._amp_stash.process_zero_grad = True
+
+    optimizer._amp_stash.params_grads_are_combined_by_group = False
+    optimizer._amp_stash.combined_params_indexed_by_group = []
+    optimizer._amp_stash.combined_grads_indexed_by_group = []
+    optimizer._amp_stash.params_lists_indexed_by_group = []
+    optimizer._amp_stash.param_states_are_combined_by_group = False
+    optimizer._amp_stash.combined_param_states_indexed_by_group = []
 
     for name in ("_lazy_init_maybe_master_weights",
                  "_master_params_to_model_params",
                  "_prepare_amp_backward",
                  "_post_amp_backward",
-                 "_amp_lazy_init"):
+                 "_amp_lazy_init",
+                 "_amp_combined_init",
+                 "_reset_all_combine_flags",
+                 "_check_already_combined_params_and_grads",
+                 "_combine_params_and_grads_by_group",
+                 "_can_get_combined_tensors",
+                 "get_model_combined_params",
+                 "get_model_combined_grads",
+                 "get_optimizer_combined_params",
+                 "get_optimizer_combined_grads"):
         if hasattr(optimizer, name):
             raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
 
+    if hasattr(optimizer, "is_npu_fused_optimizer") and optimizer.is_npu_fused_optimizer is True:
+        maybe_print("Use npu fused optimizer")
+        if properties.opt_level != "O1" and properties.opt_level != "O2":
+            raise RuntimeError("Currently, npu fused optimizer can only be used when opt_level='O1' or opt_level='O2'")
+    else:
+        optimizer.is_npu_fused_optimizer = False
+
+    if properties.combine_grad or optimizer.is_npu_fused_optimizer:
+        if properties.opt_level == "O2" and properties.master_weights != True:
+            raise RuntimeError("With opt_level O2, master_weights should be True when combine_grad is True or "
+                               "npu fused optimizer is used")
+    else:
+        if properties.check_combined_tensors:
+            maybe_print("Because combine_grad != True and no npu fused optimizer is used, "
+                        "checking combined tensors function will not take effect!")
+
+    if optimizer.is_npu_fused_optimizer:
+        old_load_state_dict = optimizer.load_state_dict
+        def new_load_state_dict(self, state_dict):
+            old_load_state_dict(state_dict)
+            self._amp_stash.param_states_are_combined_by_group = False
+        optimizer.load_state_dict = types.MethodType(new_load_state_dict, optimizer)
+
     # TODO:  Centralize exposure and import error checking for the C backend.
     if multi_tensor_applier.available:
         import amp_C
@@ -352,34 +1270,31 @@
 
         old_step = optimizer.step
         def new_step(self, closure=None):
+            stash = self._amp_stash
             if closure is not None:
                 raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
             retval = old_step()
             if not isinstance(self, FusedSGD):
                 self._master_params_to_model_params()
             # Clear the master grads that wouldn't be zeroed by model.zero_grad()
-            for param in self._amp_stash.all_fp32_from_fp16_params:
-                param.grad = None
+            if optimizer.accelerate or optimizer.is_npu_fused_optimizer:
+                if stash.main_fp32_from_fp16_grad_combine is not None:
+                    stash.main_fp32_from_fp16_grad_combine.zero_()
+            else:
+                for param in stash.all_fp32_from_fp16_params:
+                    param.grad = None
             return retval
         optimizer.step = types.MethodType(new_step, optimizer)
 
         old_zero_grad = optimizer.zero_grad
-        def new_zero_grad(self):
-            stash = self._amp_stash
-            self._amp_lazy_init()
-            # Zero the model grads.
-            for param in stash.all_fp16_params:
-                if param.grad is not None:
-                    param.grad.detach_()
-                    param.grad.zero_()
-            for param in stash.all_fp32_from_fp32_params:
-                if param.grad is not None:
-                    param.grad.detach_()
-                    param.grad.zero_()
-            # Clear the master grads that are independent of model grads
-            for param in self._amp_stash.all_fp32_from_fp16_params:
-                param.grad = None
-        optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
+        if optimizer.accelerate or optimizer.is_npu_fused_optimizer:
+            optimizer.zero_grad = types.MethodType(new_zero_grad_accelerate_with_master_weights, optimizer)
+        else:
+            optimizer.zero_grad = types.MethodType(new_zero_grad_with_master_weights, optimizer)
+
+        if optimizer.is_npu_fused_optimizer:
+            optimizer._combine_params_and_grads_by_group = types.MethodType(
+                combine_params_and_grads_by_group_with_master_weights, optimizer)
 
         if isinstance(optimizer, FusedSGD):
             optimizer._prepare_amp_backward = types.MethodType(
@@ -391,10 +1306,39 @@
                 prepare_backward_with_master_weights, optimizer)
             optimizer._post_amp_backward = types.MethodType(
                 post_backward_with_master_weights, optimizer)
+            optimizer._init_combine_ddp = types.MethodType(
+                init_combine_ddp_with_master_weights, optimizer)
+        
+        optimizer._amp_combined_init = types.MethodType(combined_init_with_master_weights, optimizer)
+        optimizer._check_already_combined_params_and_grads = types.MethodType(
+            check_already_combined_params_and_grads_with_master_weights, optimizer)
     else:
         optimizer._lazy_init_maybe_master_weights = types.MethodType(
             lazy_init_no_master_weights, optimizer)
 
+        old_zero_grad = optimizer.zero_grad
+        if optimizer.accelerate or optimizer.is_npu_fused_optimizer:
+            def new_zero_grad_accelerate_no_master_weights(self):
+                stash = self._amp_stash
+                self._amp_lazy_init()
+                self._check_already_combined_params_and_grads()
+                # Zero the model grads.
+                stash.process_zero_grad = True
+
+                if not stash.already_combined:
+                    old_zero_grad()
+                    return
+
+                if stash.main_fp16_grad_combine is not None:
+                    stash.main_fp16_grad_combine.zero_()
+                if stash.main_fp32_grad_combine is not None:
+                    stash.main_fp32_grad_combine.zero_()
+            optimizer.zero_grad = types.MethodType(new_zero_grad_accelerate_no_master_weights, optimizer)
+
+        if optimizer.is_npu_fused_optimizer:
+            optimizer._combine_params_and_grads_by_group = types.MethodType(
+                combine_params_and_grads_by_group_no_master_weights, optimizer)
+
         if isinstance(optimizer, FusedSGD):
             optimizer._prepare_amp_backward = types.MethodType(
                 prepare_backward_no_master_weights_FusedSGD, optimizer)
@@ -405,8 +1349,27 @@
                 prepare_backward_no_master_weights, optimizer)
             optimizer._post_amp_backward = types.MethodType(
                 post_backward_no_master_weights, optimizer)
+            optimizer._init_combine_ddp = types.MethodType(
+                init_combine_ddp_no_master_weights, optimizer)
+
+        optimizer._amp_combined_init = types.MethodType(combined_init_no_master_weights, optimizer)
+        optimizer._check_already_combined_params_and_grads = types.MethodType(
+            check_already_combined_params_and_grads_no_master_weights, optimizer)
 
     optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)
+    optimizer._reset_all_combine_flags = types.MethodType(reset_all_combine_flags, optimizer)
+    optimizer._can_get_combined_tensors = types.MethodType(can_get_combined_tensors, optimizer)
+    optimizer.get_model_combined_params = types.MethodType(get_model_combined_params, optimizer)
+    optimizer.get_model_combined_grads = types.MethodType(get_model_combined_grads, optimizer)
+    optimizer.get_model_combined_grad_masks = types.MethodType(get_model_combined_grad_masks, optimizer)
+    optimizer.get_optimizer_combined_params = types.MethodType(get_optimizer_combined_params, optimizer)
+    optimizer.get_optimizer_combined_grads = types.MethodType(get_optimizer_combined_grads, optimizer)
+    optimizer.get_optimizer_combined_grad_masks = types.MethodType(get_optimizer_combined_grad_masks, optimizer)
+    optimizer.clip_model_grad_norm_fused = types.MethodType(clip_model_grad_norm_fused, optimizer)
+    optimizer.clip_optimizer_grad_norm_fused = types.MethodType(clip_optimizer_grad_norm_fused, optimizer)
+    optimizer._combine_ddp_proc = types.MethodType(combine_ddp_proc, optimizer)
+    optimizer._init_combine_ddp_common = types.MethodType(init_combine_ddp_common, optimizer)
+    optimizer._combine_ddp_all_reduce = types.MethodType(combine_ddp_all_reduce, optimizer)
 
     old_add_param_group = optimizer.add_param_group
 
@@ -435,13 +1398,13 @@
             fp32_from_fp16_params_this_group = []
             for i, param in enumerate(new_group['params']):
                 if param.requires_grad:
-                    if param.type() == 'torch.cuda.HalfTensor':
+                    if param.type() == 'torch.npu.HalfTensor':
                         fp16_params_this_group.append(param)
                         master_param = param.detach().clone().float()
                         master_param.requires_grad = True
                         new_group['params'][i] = master_param
                         fp32_from_fp16_params_this_group.append(master_param)
-                    elif param.type() == 'torch.cuda.FloatTensor':
+                    elif param.type() == 'torch.npu.FloatTensor':
                         fp32_params_this_group.append(param)
                         new_group['params'][i] = param
                     else:
@@ -457,24 +1420,13 @@
             stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
             stash.all_fp32_from_fp32_params += fp32_params_this_group
 
-            # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
             stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
-
-            # It should be ok to let params be added with existing .grad attributes.
-            # for param in fp16_params_this_group:
-            #     param.grad = None
-
-            # for param in fp32_from_fp16_params_this_group:
-            #     param.grad = None
-
-            # for param in stash.fp32_params_this_group:
-            #     param.grad = None
         else:
             for param in new_group['params']:
-                if param.type() == 'torch.cuda.HalfTensor':
+                if param.type() == 'torch.npu.HalfTensor':
                     stash.all_fp16_params.append(param)
                     stash.all_fp16_grad_stash.append(None)
-                elif param.type() == 'torch.cuda.FloatTensor':
+                elif param.type() == 'torch.npu.FloatTensor':
                     stash.all_fp32_params.append(param)
                     stash.all_fp32_grad_stash.append(None)
                 else:
diff -Nur '--exclude=.git' apex/apex/amp/scaler.py apex-develop/apex/amp/scaler.py
--- apex/apex/amp/scaler.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/scaler.py	2024-03-07 21:33:04.293391422 +0800
@@ -1,7 +1,27 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import torch
+import torch.distributed as dist
+import torch_npu
+
 from ..multi_tensor_apply import multi_tensor_applier
 from ._amp_state import _amp_state, master_params, maybe_print
 from itertools import product
+import importlib
 
 def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
     # Exception handling for 18.04 compatibility
@@ -16,7 +36,8 @@
         master_grad.mul_(scale)
     return False
 
-def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
+def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, use_npu_fused_optimizer, 
+                                check_overflow=False):
     # Exception handling for 18.04 compatibility
     if check_overflow:
         cpu_sum = float(model_grad.float().sum())
@@ -27,7 +48,10 @@
     #     master_grad.copy_(model_grad)
     assert stashed_grad.dtype == master_grad.dtype
     converted_model_grad = model_grad.data.to(master_grad.dtype)
-    master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
+    if use_npu_fused_optimizer:
+        master_grad.data[:] = a*converted_model_grad.data + b*stashed_grad.data
+    else:
+        master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
     return False
 
 class LossScaler(object):
@@ -38,10 +62,14 @@
     def __init__(self,
                  loss_scale,
                  init_scale=2.**16,
-                 scale_factor=2.,
+                 scale_growth_factor=2.,
+                 scale_backoff_factor=0.5,
                  scale_window=2000,
                  min_loss_scale=None,
                  max_loss_scale=2.**24):
+        self._is_support_inf_nan = hasattr(
+            torch_npu.npu.utils, 'is_support_inf_nan') and torch_npu.npu.utils.is_support_inf_nan()
+
         if loss_scale == "dynamic":
             self.dynamic = True
             self._loss_scale = min(max_loss_scale, init_scale)
@@ -50,30 +78,103 @@
             self._loss_scale = loss_scale
         self._max_loss_scale = max_loss_scale
         self._min_loss_scale = min_loss_scale
+        self._scale_growth_factor = scale_growth_factor
+        self._scale_backoff_factor = scale_backoff_factor
         self._scale_seq_len = scale_window
         self._unskipped = 0
         self._has_overflow = False
-        self._overflow_buf = torch.cuda.IntTensor([0])
+        self._overflow_checked = False
+        self._overflow_buf = torch.npu.FloatTensor([0.])
+        self._dist_overflow_count = torch.Tensor([0.]).to('npu')
+        self._dist_initialized = False
+
+        try:
+            if dist.is_initialized():
+                self._dist_initialized = True
+        except AttributeError as err:
+            maybe_print("torch.distributed has no attribute is_initialized")
+
         if multi_tensor_applier.available:
             import amp_C
             LossScaler.has_fused_kernel = multi_tensor_applier.available
             LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
             LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
         else:
-            if not LossScaler.warned_no_fused_kernel:
-                maybe_print(
-                    "Warning:  multi_tensor_applier fused unscale kernel is unavailable, "
-                    "possibly because apex was installed without --cuda_ext --cpp_ext. "
-                    "Using Python fallback.  Original ImportError was: " +
-                    repr(multi_tensor_applier.import_err),
-                    True)
             LossScaler.has_fused_kernel = False
             LossScaler.warned_no_fused_kernel = True
 
     def loss_scale(self):
         return self._loss_scale
 
+    def check_overflow_and_sync(self):
+        if self.dynamic and self._is_support_inf_nan:
+            return
+        if self.dynamic:
+            if not self._overflow_checked:
+                self._has_overflow = torch_npu.npu.get_npu_overflow_flag()
+                self._overflow_checked = True
+
+            if self._dist_initialized:
+                if self._has_overflow:
+                    self._dist_overflow_count.add_(1)
+                    dist.all_reduce(self._dist_overflow_count)
+                    self._dist_overflow_count.zero_()
+                else:
+                    dist.all_reduce(self._dist_overflow_count)
+                    if self._dist_overflow_count.item() != 0:
+                        self._has_overflow = True
+                    self._dist_overflow_count.zero_()
+        else:
+            self._has_overflow = False
+
+    def check_grads_overflow_with_inf(self, model_grads):
+        if not self.dynamic or not self._is_support_inf_nan:
+            return False
+
+        model_grads_valid = list(filter(lambda x: x is not None, model_grads))
+        torch._amp_foreach_non_finite_check_and_unscale_(model_grads_valid, self._overflow_buf, torch.tensor(1.).npu())
+        self._has_overflow = self._overflow_buf.item() > 0
+        self._overflow_buf.zero_()
+
+        return self._has_overflow
+
+    def unscale_foreach(self, model_grads, master_grads, scale):
+        if not self._is_support_inf_nan and self._has_overflow:
+            return
+
+        model_grads_valid = []
+        for model, master in zip(model_grads, master_grads):
+            if model is not None:
+                if not LossScaler.warned_unscaling_non_fp32_grad:
+                    if master.dtype != torch.float32:
+                        maybe_print(
+                            "Attempting to unscale a grad with type {} ".format(master.type()) +
+                            "Unscaling non-fp32 grads may indicate an error. "
+                            "When using Amp, you don't need to call .half() on your model.")
+                        LossScaler.warned_unscaling_non_fp32_grad = True
+                model_grads_valid.append(model)
+
+        if self.dynamic:
+            torch._amp_foreach_non_finite_check_and_unscale_(model_grads_valid, self._overflow_buf, torch.tensor(1./scale).npu())
+            self._has_overflow = self._overflow_buf.item() > 0
+            self._overflow_buf.zero_()
+            if not self._has_overflow:
+                for model, master in zip(model_grads, master_grads):
+                    if model is not None and master is not model:
+                        master.copy_(model)
+            return
+
+        for model, master in zip(model_grads, master_grads):
+            if model is not None:
+                if master is not model:
+                    master.copy_(model)
+                if scale != 1.0:
+                    master.mul_(1./scale)
+
     def unscale_python(self, model_grads, master_grads, scale):
+        if not self._is_support_inf_nan and self._has_overflow:
+            return
+
         for model, master in zip(model_grads, master_grads):
             if model is not None:
                 if not LossScaler.warned_unscaling_non_fp32_grad:
@@ -86,7 +187,7 @@
                 self._has_overflow = scale_check_overflow_python(model,
                                                                  master,
                                                                  1./scale,
-                                                                 self.dynamic)
+                                                                 self.dynamic and self._is_support_inf_nan)
                 if self._has_overflow and self.dynamic:
                     break
 
@@ -116,19 +217,73 @@
                                  [model_grads, master_grads],
                                  1./scale)
         else:
-            self.unscale_python(model_grads, master_grads, scale)
-
+            if self._is_support_inf_nan:
+                self.unscale_foreach(model_grads, master_grads, scale)
+            else:
+                self.unscale_python(model_grads, master_grads, scale)
+        
         # Defer to update_scale
         # If the fused kernel is available, we only need one D2H memcopy and sync.
         # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
         #     self._has_overflow = self._overflow_buf.item()
 
+    def unscale_with_stashed_foreach(self,
+                                    model_grads,
+                                    stashed_master_grads,
+                                    master_grads,
+                                    a,
+                                    b,
+                                    use_npu_fused_optimizer):
+        if not self._is_support_inf_nan and self._has_overflow:
+            return
+
+        model_grads_valid = []
+        stashed_master_grads_valid = []
+        master_grads_valid = []
+        for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
+            if model is None and stashed is None:
+                continue
+            assert stashed.dtype == master.dtype
+            if not LossScaler.warned_unscaling_non_fp32_grad:
+                if master.dtype != torch.float32:
+                    maybe_print(
+                        "Attempting to unscale a grad with type {} ".format(master.type()) +
+                        "Unscaling non-fp32 grads may indicate an error. "
+                        "When using Amp, you don't need to call .half() on your model.")
+                    LossScaler.warned_unscaling_non_fp32_grad = True
+            model_grads_valid.append(model)
+            stashed_master_grads_valid.append(stashed)
+            master_grads_valid.append(master)
+
+        if self.dynamic:
+            with torch.no_grad():
+                torch._amp_foreach_non_finite_check_and_unscale_(model_grads_valid, self._overflow_buf, torch.tensor(a).npu())
+            self._has_overflow = self._overflow_buf.item() > 0
+            self._overflow_buf.zero_()
+            if self._has_overflow:
+                return
+
+        for model_grad, master_grad, stashed_grad in zip(
+            model_grads_valid, master_grads_valid, stashed_master_grads_valid):
+
+            converted_model_grad = model_grad.data.to(master_grad.dtype)
+            if not self.dynamic:
+                converted_model_grad.data = a*converted_model_grad.data
+            if use_npu_fused_optimizer:
+                master_grad.data[:] = converted_model_grad.data + b*stashed_grad.data
+            else:
+                master_grad.data = converted_model_grad.data + b*stashed_grad.data
+
     def unscale_with_stashed_python(self,
                                     model_grads,
                                     stashed_master_grads,
                                     master_grads,
                                     a,
-                                    b):
+                                    b,
+                                    use_npu_fused_optimizer):
+        if not self._is_support_inf_nan and self._has_overflow:
+            return
+
         for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
             if model is None and stashed is None:
                 continue
@@ -145,7 +300,8 @@
                                                                  master,
                                                                  a,
                                                                  b,
-                                                                 self.dynamic)
+                                                                 use_npu_fused_optimizer,
+                                                                 self.dynamic and self._is_support_inf_nan)
                 if self._has_overflow and self.dynamic:
                     break
 
@@ -153,7 +309,8 @@
                              model_grads,
                              stashed_master_grads,
                              master_grads,
-                             scale_override=None):
+                             scale_override=None,
+                             use_npu_fused_optimizer=False):
         if self._has_overflow:
             return
 
@@ -177,19 +334,87 @@
                                  out_scale/stashed_have_scale, # 1.0,
                                  0) # check only arg 0, aka the incoming model grads, for infs
         else:
-            self.unscale_with_stashed_python(model_grads,
-                                             stashed_master_grads,
-                                             master_grads,
-                                             out_scale/grads_have_scale,
-                                             out_scale/stashed_have_scale)
+            if self._is_support_inf_nan:
+                self.unscale_with_stashed_foreach(model_grads,
+                                                 stashed_master_grads,
+                                                 master_grads,
+                                                 out_scale/grads_have_scale,
+                                                 out_scale/stashed_have_scale,
+                                                 use_npu_fused_optimizer)
+            else:
+                self.unscale_with_stashed_python(model_grads,
+                                                 stashed_master_grads,
+                                                 master_grads,
+                                                 out_scale/grads_have_scale,
+                                                 out_scale/stashed_have_scale,
+                                                 use_npu_fused_optimizer)
 
         # Defer to update_scale
         # If the fused kernel is available, we only need one D2H memcopy and sync.
         # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
         #     self._has_overflow = self._overflow_buf.item()
 
+    def unscale_with_stashed_combined(self,
+                                      grads_combined,
+                                      stashed_grads_combined,
+                                      scale_override=None,
+                                      grads_list=None):
+        if self._has_overflow:
+            return
+
+        if grads_list is not None and self.check_grads_overflow_with_inf(grads_list):
+            return
+        
+        grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
+        if scale_override is not None:
+            grads_have_scale, stashed_have_scale, out_scale = scale_override
+
+        if stashed_grads_combined is None:
+            grads_combined.data[:] = grads_combined.mul_(out_scale/grads_have_scale)
+        else:
+            grads_combined.data[:] = grads_combined.mul_(out_scale/grads_have_scale) + stashed_grads_combined
+
+    def unscale_grad_O2(self,
+                        model_grads_combined=None,
+                        stashed_master_grads_combined=None,
+                        master_grads_combined=None,
+                        scale_override=None,
+                        master_grads=None,
+                        model_grads=None):
+
+        if master_grads_combined is None:
+            return
+
+        if self._has_overflow:
+            return
+
+        if model_grads is not None and self.check_grads_overflow_with_inf(model_grads):
+            return
+
+        grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
+        if scale_override is not None:
+            grads_have_scale, stashed_have_scale, out_scale = scale_override
+
+        if stashed_master_grads_combined is not None and \
+                master_grads_combined.data_ptr() == stashed_master_grads_combined.data_ptr() and \
+                master_grads_combined.numel() == stashed_master_grads_combined.numel():
+            stashed_master_grads_combined = master_grads_combined.clone()
+
+        if master_grads_combined is not model_grads_combined:
+            if master_grads_combined.numel() == model_grads_combined.numel():
+                master_grads_combined.copy_(model_grads_combined)
+            else:
+                for master, model in zip(master_grads, model_grads):
+                    master.copy_(model)
+        master_grads_combined.mul_(out_scale/grads_have_scale)
+
+        if stashed_master_grads_combined is not None:
+            assert stashed_master_grads_combined.dtype == master_grads_combined.dtype
+            master_grads_combined.add_(stashed_master_grads_combined)
+
     def clear_overflow_state(self):
         self._has_overflow = False
+        self._overflow_checked = False
         if self.has_fused_kernel:
             self._overflow_buf.zero_()
 
@@ -202,16 +427,16 @@
         if self._has_overflow and self.dynamic:
             should_skip = True
             if(self._min_loss_scale):
-                self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
+                self._loss_scale = max(self._min_loss_scale, self._loss_scale * self._scale_backoff_factor)
             else:
-                self._loss_scale = self._loss_scale/2.
+                self._loss_scale = self._loss_scale * self._scale_backoff_factor
             self._unskipped = 0
         else:
             should_skip = False
             self._unskipped += 1
 
         if self._unskipped == self._scale_seq_len and self.dynamic:
-            self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
+            self._loss_scale = min(self._max_loss_scale, self._loss_scale * self._scale_growth_factor)
             self._unskipped = 0
 
         return should_skip
diff -Nur '--exclude=.git' apex/apex/amp/utils.py apex-develop/apex/amp/utils.py
--- apex/apex/amp/utils.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/utils.py	2024-03-07 21:33:04.293391422 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from . import compat
 
 import functools
@@ -55,7 +71,7 @@
     if is_nested(x):
         return type(x)([maybe_half(y) for y in x])
 
-    if not x.is_cuda or type_string(x) == 'HalfTensor':
+    if not 'npu' in x.type()  or type_string(x) == 'HalfTensor':
         return x
     else:
         if verbose:
@@ -66,7 +82,7 @@
     if is_nested(x):
         return type(x)([maybe_float(y) for y in x])
 
-    if not x.is_cuda or type_string(x) == 'FloatTensor':
+    if not 'npu' in x.type() or type_string(x) == 'FloatTensor':
         return x
     else:
         if verbose:
@@ -94,7 +110,7 @@
         cached_x = cache[x]
         if x.requires_grad and cached_x.requires_grad:
             # Make sure x is actually cached_x's autograd parent.
-            if cached_x.grad_fn.next_functions[1][0].variable is not x:
+            if cached_x.grad_fn.next_functions[0][0].variable is not x:
                 raise RuntimeError("x and cache[x] both require grad, but x is not "
                                    "cache[x]'s parent.  This is likely an error.")
         # During eval, it's possible to end up caching casted weights with
diff -Nur '--exclude=.git' apex/apex/amp/wrap.py apex-develop/apex/amp/wrap.py
--- apex/apex/amp/wrap.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/amp/wrap.py	2024-03-07 21:33:04.293391422 +0800
@@ -249,7 +249,7 @@
 
         new_args = []
         for i, arg in enumerate(args):
-            if i == params_idx:
+            if i == params_idx and torch.cuda.is_available():
                 num_params = sum([x.numel() for x in arg])
                 fp16_weight_buf = args[0].new_empty((num_params,),
                                                     dtype=torch.half)
diff -Nur '--exclude=.git' apex/apex/normalization/fused_layer_norm.py apex-develop/apex/normalization/fused_layer_norm.py
--- apex/apex/normalization/fused_layer_norm.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/normalization/fused_layer_norm.py	2024-03-07 21:33:04.305391422 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2023, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import math
 import torch
 import numbers
@@ -130,7 +146,7 @@
         super(FusedLayerNorm, self).__init__()
 
         global fused_layer_norm_cuda
-        fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
+        fused_layer_norm_cuda = None
 
         if isinstance(normalized_shape, numbers.Integral):
             normalized_shape = (normalized_shape,)
@@ -151,9 +167,10 @@
             init.zeros_(self.bias)
 
     def forward(self, input):
-        if not input.is_cuda:
-            return  F.layer_norm(
-                input, self.normalized_shape, self.weight, self.bias, self.eps)
+        if not input.is_cuda or fused_layer_norm_cuda is None:
+            with torch.autocast(device_type='npu', enabled=False):
+                return  F.layer_norm(
+                    input, self.normalized_shape, self.weight, self.bias, self.eps)
         if self.elementwise_affine:
           return FusedLayerNormAffineFunction.apply(
               input, self.weight, self.bias, self.normalized_shape,self.eps)
diff -Nur '--exclude=.git' apex/apex/optimizers/fused_adagrad.py apex-develop/apex/optimizers/fused_adagrad.py
--- apex/apex/optimizers/fused_adagrad.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/optimizers/fused_adagrad.py	2024-03-07 21:33:04.305391422 +0800
@@ -37,8 +37,6 @@
         adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay
             True for decoupled weight decay (also known as AdamW) (default: False)
 
-    .. _Adaptive Subgradient Methods for Online Learning and Stochastic
-        Optimization: http://jmlr.org/papers/v12/duchi11a.html
     """
     def __init__(self, params, lr=1e-2, eps=1e-10,
                  weight_decay=0., set_grad_none=True, adagrad_w_mode=False):
diff -Nur '--exclude=.git' apex/apex/optimizers/fused_adam.py apex-develop/apex/optimizers/fused_adam.py
--- apex/apex/optimizers/fused_adam.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/optimizers/fused_adam.py	2024-03-07 21:33:04.305391422 +0800
@@ -53,10 +53,6 @@
         set_grad_none (bool, optional): whether set grad to None when zero_grad()
             method is called. (default: True)
 
-    .. _Adam - A Method for Stochastic Optimization:
-        https://arxiv.org/abs/1412.6980
-    .. _On the Convergence of Adam and Beyond:
-        https://openreview.net/forum?id=ryQu7f-RZ
     """
 
     def __init__(self, params, lr=1e-3, bias_correction=True,
diff -Nur '--exclude=.git' apex/apex/optimizers/fused_lamb.py apex-develop/apex/optimizers/fused_lamb.py
--- apex/apex/optimizers/fused_lamb.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/optimizers/fused_lamb.py	2024-03-07 21:33:04.305391422 +0800
@@ -54,10 +54,6 @@
         use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
             weight decay parameter (default: False)
 
-    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
-        https://arxiv.org/abs/1904.00962
-    .. _On the Convergence of Adam and Beyond:
-        https://openreview.net/forum?id=ryQu7f-RZ
     """
 
     def __init__(self, params, lr=1e-3, bias_correction=True,
diff -Nur '--exclude=.git' apex/apex/optimizers/fused_novograd.py apex-develop/apex/optimizers/fused_novograd.py
--- apex/apex/optimizers/fused_novograd.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/optimizers/fused_novograd.py	2024-03-07 21:33:04.305391422 +0800
@@ -30,7 +30,6 @@
     In general, ``opt_level="O1"`` is recommended.
 
     It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_.
-    More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd
 
     Arguments:
         params (iterable): iterable of parameters to optimize or dicts defining
@@ -58,10 +57,6 @@
         set_grad_none (bool, optional): whether set grad to None when zero_grad()
             method is called. (default: True)
 
-    .. _Jasper - An End-to-End Convolutional Neural Acoustic Model:
-        https://arxiv.org/abs/1904.03288
-    .. _On the Convergence of Adam and Beyond:
-        https://openreview.net/forum?id=ryQu7f-RZ
     """
 
     def __init__(self, params, lr=1e-3, bias_correction=True,
diff -Nur '--exclude=.git' apex/apex/optimizers/fused_sgd.py apex-develop/apex/optimizers/fused_sgd.py
--- apex/apex/optimizers/fused_sgd.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/optimizers/fused_sgd.py	2024-03-07 21:33:04.305391422 +0800
@@ -48,7 +48,6 @@
         >>> loss_fn(model(input), target).backward()
         >>> optimizer.step()
 
-    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
 
     .. note::
         The implementation of SGD with Momentum/Nesterov subtly differs from
diff -Nur '--exclude=.git' apex/apex/optimizers/__init__.py apex-develop/apex/optimizers/__init__.py
--- apex/apex/optimizers/__init__.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/apex/optimizers/__init__.py	2024-03-07 21:33:04.305391422 +0800
@@ -2,4 +2,14 @@
 from .fused_adam import FusedAdam
 from .fused_novograd import FusedNovoGrad
 from .fused_lamb import FusedLAMB
-from .fused_adagrad import FusedAdagrad
\ No newline at end of file
+from .fused_adagrad import FusedAdagrad
+from .npu_fused_sgd import NpuFusedSGD
+from .npu_fused_adam import NpuFusedAdam
+from .npu_fused_bert_adam import NpuFusedBertAdam
+from .npu_fused_adadelta import NpuFusedAdadelta
+from .npu_fused_lamb import NpuFusedLamb
+from .lamb import Lamb
+from .npu_fused_adamw import NpuFusedAdamW
+from .npu_fused_adamp import NpuFusedAdamP
+from .npu_fused_rmsprop import NpuFusedRMSprop
+from .npu_fused_rmsprop_tf import NpuFusedRMSpropTF
diff -Nur '--exclude=.git' apex/csrc/flatten_unflatten.cpp apex-develop/csrc/flatten_unflatten.cpp
--- apex/csrc/flatten_unflatten.cpp	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/csrc/flatten_unflatten.cpp	2024-03-07 21:33:04.309391422 +0800
@@ -1,3 +1,18 @@
+/*
+ * Copyright (c) 2020, Huawei Technologies.All rights reserved.
+ * Licensed under the BSD 3-Clause License  (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://opensource.org/licenses/BSD-3-Clause
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 #include <torch/extension.h>
 #include <torch/csrc/utils/tensor_flatten.h>
 // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
@@ -12,7 +27,7 @@
   return torch::utils::unflatten_dense_tensors(flat, tensors);
 }
 
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+PYBIND11_MODULE(apex_C, m) {
   m.def("flatten", &flatten, "Flatten dense tensors");
   m.def("unflatten", &unflatten, "Unflatten dense tensors");
 }
diff -Nur '--exclude=.git' apex/.gitignore apex-develop/.gitignore
--- apex/.gitignore	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/.gitignore	1970-01-01 08:00:00.000000000 +0800
@@ -1,5 +0,0 @@
-apex.egg-info
-dist
-build
-docs/build
-*~
\ No newline at end of file
diff -Nur '--exclude=.git' apex/.gitmodules apex-develop/.gitmodules
--- apex/.gitmodules	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/.gitmodules	1970-01-01 08:00:00.000000000 +0800
@@ -1,4 +0,0 @@
-[submodule "apex/contrib/csrc/multihead_attn/cutlass"]
-	path = apex/contrib/csrc/multihead_attn/cutlass
-	url = https://github.com/NVIDIA/cutlass.git
-	branch = v1.2.0
diff -Nur '--exclude=.git' apex/setup.py apex-develop/setup.py
--- apex/setup.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/setup.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,55 +1,92 @@
-import torch
-from torch.utils import cpp_extension
-from setuptools import setup, find_packages
-import subprocess
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 import sys
 import warnings
 import os
+import glob
+import subprocess
+from setuptools.command.build_ext import build_ext
+from setuptools import setup, find_packages, Extension
+
+import torch
 
 # ninja build does not work unless include_dirs are abs path
 this_dir = os.path.dirname(os.path.abspath(__file__))
 
-def get_cuda_bare_metal_version(cuda_dir):
-    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
-    output = raw_output.split()
-    release_idx = output.index("release") + 1
-    release = output[release_idx].split(".")
-    bare_metal_major = release[0]
-    bare_metal_minor = release[1][0]
-
-    return raw_output, bare_metal_major, bare_metal_minor
-
-if not torch.cuda.is_available():
-    # https://github.com/NVIDIA/apex/issues/486
-    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
-    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
-    print('\nWarning: Torch did not find available GPUs on this system.\n',
-          'If your intention is to cross-compile, this is not an error.\n'
-          'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
-          'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
-          'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
-          'If you wish to cross-compile for a single specific architecture,\n'
-          'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
-    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
-        _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
-        if int(bare_metal_major) == 11:
-            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
-        else:
-            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
+cmdclass = {}
+ext_modules = []
+
+extras = {}
 
-print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
 TORCH_MAJOR = int(torch.__version__.split('.')[0])
 TORCH_MINOR = int(torch.__version__.split('.')[1])
 
-if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
-      raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
-                         "The latest stable release can be obtained from https://pytorch.org/")
+secure_compile_args = ['-fPIE', '-fPIC', '-fstack-protector-all', '-Wall', '-D__FILENAME__=\"$(notdir $(abspath $<))\"']
+
+if (TORCH_MAJOR == 2 and TORCH_MINOR >= 1) or TORCH_MAJOR > 2 :
+    secure_compile_args.append('-std=c++17')
+
+secure_link_args = ['-Wl,-z,now', '-Wl,-z,relro', '-Wl,-z,noexecstack', '-s']
+
+def get_package_dir():
+    if '--user' in sys.argv:
+        package_dir = site.getusersitepackages()
+    else:
+        py_version = f'{sys.version_info.major}.{sys.version_info.minor}'
+        package_dir = f'{sys.prefix}/lib/python{py_version}/site-packages'
+    return package_dir
+
+
+def CppExtension(name, sources, *args, **kwargs):
+    r'''
+    Creates a :class:`setuptools.Extension` for C++.
+    '''
+    package_dir = get_package_dir()
+    temp_include_dirs = kwargs.get('include_dirs', [])
+    temp_include_dirs.append(os.path.join(package_dir, 'torch/include'))
+    temp_include_dirs.append(os.path.join(package_dir, 'torch/include/torch/csrc/api/include'))
+    kwargs['include_dirs'] = temp_include_dirs
+
+    temp_library_dirs = kwargs.get('library_dirs', [])
+    temp_library_dirs.append(os.path.join(package_dir, 'torch/lib'))
+    kwargs['library_dirs'] = temp_library_dirs
+
+    libraries = kwargs.get('libraries', [])
+    libraries.append('c10')
+    libraries.append('torch')
+    libraries.append('torch_cpu')
+    libraries.append('torch_python')
+    kwargs['libraries'] = libraries
+    kwargs['language'] = 'c++'
+    return Extension(name, sources, *args, **kwargs)
+
+
+class BuildExtension(build_ext, object):
+
+    def build_extensions(self):
+        if self.compiler and '-Wstrict-prototypes' in self.compiler.compiler_so:
+            self.compiler.compiler_so.remove('-Wstrict-prototypes')
+
+        if self.compiler and '-g' in self.compiler.compiler_so:
+            self.compiler.compiler_so.remove('-g')
+
+        return super(BuildExtension, self).build_extensions()
 
-cmdclass = {}
-ext_modules = []
 
-extras = {}
 if "--pyprof" in sys.argv:
     string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
              "soon be removed from Apex.  Please visit\n" + \
@@ -67,344 +104,43 @@
     warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")
 
 if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
-    if TORCH_MAJOR == 0:
-        raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
-                           "found torch.__version__ = {}".format(torch.__version__))
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension
 
 if "--cpp_ext" in sys.argv:
-    from torch.utils.cpp_extension import CppExtension
     sys.argv.remove("--cpp_ext")
     ext_modules.append(
         CppExtension('apex_C',
-                     ['csrc/flatten_unflatten.cpp',]))
+                     ['csrc/flatten_unflatten.cpp',],
+                     extra_compile_args=secure_compile_args,
+                     extra_link_args=secure_link_args))
 
-def get_cuda_bare_metal_version(cuda_dir):
-    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
-    output = raw_output.split()
-    release_idx = output.index("release") + 1
-    release = output[release_idx].split(".")
-    bare_metal_major = release[0]
-    bare_metal_minor = release[1][0]
-
-    return raw_output, bare_metal_major, bare_metal_minor
-
-def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
-    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
-    torch_binary_major = torch.version.cuda.split(".")[0]
-    torch_binary_minor = torch.version.cuda.split(".")[1]
-
-    print("\nCompiling cuda extensions with")
-    print(raw_output + "from " + cuda_dir + "/bin\n")
-
-    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
-        raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
-                           "not match the version used to compile Pytorch binaries.  " +
-                           "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
-                           "In some cases, a minor-version mismatch will not cause later errors:  " +
-                           "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
-                           "You can try commenting out this check (at your own risk).")
-
-
-# Set up macros for forward/backward compatibility hack around
-# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
-# and
-# https://github.com/NVIDIA/apex/issues/456
-# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
-version_ge_1_1 = []
-if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
-    version_ge_1_1 = ['-DVERSION_GE_1_1']
-version_ge_1_3 = []
-if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
-    version_ge_1_3 = ['-DVERSION_GE_1_3']
-version_ge_1_5 = []
-if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
-    version_ge_1_5 = ['-DVERSION_GE_1_5']
-version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
+    ext_modules.append(
+        CppExtension('change_data_ptr',
+                     ['csrc/combine_tensors/change_dataptr.cpp',],
+                     extra_compile_args=secure_compile_args,
+                     extra_link_args=secure_link_args))
 
 if "--distributed_lamb" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--distributed_lamb")
-
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension
 
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--distributed_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        ext_modules.append(
-            CUDAExtension(name='distributed_lamb_cuda',
-                          sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
-                                   'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
-                          include_dirs=[os.path.join(this_dir, 'csrc')],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
-                                              'nvcc':['-O3',
-                                                      '--use_fast_math'] + version_dependent_macros}))
-
-if "--cuda_ext" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--cuda_ext")
-
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--cuda_ext was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
-
-        ext_modules.append(
-            CUDAExtension(name='amp_C',
-                          sources=['csrc/amp_C_frontend.cpp',
-                                   'csrc/multi_tensor_sgd_kernel.cu',
-                                   'csrc/multi_tensor_scale_kernel.cu',
-                                   'csrc/multi_tensor_axpby_kernel.cu',
-                                   'csrc/multi_tensor_l2norm_kernel.cu',
-                                   'csrc/multi_tensor_lamb_stage_1.cu',
-                                   'csrc/multi_tensor_lamb_stage_2.cu',
-                                   'csrc/multi_tensor_adam.cu',
-                                   'csrc/multi_tensor_adagrad.cu',
-                                   'csrc/multi_tensor_novograd.cu',
-                                   'csrc/multi_tensor_lamb.cu'],
-                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
-                                              'nvcc':['-lineinfo',
-                                                      '-O3',
-                                                      # '--resource-usage',
-                                                      '--use_fast_math'] + version_dependent_macros}))
-        ext_modules.append(
-            CUDAExtension(name='syncbn',
-                          sources=['csrc/syncbn.cpp',
-                                   'csrc/welford.cu'],
-                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
-                                              'nvcc':['-O3'] + version_dependent_macros}))
-
-        ext_modules.append(
-            CUDAExtension(name='fused_layer_norm_cuda',
-                          sources=['csrc/layer_norm_cuda.cpp',
-                                   'csrc/layer_norm_cuda_kernel.cu'],
-                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
-                                              'nvcc':['-maxrregcount=50',
-                                                      '-O3',
-                                                      '--use_fast_math'] + version_dependent_macros}))
-
-        ext_modules.append(
-            CUDAExtension(name='mlp_cuda',
-                          sources=['csrc/mlp.cpp',
-                                   'csrc/mlp_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
-                                              'nvcc':['-O3'] + version_dependent_macros}))
-
 if "--bnp" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--bnp")
-
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension
 
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--bnp was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        ext_modules.append(
-            CUDAExtension(name='bnp',
-                          sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
-                                   'apex/contrib/csrc/groupbn/ipc.cu',
-                                   'apex/contrib/csrc/groupbn/interface.cpp',
-                                   'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
-                          include_dirs=[os.path.join(this_dir, 'csrc')],
-                          extra_compile_args={'cxx': [] + version_dependent_macros,
-                                              'nvcc':['-DCUDA_HAS_FP16=1',
-                                                      '-D__CUDA_NO_HALF_OPERATORS__',
-                                                      '-D__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
-
 if "--xentropy" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--xentropy")
-
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension
 
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--xentropy was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        ext_modules.append(
-            CUDAExtension(name='xentropy_cuda',
-                          sources=['apex/contrib/csrc/xentropy/interface.cpp',
-                                   'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
-                          include_dirs=[os.path.join(this_dir, 'csrc')],
-                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
-                                              'nvcc':['-O3'] + version_dependent_macros}))
-
 if "--deprecated_fused_adam" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--deprecated_fused_adam")
-
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension
 
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        ext_modules.append(
-            CUDAExtension(name='fused_adam_cuda',
-                          sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
-                                   'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
-                          include_dirs=[os.path.join(this_dir, 'csrc')],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
-                                              'nvcc':['-O3',
-                                                      '--use_fast_math'] + version_dependent_macros}))
-
 if "--deprecated_fused_lamb" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--deprecated_fused_lamb")
-
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension
 
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        ext_modules.append(
-            CUDAExtension(name='fused_lamb_cuda',
-                          sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
-                                   'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
-                                   'csrc/multi_tensor_l2norm_kernel.cu'],
-                          include_dirs=[os.path.join(this_dir, 'csrc')],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
-                                              'nvcc':['-O3',
-                                                      '--use_fast_math'] + version_dependent_macros}))
-
-# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 
-generator_flag = []
-torch_dir = torch.__path__[0]
-if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
-    generator_flag = ['-DOLD_GENERATOR']
-
-
 if "--fast_multihead_attn" in sys.argv:
-    from torch.utils.cpp_extension import CUDAExtension
-    sys.argv.remove("--fast_multihead_attn")
-
-    from torch.utils.cpp_extension import BuildExtension
     cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
 
-    if torch.utils.cpp_extension.CUDA_HOME is None:
-        raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
-    else:
-        # Check, if CUDA11 is installed for compute capability 8.0
-        cc_flag = []
-        _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
-        if int(bare_metal_major) >= 11:
-            cc_flag.append('-gencode')
-            cc_flag.append('arch=compute_80,code=sm_80')
-
-        subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
-        ext_modules.append(
-            CUDAExtension(name='fast_additive_mask_softmax_dropout',
-                          sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
-                                   'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_mask_softmax_dropout',
-                          sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
-                                   'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
-                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
-                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_self_multihead_attn_bias',
-                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
-                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_self_multihead_attn',
-                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
-                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_self_multihead_attn_norm_add',
-                          sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
-                                   'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_encdec_multihead_attn',
-                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
-                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-        ext_modules.append(
-            CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
-                          sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
-                                   'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
-                          extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
-                                              'nvcc':['-O3',
-                                                      '-gencode', 'arch=compute_70,code=sm_70',
-                                                      '-I./apex/contrib/csrc/multihead_attn/cutlass/',
-                                                      '-U__CUDA_NO_HALF_OPERATORS__',
-                                                      '-U__CUDA_NO_HALF_CONVERSIONS__',
-                                                      '--expt-relaxed-constexpr',
-                                                      '--expt-extended-lambda',
-                                                      '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
-
 setup(
     name='apex',
-    version='0.1',
+    version='0.1+ascend',
     packages=find_packages(exclude=('build',
                                     'csrc',
                                     'include',
diff -Nur '--exclude=.git' apex/tests/distributed/amp_master_params/amp_master_params.py apex-develop/tests/distributed/amp_master_params/amp_master_params.py
--- apex/tests/distributed/amp_master_params/amp_master_params.py	2023-04-06 10:36:26.964937605 +0800
+++ apex-develop/tests/distributed/amp_master_params/amp_master_params.py	2024-03-07 21:33:04.313391421 +0800
@@ -34,8 +34,6 @@
 
 # Each process receives its own batch of "fake input data" and "fake target data."
 # The "training loop" in each process just uses this fake batch over and over.
-# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
-# example of distributed data sampling for both training and validation.
 x = torch.randn(N, D_in, device='cuda')
 y = torch.randn(N, D_out, device='cuda')
 
diff -Nur '--exclude=.git' apex/tests/distributed/synced_batchnorm/test_groups.py apex-develop/tests/distributed/synced_batchnorm/test_groups.py
--- apex/tests/distributed/synced_batchnorm/test_groups.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/distributed/synced_batchnorm/test_groups.py	2024-03-07 21:33:04.313391421 +0800
@@ -105,7 +105,6 @@
 out_bn.backward(grad_bn)
 # compensating the averaging over processes done by DDP
 # in order to produce mathematically equivalent result
-# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368
 for param in bn.parameters():
     param.grad = param.grad / args.group_size
 bn_opt = optim.SGD(bn.parameters(), lr=1.0)
diff -Nur '--exclude=.git' apex/tests/distributed/synced_batchnorm/two_gpu_unit_test.py apex-develop/tests/distributed/synced_batchnorm/two_gpu_unit_test.py
--- apex/tests/distributed/synced_batchnorm/two_gpu_unit_test.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/distributed/synced_batchnorm/two_gpu_unit_test.py	2024-03-07 21:33:04.313391421 +0800
@@ -94,7 +94,6 @@
 out_bn.backward(grad_bn)
 # compensating the averaging over processes done by DDP
 # in order to produce mathematically equivalent result
-# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368
 for param in bn.parameters():
     param.grad = param.grad / args.world_size
 bn_opt = optim.SGD(bn.parameters(), lr=1.0)
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_add_param_group.py apex-develop/tests/L0/run_amp/test_add_param_group.py
--- apex/tests/L0/run_amp/test_add_param_group.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_add_param_group.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 
 import functools as ft
@@ -9,16 +25,20 @@
 from torch import nn
 import torch.nn.functional as F
 from torch.nn import Parameter
+import numpy as np
 
-from utils import common_init, HALF, FLOAT,\
-    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
+from utils import common_init
+import sys
+sys.path.append('../')
+import device
 
 class MyModel(torch.nn.Module):
     def __init__(self, unique):
         super(MyModel, self).__init__()
         self.weight0 = Parameter(unique +
-            torch.arange(2, device='cuda', dtype=torch.float32))
-        self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
+            torch.from_numpy(np.arange(2, dtype=np.float32)))
+        self.weight1 = Parameter(1. + unique +
+            torch.from_numpy(np.arange(2, dtype=np.float16)).to(device.CALCULATE_DEVICE ))
 
     @staticmethod
     def ops(input, weight0, weight1):
@@ -33,7 +53,8 @@
 
 class TestAddParamGroup(unittest.TestCase):
     def setUp(self):
-        self.x = torch.ones((2), device='cuda', dtype=torch.float32)
+        self.device = device.CALCULATE_DEVICE
+        self.x = torch.ones((2), device=self.device, dtype=torch.float32)
         common_init(self)
 
     def tearDown(self):
@@ -54,8 +75,8 @@
         for opt_level in ("O0", "O1", "O2", "O3"):
           for zero_before_add in (True, False):
             for try_accumulation in (True, False):
-              model0 = MyModel(1)
-              model1 = MyModel(2)
+              model0 = MyModel(1).to(self.device)
+              model1 = MyModel(2).to(self.device)
 
               optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
                                           momentum=0.125)
@@ -89,8 +110,8 @@
                                  [param.data.clone() for param in model1.parameters()]
 
               for how_to_zero in "none", "model", "optimizer":
-                  model0 = MyModel(1)
-                  model1 = MyModel(2)
+                  model0 = MyModel(1).to(self.device)
+                  model1 = MyModel(2).to(self.device)
 
                   optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
                                               momentum=0.125)
@@ -139,7 +160,8 @@
                                  [param.data.clone() for param in model1.parameters()]
 
                   for reference, final in zip(reference_params, final_params):
-                      self.assertTrue(torch.allclose(reference.to(final.dtype), final),
+                      final = final.to(torch.float32)
+                      self.assertTrue(torch.allclose(reference.to(final.dtype).to('cpu'), final.to('cpu')),
                                       "opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
                                       opt_level, how_to_zero, zero_before_add))
 
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_basic_casts.py apex-develop/tests/L0/run_amp/test_basic_casts.py
--- apex/tests/L0/run_amp/test_basic_casts.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_basic_casts.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 
 import functools as ft
@@ -7,73 +23,89 @@
 import torch
 from torch import nn
 import torch.nn.functional as F
+import numpy as np
+
+from utils import common_init, generate_data
+import utils
+
+import sys
+sys.path.append('../')
+import device
+
+npu_input_grad = None
 
-from utils import common_init, HALF, FLOAT,\
-    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
+def npu_input_grad_hook(grad):
+   global npu_input_grad
+   npu_input_grad = grad.to('cpu')
 
 def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
     for fn, typ in it.product(fns, expected.keys()):
-        x = torch.randn(input_shape, dtype=typ).requires_grad_()
+        x = generate_data(0, 10, input_shape, typ).requires_grad_()
+        x = x.to(test_case.device)
+        x.register_hook(npu_input_grad_hook)
         y = fn(x)
         test_case.assertEqual(y.type(), expected[typ])
         if test_backward:
-            y.float().sum().backward()
-            test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
+            y.float().sum().backward(retain_graph=True)
+            test_case.assertEqual(npu_input_grad.type().split(".")[-1], utils.MATCH_INPUT[typ].split(".")[-1])
 
 class TestBasicCasts(unittest.TestCase):
     def setUp(self):
         self.handle = amp.init(enabled=True)
+        self.device = device.CALCULATE_DEVICE
         common_init(self)
 
     def tearDown(self):
         self.handle._deactivate()
 
     def test_linear_is_half(self):
-        m = nn.Linear(self.h, self.h)
+        m = nn.Linear(self.h, self.h).to(self.device)
         f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
-        run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
+        run_layer_test(self, [m, f], utils.ALWAYS_HALF, (self.b, self.h))
 
     def test_conv2d_is_half(self):
-        m = nn.Conv2d(self.c, self.c, self.k)
+        m = nn.Conv2d(self.c, self.c, self.k).to(self.device)
         f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
-        run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
+        run_layer_test(self, [m, f], utils.ALWAYS_HALF, (self.b, self.c, self.h, self.h))
 
     def test_softmax_is_float(self):
-        m = nn.Softmax(dim=1)
+        m = nn.Softmax(dim=1).to(self.device)
         f = ft.partial(F.softmax, dim=1)
-        run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
+        run_layer_test(self, [m, f], utils.ALWAYS_FLOAT, (self.b, self.h))
 
+    @unittest.skipIf(device.is_npu(),"NPU does not support group_norm in half")
     def test_group_norm_is_float(self):
-        m = nn.GroupNorm(num_groups=4, num_channels=self.c)
-        run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
+        m = nn.GroupNorm(num_groups=4, num_channels=self.c).to(self.device)
+        run_layer_test(self, [m], utils.ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
 
     def test_mse_loss_is_float(self):
         shape = (self.b, self.h)
-        target = torch.randn(shape)
-        mod = nn.MSELoss()
+        target = torch.randn(shape).to(self.device)
+        mod = nn.MSELoss().to(self.device)
         m = lambda x: mod(x, target)
         f = ft.partial(F.mse_loss, target=target)
-        run_layer_test(self, [m], ALWAYS_FLOAT, shape)
+        run_layer_test(self, [m], utils.ALWAYS_FLOAT, shape)
 
     def test_relu_is_match(self):
-        run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
+        run_layer_test(self, [nn.ReLU(), F.relu], utils.MATCH_INPUT, (self.b, self.h))
 
     def test_batch_norm_is_match(self):
-        m = nn.BatchNorm2d(num_features=self.c)
+        m = nn.BatchNorm2d(num_features=self.c).to(self.device)
         f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                        weight=m.weight, bias=m.bias, training=True)
-        run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
+        run_layer_test(self, [m], utils.MATCH_INPUT, (self.b, self.c, self.h, self.h))
 
         # Test forward-only for BN inference
         m.eval()
         f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                        weight=m.weight, bias=m.bias, training=False)
-        run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
+        run_layer_test(self, [m, f], utils.MATCH_INPUT, (self.b, self.c, self.h, self.h),
                             test_backward=False)
 
 class TestBannedMethods(unittest.TestCase):
     def setUp(self):
         self.handle = amp.init(enabled=True)
+        self.device = device.CALCULATE_DEVICE
         common_init(self)
 
     def tearDown(self):
@@ -81,12 +113,12 @@
 
     def bce_common(self, assertion):
         shape = (self.b, self.h)
-        target = torch.rand(shape)
-        mod = nn.BCELoss()
+        target = torch.rand(shape).to(self.device)
+        mod = nn.BCELoss().to(self.device)
         m = lambda x: mod(x, target)
         f = ft.partial(F.binary_cross_entropy, target=target)
         for fn in [m, f]:
-            x = torch.rand(shape, dtype=torch.half)
+            x = generate_data(0, 10, shape, np.float16).to(self.device)
             assertion(fn, x)
 
     def test_bce_raises_by_default(self):
@@ -96,36 +128,37 @@
     def test_bce_is_float_with_allow_banned(self):
         self.handle._deactivate()
         self.handle = amp.init(enabled=True, allow_banned=True)
-        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
+        assertion = lambda fn, x: self.assertEqual(fn(x).type(), utils.FLOAT)
         self.bce_common(assertion)
 
 class TestTensorCasts(unittest.TestCase):
     def setUp(self):
         self.handle = amp.init(enabled=True)
+        self.device = device.CALCULATE_DEVICE
         common_init(self)
 
     def tearDown(self):
         self.handle._deactivate()
 
     def test_matmul_method_is_half(self):
-        other = torch.randn(self.h, self.h)
+        other = torch.randn(self.h, self.h).to(self.device)
         lhs = lambda x: x.matmul(other)
         rhs = lambda x: other.matmul(x)
-        run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
+        run_layer_test(self, [lhs, rhs], utils.ALWAYS_HALF, (self.h, self.h))
 
     def test_matmul_op_is_half(self):
-        other = torch.randn(self.h, self.h)
+        other = torch.randn(self.h, self.h).to(self.device)
         lhs = lambda x: x @ other
         rhs = lambda x: other @ x
-        run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
+        run_layer_test(self, [lhs, rhs], utils.ALWAYS_HALF, (self.h, self.h))
 
     def test_pow_method_is_float(self):
         fn = lambda x: x.pow(2.)
-        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
+        run_layer_test(self, [fn], utils.ALWAYS_FLOAT, (self.b, self.h))
 
     def test_pow_op_is_float(self):
         fn = lambda x: x ** 2.
-        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
+        run_layer_test(self, [fn], utils.ALWAYS_FLOAT, (self.b, self.h))
 
     def test_cpu_is_float(self):
         fn = lambda x: x.cpu()
@@ -135,7 +168,7 @@
 
     def test_sum_is_float(self):
         fn = lambda x: x.sum()
-        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
+        run_layer_test(self, [fn], utils.ALWAYS_FLOAT, (self.b, self.h))
 
     # TODO: maybe more tests on disabled casting?
 
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_cache.py apex-develop/tests/L0/run_amp/test_cache.py
--- apex/tests/L0/run_amp/test_cache.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_cache.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 
 import functools as ft
@@ -8,9 +24,16 @@
 import torch
 from torch import nn
 import torch.nn.functional as F
+import numpy as np
+import sys
+sys.path.append('../')
+import device
+import utils
 
 from utils import common_init, HALF, FLOAT,\
-    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
+    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT,\
+    generate_data
+    
 
 def get_reference_grad(i, w, ops):
     # Creating new tensors ensures, among other things, that the new tensors are not in the cache.
@@ -24,7 +47,8 @@
 class WhitelistModule(torch.nn.Module):
     def __init__(self, dtype):
         super(WhitelistModule, self).__init__()
-        self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))
+        weight_parameter = torch.from_numpy(np.arange(8*8, dtype=dtype)).view(8,8).to(device.CALCULATE_DEVICE)
+        self.weight = torch.nn.Parameter(weight_parameter)
 
     @staticmethod
     def ops(input, weight):
@@ -37,7 +61,8 @@
 class BlacklistModule(torch.nn.Module):
     def __init__(self, dtype):
         super(BlacklistModule, self).__init__()
-        self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
+        weight_parameter = torch.from_numpy(np.arange(2*8, dtype=dtype)).view(2,8).to(device.CALCULATE_DEVICE)
+        self.weight = torch.nn.Parameter(weight_parameter)
 
     @staticmethod
     def ops(input, weight):
@@ -50,7 +75,8 @@
 class PromoteModule(torch.nn.Module):
     def __init__(self, dtype):
         super(PromoteModule, self).__init__()
-        self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
+        weight_parameter = torch.from_numpy(np.arange(2*8, dtype=dtype)).view(2,8).to(device.CALCULATE_DEVICE)
+        self.weight = torch.nn.Parameter(weight_parameter)
 
     @staticmethod
     def ops(input, weight):
@@ -61,14 +87,14 @@
 
 class TestCache(unittest.TestCase):
     def setUp(self):
-        self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
+        self.x = torch.ones((2, 8), dtype=torch.float32).to(device.CALCULATE_DEVICE)
         common_init(self)
 
     def tearDown(self):
         pass
 
     def train_eval_train_test(self, module, t):
-        model = module(t).cuda()
+        model = module(t).to(device.CALCULATE_DEVICE)
         optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
 
         _amp_state.allow_incoming_model_not_fp32 = True
@@ -91,10 +117,10 @@
         
             # Currently there's no difference in the allclose calls, so no need for branching,
             # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. 
-            if model.weight.grad.type() == "torch.cuda.HalfTensor":
-                self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
-            elif model.weight.grad.type() == "torch.cuda.FloatTensor":
-                self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
+            if model.weight.grad.type() == utils.HALF:
+                self.assertTrue(torch.allclose(model.weight.grad.float().to('cpu'), reference_grad.to('cpu')))
+            elif model.weight.grad.type() == utils.FLOAT:
+                self.assertTrue(torch.allclose(model.weight.grad.float().to('cpu'), reference_grad.to('cpu')))
             else:
                 raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))
 
@@ -115,22 +141,25 @@
     # I could easily have these as a set of for loops in a single test,
     # instead of going for granularity.
     def test_whitelist_module_fp16_weight(self):
-        self.train_eval_train_test(WhitelistModule, torch.float16)
+        self.train_eval_train_test(WhitelistModule, np.float16)
+
 
     def test_whitelist_module_fp32_weight(self):
-        self.train_eval_train_test(WhitelistModule, torch.float32)
+        self.train_eval_train_test(WhitelistModule, np.float32)
+
 
     def test_blacklist_module_fp16_weight(self):
-        self.train_eval_train_test(BlacklistModule, torch.float16)
+        self.train_eval_train_test(BlacklistModule, np.float16)
+
 
     def test_blacklist_module_fp32_weight(self):
-        self.train_eval_train_test(BlacklistModule, torch.float32)
+        self.train_eval_train_test(BlacklistModule, np.float32)
 
     def test_promote_module_fp16_weight(self):
-        self.train_eval_train_test(PromoteModule, torch.float16)
+        self.train_eval_train_test(PromoteModule, np.float16)
 
     def test_promote_module_fp32_weight(self):
-        self.train_eval_train_test(PromoteModule, torch.float32)
+        self.train_eval_train_test(PromoteModule, np.float32)
 
 
 if __name__ == '__main__':
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_checkpointing.py apex-develop/tests/L0/run_amp/test_checkpointing.py
--- apex/tests/L0/run_amp/test_checkpointing.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_checkpointing.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 
 import torch
@@ -7,9 +23,8 @@
 
 from apex import amp
 
-
 from utils import common_init, FLOAT
-
+import utils
 
 class MyModel(torch.nn.Module):
     def __init__(self):
@@ -40,7 +55,7 @@
             if 'num_batches_tracked' in key:
                 continue
             param = state_dict[key]
-            self.assertEqual(param.type(), FLOAT,
+            self.assertEqual(param.type(), utils.FLOAT,
                              'Parameter in state_dict not FLOAT')
 
     def train_step(self, model, optimizer, data, loss_ids):
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_larc.py apex-develop/tests/L0/run_amp/test_larc.py
--- apex/tests/L0/run_amp/test_larc.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_larc.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,5 +1,5 @@
 import unittest
-
+import sys
 import torch
 from torch import nn
 from torch.nn import Parameter
@@ -8,12 +8,14 @@
 from apex.parallel.LARC import LARC
 from utils import common_init
 
+sys.path.append('../')
+import device
 
 class MyModel(torch.nn.Module):
     def __init__(self, unique):
         super(MyModel, self).__init__()
         self.weight0 = Parameter(
-            unique + torch.arange(2, device="cuda", dtype=torch.float32)
+            unique + torch.arange(2, device=device.CALCULATE_DEVICE, dtype=torch.float32)
         )
 
     def forward(self, input):
@@ -22,7 +24,7 @@
 
 class TestLARC(unittest.TestCase):
     def setUp(self):
-        self.x = torch.ones((2), device="cuda", dtype=torch.float32)
+        self.x = torch.ones((2), device=device.CALCULATE_DEVICE, dtype=torch.float32)
         common_init(self)
 
     def tearDown(self):
@@ -39,7 +41,7 @@
             )
 
             model, optimizer = amp.initialize(
-                model, optimizer, opt_level=opt_level, verbosity=0
+                model, optimizer, opt_level=opt_level, loss_scale=1024, verbosity=0
             )
 
             optimizer.zero_grad()
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_promotion.py apex-develop/tests/L0/run_amp/test_promotion.py
--- apex/tests/L0/run_amp/test_promotion.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_promotion.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 
 import itertools as it
@@ -7,11 +23,17 @@
 from torch import nn
 import torch.nn.functional as F
 
-from utils import common_init, HALF, FLOAT, DTYPES
+from utils import common_init, HALF, FLOAT, DTYPES,\
+    generate_data
+import utils
+import sys
+sys.path.append('../')
+import device
 
 class TestPromotion(unittest.TestCase):
     def setUp(self):
         self.handle = amp.init(enabled=True)
+        self.device = device.CALCULATE_DEVICE
         common_init(self)
 
     def tearDown(self):
@@ -20,12 +42,13 @@
     def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
         type_pairs = it.product(DTYPES, DTYPES)
         for fn, (xtype, ytype) in it.product(fns, type_pairs):
-            x = torch.randn(input_shape, dtype=xtype).requires_grad_()
+            x = generate_data(0, 10, input_shape, xtype).requires_grad_()
             x_leaf = x
             if x_inplace:
                 # We need a non-leaf to call in place on
                 x = x.clone()
-            y = torch.randn(input_shape, dtype=ytype)
+            y = generate_data(0, 10, input_shape, dtype=ytype).to(self.device)
+            x = x.to(self.device)
             out = fn(x, y)
             if x_inplace:
                 # In place: always match xtype
@@ -33,9 +56,9 @@
             else:
                 # Out of place: match widest type
                 if xtype == torch.float or ytype == torch.float:
-                    self.assertEqual(out.type(), FLOAT)
+                    self.assertEqual(out.type(), utils.FLOAT)
                 else:
-                    self.assertEqual(out.type(), HALF)
+                    self.assertEqual(out.type(), utils.HALF)
             out.float().sum().backward()
             self.assertEqual(x_leaf.grad.dtype, xtype)
 
@@ -51,19 +74,19 @@
 
     def test_cat_matches_widest(self):
         shape = self.b
-        ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
-        x_float = torch.randn(shape)
+        ys = [generate_data(0, 10, shape, dtype=torch.half).to(self.device) for _ in range(5)]
+        x_float = generate_data(0, 10, shape, dtype=torch.float).to(self.device)
         out = torch.cat(ys + [x_float])
-        self.assertEqual(out.type(), FLOAT)
-        x_half = torch.randn(shape, dtype=torch.half)
+        self.assertEqual(out.type(), utils.FLOAT)
+        x_half = generate_data(0, 10, shape, dtype=torch.half).to(self.device)
         out = torch.cat(ys + [x_half])
-        self.assertEqual(out.type(), HALF)
+        self.assertEqual(out.type(), utils.HALF)
 
     def test_inplace_exp_is_error_for_half(self):
-        xs = torch.randn(self.b)
+        xs = generate_data(0, 10, self.b, dtype=torch.float).to(self.device)
         xs.exp_()
-        self.assertEqual(xs.type(), FLOAT)
-        xs = torch.randn(self.b, dtype=torch.half)
+        self.assertEqual(xs.type(), utils.FLOAT)
+        xs = generate_data(0, 10, self.b, dtype=torch.half).to(self.device)
         with self.assertRaises(NotImplementedError):
             xs.exp_()
 
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/test_rnn.py apex-develop/tests/L0/run_amp/test_rnn.py
--- apex/tests/L0/run_amp/test_rnn.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/test_rnn.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 
 from apex import amp
@@ -5,7 +21,8 @@
 import torch
 from torch import nn
 
-from utils import common_init, HALF
+from utils import common_init
+import utils
 
 class TestRnnCells(unittest.TestCase):
     def setUp(self):
@@ -34,7 +51,7 @@
                     output = hidden
                 outputs.append(output)
             for y in outputs:
-                self.assertEqual(y.type(), HALF)
+                self.assertEqual(y.type(), utils.HALF)
             outputs[-1].float().sum().backward()
             for i, x in enumerate(xs):
                 self.assertEqual(x.grad.dtype, x.dtype)
@@ -69,7 +86,7 @@
             else:
                 hidden = hidden_fn()
             output, _ = rnn(x, hidden)
-            self.assertEqual(output.type(), HALF)
+            self.assertEqual(output.type(), utils.HALF)
             output[-1, :, :].float().sum().backward()
             self.assertEqual(x.grad.dtype, x.dtype)
 
@@ -108,7 +125,7 @@
             torch.set_default_tensor_type(torch.cuda.FloatTensor)
             hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
             output, _ = rnn(packed_seq, hidden)
-            self.assertEqual(output.data.type(), HALF)
+            self.assertEqual(output.data.type(), utils.HALF)
             output.data.float().sum().backward()
             self.assertEqual(x.grad.dtype, x.dtype)
 
diff -Nur '--exclude=.git' apex/tests/L0/run_amp/utils.py apex-develop/tests/L0/run_amp/utils.py
--- apex/tests/L0/run_amp/utils.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_amp/utils.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,7 +1,28 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import torch
+import numpy as np
+
+import sys
+sys.path.append('../')
+import device
 
-HALF = 'torch.cuda.HalfTensor'
-FLOAT = 'torch.cuda.FloatTensor'
+HALF = 'torch.npu.HalfTensor'
+FLOAT = 'torch.npu.FloatTensor'
 
 DTYPES = [torch.half, torch.float]
 
@@ -18,4 +39,28 @@
     test_case.c = 16
     test_case.k = 3
     test_case.t = 10
-    torch.set_default_tensor_type(torch.cuda.FloatTensor)
+    global HALF, FLOAT, DTYPES, ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
+    if device.is_npu():
+        HALF = 'torch.npu.HalfTensor'
+        FLOAT = 'torch.npu.FloatTensor'
+        torch.set_default_tensor_type(torch.FloatTensor)
+    else:
+        HALF = 'torch.cuda.HalfTensor'
+        FLOAT = 'torch.cuda.FloatTensor'
+        torch.set_default_tensor_type(torch.cuda.FloatTensor)
+
+    ALWAYS_HALF = {torch.float: HALF,
+                   torch.half: HALF}
+    ALWAYS_FLOAT = {torch.float: FLOAT,
+                    torch.half: FLOAT}
+    MATCH_INPUT = {torch.float: FLOAT,
+                   torch.half: HALF}
+
+def generate_data(min, max, shape, dtype):
+    if dtype == torch.float32:
+        dtype = np.float32
+    if dtype == torch.float16:
+        dtype = np.float16
+    input1 = np.random.uniform(min, max, shape).astype(dtype)
+    npu_input1 = torch.from_numpy(input1)
+    return npu_input1
\ No newline at end of file
diff -Nur '--exclude=.git' apex/tests/L0/run_optimizers/test_lamb.py apex-develop/tests/L0/run_optimizers/test_lamb.py
--- apex/tests/L0/run_optimizers/test_lamb.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_optimizers/test_lamb.py	2024-03-07 21:33:04.313391421 +0800
@@ -22,8 +22,6 @@
             numerical stability (default: 1e-6)
         weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
 
-    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
-        https://arxiv.org/abs/1904.00962
     """
 
     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
diff -Nur '--exclude=.git' apex/tests/L0/run_test.py apex-develop/tests/L0/run_test.py
--- apex/tests/L0/run_test.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L0/run_test.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,20 +1,72 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import unittest
 import sys
-
-test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
+import device
+import torch
+import torch_npu
+import argparse
 
 runner = unittest.TextTestRunner(verbosity=2)
-
 errcode = 0
 
-for test_dir in test_dirs:
-    suite = unittest.TestLoader().discover(test_dir)
-
-    print("\nExecuting tests from " + test_dir)
+parser = argparse.ArgumentParser()
+parser.add_argument('--npu',
+                default=0,
+                type=int,
+                help='NPU id to use.')
+args = parser.parse_args()
+
+device.CALCULATE_DEVICE = "npu:{}".format(args.npu)
+torch.npu.set_device(device.CALCULATE_DEVICE)
+
+if device.is_npu():
+    sys.path.append('./run_amp')
+    sys.path.append('../../apex/contrib/test/')
+    from test_basic_casts import TestBannedMethods, TestTensorCasts, TestBasicCasts
+    from test_cache import TestCache
+    from test_promotion import TestPromotion
+    from test_larc import TestLARC
+    from test_combine_tensors import TestCombineTensors
+    test_dirs = ["run_amp"]
+    suite=unittest.TestSuite()
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBannedMethods))
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestTensorCasts))
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBasicCasts))
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestCache))
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPromotion))
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestLARC))
+    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestCombineTensors))
 
     result = runner.run(suite)
-
     if not result.wasSuccessful():
         errcode = 1
+    sys.exit(errcode)
+else:
+    test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
+
+    for test_dir in test_dirs:
+        suite = unittest.TestLoader().discover(test_dir)
+
+        print("\nExecuting tests from " + test_dir)
+
+        result = runner.run(suite)
+
+        if not result.wasSuccessful():
+            errcode = 1
 
-sys.exit(errcode)
+    sys.exit(errcode)
diff -Nur '--exclude=.git' apex/tests/L1/common/main_amp.py apex-develop/tests/L1/common/main_amp.py
--- apex/tests/L1/common/main_amp.py	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L1/common/main_amp.py	2024-03-07 21:33:04.313391421 +0800
@@ -1,3 +1,19 @@
+# Copyright (c) 2020, Huawei Technologies.
+# Copyright (c) 2019, NVIDIA CORPORATION.
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import argparse
 import os
 import shutil
@@ -23,7 +39,9 @@
     from apex import amp, optimizers
     from apex.multi_tensor_apply import multi_tensor_applier
 except ImportError:
-    raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
+    raise ImportError("Please install apex")
+
+CALCULATE_DEVICE = "npu:0"
 
 model_names = sorted(name for name in models.__dict__
                      if name.islower() and not name.startswith("__")
@@ -73,7 +91,9 @@
 parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
 parser.add_argument('--loss-scale', type=str, default=None)
 parser.add_argument('--fused-adam', action='store_true')
-
+parser.add_argument('--npu-fused-sgd', action='store_true')
+parser.add_argument('--combine-grad', action='store_true')
+parser.add_argument('--npu', default=None, type=int, help ='NPU id to use.')
 parser.add_argument('--prints-to-process', type=int, default=10)
 
 cudnn.benchmark = True
@@ -99,7 +119,6 @@
 
 # Let multi_tensor_applier be the canary in the coalmine
 # that verifies if the backend is what we think it is
-assert multi_tensor_applier.available == args.has_ext 
 
 print("opt_level = {}".format(args.opt_level))
 print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
@@ -124,6 +143,12 @@
     args.gpu = 0
     args.world_size = 1
 
+    global CALCULATE_DEVICE
+    if args.npu is not None:
+        CALCULATE_DEVICE = "npu:{}".format(args.npu)
+    torch.npu.set_device(CALCULATE_DEVICE)
+    print("use ",CALCULATE_DEVICE)
+
     if args.distributed:
         args.gpu = args.local_rank % torch.cuda.device_count()
         torch.cuda.set_device(args.gpu)
@@ -139,32 +164,40 @@
         model = models.__dict__[args.arch](pretrained=True)
     else:
         print("=> creating model '{}'".format(args.arch))
-        model = models.__dict__[args.arch]()
+        model = models.__dict__[args.arch](zero_init_residual=True)
 
     if args.sync_bn:
         import apex
         print("using apex synced BN")
         model = apex.parallel.convert_syncbn_model(model)
 
-    model = model.cuda()
-
+    model = model.to(CALCULATE_DEVICE)
     # Scale learning rate based on global batch size
     args.lr = args.lr*float(args.batch_size*args.world_size)/256. 
     if args.fused_adam:
         optimizer = optimizers.FusedAdam(model.parameters())
+    elif args.npu_fused_sgd:
+        optimizer = optimizers.NpuFusedSGD(
+            [{'params': [param for name, param in model.named_parameters() if name[-4:] == 'bias'],
+              'weight_decay': 0.0},
+            {'params': [param for name, param in model.named_parameters() if name[-4:] != 'bias'],
+             'weight_decay': args.weight_decay}],
+            args.lr, momentum=args.momentum)
     else:
-        optimizer = torch.optim.SGD(model.parameters(), args.lr,
-                                    momentum=args.momentum,
-                                    weight_decay=args.weight_decay)
+        optimizer = torch.optim.SGD(
+            [{'params': [param for name, param in model.named_parameters() if name[-4:] == 'bias'],
+              'weight_decay': 0.0},
+             {'params': [param for name, param in model.named_parameters() if name[-4:] != 'bias'],
+              'weight_decay': args.weight_decay}],
+            args.lr, momentum=args.momentum)
 
     model, optimizer = amp.initialize(
         model, optimizer,
-        # enabled=False,
         opt_level=args.opt_level,
-        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
-        loss_scale=args.loss_scale
+        loss_scale=args.loss_scale,
+        combine_grad=args.combine_grad,
+        verbosity=1
         )
-
     if args.distributed:
         # By default, apex.parallel.DistributedDataParallel overlaps communication with 
         # computation in the backward pass.
@@ -173,8 +206,7 @@
         model = DDP(model, delay_allreduce=True)
 
     # define loss function (criterion) and optimizer
-    criterion = nn.CrossEntropyLoss().cuda()
-
+    criterion = nn.CrossEntropyLoss().to(CALCULATE_DEVICE)
     # Optionally resume from a checkpoint
     if args.resume:
         # Use a local scope to avoid dangling references
@@ -203,17 +235,22 @@
         crop_size = 224
         val_size = 256
 
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                     std=[0.229, 0.224, 0.225])
+
     train_dataset = datasets.ImageFolder(
         traindir,
         transforms.Compose([
             transforms.RandomResizedCrop(crop_size),
             transforms.RandomHorizontalFlip(),
-            # transforms.ToTensor(), Too slow
-            # normalize,
+            transforms.ToTensor(),
+            normalize,
         ]))
     val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
             transforms.Resize(val_size),
             transforms.CenterCrop(crop_size),
+            transforms.ToTensor(),
+            normalize,
         ]))
 
     train_sampler = None
@@ -224,14 +261,13 @@
 
     train_loader = torch.utils.data.DataLoader(
         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
-        num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
+        num_workers=args.workers, pin_memory=True, sampler=train_sampler)
 
     val_loader = torch.utils.data.DataLoader(
         val_dataset,
         batch_size=args.batch_size, shuffle=False,
         num_workers=args.workers, pin_memory=True,
-        sampler=val_sampler,
-        collate_fn=fast_collate)
+        sampler=val_sampler)
 
     if args.evaluate:
         validate(val_loader, model, criterion)
@@ -312,12 +348,7 @@
                      "Loss" : [],
                      "Speed" : []}
 
-    prefetcher = data_prefetcher(train_loader)
-    input, target = prefetcher.next()
-    i = -1
-    while input is not None:
-        i += 1
-
+    for i, (images, target) in enumerate(train_loader):
         # No learning rate warmup for this test, to expose bitwise inaccuracies more quickly
         # adjust_learning_rate(optimizer, epoch, i, len(train_loader))
 
@@ -328,7 +359,9 @@
         data_time.update(time.time() - end)
 
         # compute output
-        output = model(input)
+        images = images.to(CALCULATE_DEVICE, non_blocking=True)
+        target = target.to(torch.int32).to(CALCULATE_DEVICE, non_blocking=True)
+        output = model(images)
         loss = criterion(output, target)
 
         # measure accuracy and record loss
@@ -341,9 +374,9 @@
         else:
             reduced_loss = loss.data
 
-        losses.update(to_python_float(reduced_loss), input.size(0))
-        top1.update(to_python_float(prec1), input.size(0))
-        top5.update(to_python_float(prec5), input.size(0))
+        losses.update(to_python_float(reduced_loss), images.size(0))
+        top1.update(to_python_float(prec1), images.size(0))
+        top5.update(to_python_float(prec5), images.size(0))
 
         # compute gradient and do SGD step
         optimizer.zero_grad()
@@ -354,12 +387,8 @@
         # for param in model.parameters():
         #     print(param.data.double().sum().item(), param.grad.data.double().sum().item())
 
-        # torch.cuda.synchronize()
-        torch.cuda.nvtx.range_push("step")
         optimizer.step()
-        torch.cuda.nvtx.range_pop()
 
-        torch.cuda.synchronize()
         # measure elapsed time
         batch_time.update(time.time() - end)
 
@@ -367,7 +396,6 @@
 
         # If you decide to refactor this test, like examples/imagenet, to sample the loss every
         # print_freq iterations, make sure to move this prefetching below the accuracy calculation.
-        input, target = prefetcher.next()
 
         if i % args.print_freq == 0 and i > 1:
             if args.local_rank == 0:
@@ -388,10 +416,10 @@
             run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val)
             if len(run_info_dict["Loss"]) == args.prints_to_process:
                 if args.local_rank == 0:
+
                     torch.save(run_info_dict,
-                               str(args.has_ext) + "_" + str(args.opt_level) + "_" +
-                               str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
-                               str(args.fused_adam))
+                               str(args.combine_grad) + "_" + str(args.opt_level) + "_" +
+                               str(args.loss_scale) + "_" + str(args.npu_fused_sgd))
                 quit()
 
 
@@ -405,16 +433,12 @@
     model.eval()
 
     end = time.time()
-
-    prefetcher = data_prefetcher(val_loader)
-    input, target = prefetcher.next()
-    i = -1
-    while input is not None:
-        i += 1
-
+    for i, (images, target) in enumerate(val_loader):
+        images = images.to(CALCULATE_DEVICE, non_blocking=True)
+        target = target.to(torch.int32).to(CALCULATE_DEVICE, non_blocking=True)
         # compute output
         with torch.no_grad():
-            output = model(input)
+            output = model(images)
             loss = criterion(output, target)
 
         # measure accuracy and record loss
@@ -427,9 +451,9 @@
         else:
             reduced_loss = loss.data
 
-        losses.update(to_python_float(reduced_loss), input.size(0))
-        top1.update(to_python_float(prec1), input.size(0))
-        top5.update(to_python_float(prec5), input.size(0))
+        losses.update(to_python_float(reduced_loss), images.size(0))
+        top1.update(to_python_float(prec1), images.size(0))
+        top5.update(to_python_float(prec5), images.size(0))
 
         # measure elapsed time
         batch_time.update(time.time() - end)
@@ -448,8 +472,6 @@
                    batch_time=batch_time, loss=losses,
                    top1=top1, top5=top5))
 
-        input, target = prefetcher.next()
-
     print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
           .format(top1=top1, top5=top5))
 
diff -Nur '--exclude=.git' apex/tests/L1/cross_product/run.sh apex-develop/tests/L1/cross_product/run.sh
--- apex/tests/L1/cross_product/run.sh	2023-04-06 10:36:26.968937605 +0800
+++ apex-develop/tests/L1/cross_product/run.sh	2024-03-07 21:33:04.313391421 +0800
@@ -3,4 +3,5 @@
 # DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
 # DATADIR="/opt/home/apex/examples/imagenet/"
 cp ../common/* .
-bash run_test.sh single_gpu $1
+# bash run_test.sh single_gpu $1
+bash run_test_npu.sh single_npu $1 $2