diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py
index fe26e8b4..4451f277 100644
--- a/megatron/core/distributed/__init__.py
+++ b/megatron/core/distributed/__init__.py
@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads
 from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
 from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
 from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
+
+# Backward compatibility patch for FSDP module reorganization
+import sys
+import importlib.util
+
+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp')
+if spec:
+    custom_fsdp = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(custom_fsdp)
+    sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp
+    if hasattr(custom_fsdp, 'MegatronFSDP'):
+        custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP
diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py
index 99c3edc0..26ea5cb4 100644
--- a/megatron/core/extensions/transformer_engine.py
+++ b/megatron/core/extensions/transformer_engine.py
@@ -404,6 +404,7 @@ class TELinear(te.pytorch.Linear):
         )
 
         for param in self.parameters():
+            setattr(param, "parallel_mode", parallel_mode)
             if is_expert:
                 # Reduce the gradient on the expert_data_parallel group for expert linear layers
                 setattr(param, "allreduce", not self.expert_parallel)
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
index 002edb92..f7273488 100755
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ b/megatron/core/models/gpt/gpt_layer_specs.py
@@ -80,6 +80,8 @@ def get_gpt_layer_with_transformer_engine_spec(
     use_te_op_fuser: Optional[bool] = False,
     use_kitchen: bool = False,
     use_te_activation_func: bool = False,
+    post_self_attn_layernorm: bool = False,
+    post_mlp_layernorm: bool = False,
 ) -> ModuleSpec:
     """Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
 
@@ -182,9 +184,11 @@ def get_gpt_layer_with_transformer_engine_spec(
                     ),
                 ),
                 self_attn_bda=get_bias_dropout_add,
+                post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp,
                 pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
                 mlp=mlp,
                 mlp_bda=get_bias_dropout_add,
+                post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
                 sharded_state_dict_keys_map={
                     "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
                     "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
index df9adc3e..2f4f544a 100644
--- a/megatron/core/models/gpt/gpt_model.py
+++ b/megatron/core/models/gpt/gpt_model.py
@@ -443,7 +443,7 @@ class GPTModel(LanguageModule):
         if self.share_embeddings_and_output_weights:
             output_weight = self.shared_embedding_or_output_weight()
 
-        if mtp_in_postprocess:
+        if mtp_in_postprocess and labels is not None:
             hidden_states = self.mtp(
                 input_ids=input_ids,
                 position_ids=position_ids,
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index 57332ac3..f3abd642 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -9,6 +9,7 @@ from typing import Callable, List, Optional
 
 import numpy as np
 import torch
+import torch.distributed as dist
 
 from .utils import GlobalMemoryBuffer, is_torch_min_version
 
@@ -163,6 +164,213 @@ def get_nccl_options(pg_name, nccl_comm_cfgs):
         return None
 
 
+old_new_group = None
+
+
+def monkey_patch_torch_dist():
+    print("Applying monkey patch to torch.distributed", flush=True)
+    global old_new_group
+    if old_new_group is not None:
+        return
+
+    old_new_group = dist.new_group
+
+    def new_group(*args, **kwargs):
+        group = old_new_group(*args, **kwargs)
+        # skip none nccl group.
+        if (
+            len(args) >= 3 and args[2] == "gloo" or
+            "backend" in kwargs and kwargs["backend"] == "gloo"
+        ):
+            return group
+        
+        # Get ranks from arguments
+        if len(args) >= 1 and args[0] is not None:
+            ranks = args[0]
+        elif "ranks" in kwargs and kwargs["ranks"] is not None:
+            ranks = kwargs["ranks"]
+        else:
+            # If no ranks specified, use all ranks in world
+            ranks = list(range(dist.get_world_size()))
+
+        if len(ranks) == 1:
+            return group
+
+        group = ReloadableProcessGroup(group, ranks)
+        return group
+
+    dist.new_group = new_group
+
+    def get_new_function(func):
+        def new_function(*args, **kwargs):
+            args = (
+                arg.group if isinstance(arg, ReloadableProcessGroup) else arg
+                for arg in args
+            )
+            kwargs = {
+                k: (v.group if isinstance(v, ReloadableProcessGroup) else v)
+                for k, v in kwargs.items()
+            }
+            return func(*args, **kwargs)
+        return new_function
+
+    dist.get_rank = get_new_function(dist.get_rank)
+    dist.get_world_size = get_new_function(dist.get_world_size)
+    dist.get_backend = get_new_function(dist.get_backend)
+    dist.get_global_rank = get_new_function(dist.get_global_rank)
+    dist.get_group_rank = get_new_function(dist.get_group_rank)
+    dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks)
+
+    dist.all_reduce = get_new_function(dist.all_reduce)
+    dist.all_gather = get_new_function(dist.all_gather)
+    dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor)
+    dist.all_gather_object = get_new_function(dist.all_gather_object)
+    dist.all_to_all = get_new_function(dist.all_to_all)
+    dist.all_to_all_single = get_new_function(dist.all_to_all_single)
+    dist.broadcast = get_new_function(dist.broadcast)
+    dist.reduce = get_new_function(dist.reduce)
+    dist.reduce_scatter = get_new_function(dist.reduce_scatter)
+    dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor)
+    dist.scatter = get_new_function(dist.scatter)
+    dist.gather = get_new_function(dist.gather)
+    dist.barrier = get_new_function(dist.barrier)
+    dist.send = get_new_function(dist.send)
+    dist.recv = get_new_function(dist.recv)
+    dist._coalescing_manager = get_new_function(dist._coalescing_manager)
+
+    # p2p
+    old_isend = dist.isend
+    old_irecv = dist.irecv
+
+    dist.isend = get_new_function(dist.isend)
+    dist.irecv = get_new_function(dist.irecv)
+
+    def get_new_p2pop_function(func):
+        def new_function(*args, **kwargs):
+            def convert(arg):
+                if isinstance(arg, ReloadableProcessGroup):
+                    return arg.group
+                elif arg == dist.isend:
+                    arg = old_isend
+                elif arg == dist.irecv:
+                    arg = old_irecv
+                return arg
+
+            args = (convert(arg) for arg in args)
+            kwargs = {
+                k: convert(v)
+                for k, v in kwargs.items()
+            }
+            return func(*args, **kwargs)
+        return new_function
+    
+    dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__)
+    dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__)
+
+
+
+class ReloadableProcessGroup(torch.distributed.ProcessGroup):
+    GROUPS = []
+
+    def __init__(self, group, ranks):
+        super().__init__(
+            rank=dist.get_rank(group),
+            size=dist.get_world_size(group),
+        )
+        #print(f"Creating ReloadableProcessGroup with ranks: {ranks}", flush=True)
+        self.group = group
+        self.group_info = {
+            "ranks": ranks,
+        }
+        ReloadableProcessGroup.GROUPS.append(self)
+
+    def __getattr__(self, name):
+        return getattr(self.group, name)
+
+    @staticmethod
+    def destroy_process_groups():
+        for reloadable_group in ReloadableProcessGroup.GROUPS:
+            if reloadable_group.group is None:
+                continue
+            #print(f"Destroying process group: {reloadable_group.group_info['ranks']}")
+            dist.destroy_process_group(reloadable_group.group)
+            del reloadable_group.group
+            reloadable_group.group = None
+
+    @staticmethod
+    def reload_process_groups():
+        for reloadable_group in ReloadableProcessGroup.GROUPS:
+            if reloadable_group.group is not None:
+                continue
+            #print(f"Reloading process group: {reloadable_group.group_info['ranks']}")
+            group = old_new_group(
+                ranks=reloadable_group.group_info["ranks"],
+                backend="nccl"
+            )
+            reloadable_group.group = group
+
+    def rank(self) -> int: return self.group.rank()
+    def size(self) -> int: return self.group.size()
+    def name(self) -> str: return self.group.name()
+
+    def shutdown(self) -> None:
+        if self.group is not None:
+            self.group.shutdown()
+
+    def abort(self) -> None:
+        if self.group is not None:
+            self.group.abort()
+
+    def _fwd(self, method, *args, **kwargs):
+        inner = self.group
+        if inner is None:
+            raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.")
+        return getattr(inner, method)(*args, **kwargs)
+
+    def barrier(self, *a, **kw): return self._fwd("barrier", *a, **kw)
+    def broadcast(self, *a, **kw): return self._fwd("broadcast", *a, **kw)
+    def allreduce(self, *a, **kw): return self._fwd("allreduce", *a, **kw)
+    def allreduce_coalesced(self, *a, **kw): return self._fwd("allreduce_coalesced", *a, **kw)
+    def reduce(self, *a, **kw): return self._fwd("reduce", *a, **kw)
+    def allgather(self, *a, **kw): return self._fwd("allgather", *a, **kw)
+    def _allgather_base(self, *a, **kw): return self._fwd("_allgather_base", *a, **kw)
+    def allgather_coalesced(self, *a, **kw): return self._fwd("allgather_coalesced", *a, **kw)
+    def allgather_into_tensor_coalesced(self, *a, **kw): return self._fwd("allgather_into_tensor_coalesced", *a, **kw)
+    def gather(self, *a, **kw): return self._fwd("gather", *a, **kw)
+    def scatter(self, *a, **kw): return self._fwd("scatter", *a, **kw)
+    def reduce_scatter(self, *a, **kw): return self._fwd("reduce_scatter", *a, **kw)
+    def _reduce_scatter_base(self, *a, **kw): return self._fwd("_reduce_scatter_base", *a, **kw)
+    def reduce_scatter_tensor_coalesced(self, *a, **kw): return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw)
+    def alltoall_base(self, *a, **kw): return self._fwd("alltoall_base", *a, **kw)
+    def alltoall(self, *a, **kw): return self._fwd("alltoall", *a, **kw)
+    def send(self, *a, **kw): return self._fwd("send", *a, **kw)
+    def recv(self, *a, **kw): return self._fwd("recv", *a, **kw)
+    def recv_anysource(self, *a, **kw): return self._fwd("recv_anysource", *a, **kw)
+
+    def _start_coalescing(self, *a, **kw): return self._fwd("_start_coalescing", *a, **kw)
+    def _end_coalescing(self, *a, **kw): return self._fwd("_end_coalescing", *a, **kw)
+    def _get_backend_name(self): return self._fwd("_get_backend_name")
+    def _get_backend(self, *a, **kw): return self._fwd("_get_backend", *a, **kw)
+    def _set_default_backend(self, *a, **kw): return self._fwd("_set_default_backend", *a, **kw)
+    @property
+    def bound_device_id(self): return self.group.bound_device_id
+    @bound_device_id.setter
+    def bound_device_id(self, dev): self.group.bound_device_id = dev
+
+
+def destroy_process_groups():
+    """Destroy all reloadable process groups."""
+    ReloadableProcessGroup.destroy_process_groups()
+
+
+def reload_process_groups():
+    """Reload all reloadable process groups."""
+    ReloadableProcessGroup.reload_process_groups()
+
+
+monkey_patch_torch_dist()
+
+
 def create_group(
     ranks=None,
     timeout=None,
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index 63ee9d1f..b90b744c 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -26,22 +26,22 @@ def _batched_p2p_ops(
     ops = []
     if tensor_send_prev is not None:
         send_prev_op = torch.distributed.P2POp(
-            torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
+            torch.distributed.isend, tensor_send_prev, prev_pipeline_rank,
         )
         ops.append(send_prev_op)
     if tensor_recv_prev is not None:
         recv_prev_op = torch.distributed.P2POp(
-            torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
+            torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank,
         )
         ops.append(recv_prev_op)
     if tensor_send_next is not None:
         send_next_op = torch.distributed.P2POp(
-            torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
+            torch.distributed.isend, tensor_send_next, next_pipeline_rank,
         )
         ops.append(send_next_op)
     if tensor_recv_next is not None:
         recv_next_op = torch.distributed.P2POp(
-            torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
+            torch.distributed.irecv, tensor_recv_next, next_pipeline_rank,
         )
         ops.append(recv_next_op)
     if len(ops) > 0:
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index 6f557e1f..b295fd35 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig):
     qk_layernorm: bool = False
     """Whether to apply `normalization` type of normalization to the query and key embeddings."""
 
+    post_self_attn_layernorm: bool = False
+    post_mlp_layernorm: bool = False
+
     test_mode: bool = False
     """Whether to run real-time tests."""
 
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 84f22bde..b4807d26 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -224,6 +224,7 @@ class TransformerLayerSubmodules:
     input_layernorm: Union[ModuleSpec, type] = IdentityOp
     self_attention: Union[ModuleSpec, type] = IdentityOp
     self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
+    post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
 
     pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
     cross_attention: Union[ModuleSpec, type] = IdentityOp
@@ -232,6 +233,7 @@ class TransformerLayerSubmodules:
     pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
     mlp: Union[ModuleSpec, type] = IdentityOp
     mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
+    post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
 
     # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
     sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
@@ -336,6 +338,14 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
         # [Module 3: BiasDropoutFusion]
         self.self_attn_bda = build_module(submodules.self_attn_bda)
 
+        self.post_self_attn_layernorm = build_module(
+            submodules.post_self_attn_layernorm,
+            config=self.config,
+            hidden_size=self.config.hidden_size,
+            eps=self.config.layernorm_epsilon,
+        )
+
+
         # [Module 4: Post SelfAttention] Optional Layernorm after self-attn
         self.pre_cross_attn_layernorm = build_module(
             submodules.pre_cross_attn_layernorm,
@@ -399,6 +409,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
         # [Module 9: BiasDropoutFusion]
         self.mlp_bda = build_module(submodules.mlp_bda)
 
+        self.post_mlp_layernorm = build_module(
+            submodules.post_mlp_layernorm,
+            config=self.config,
+            hidden_size=self.config.hidden_size,
+            eps=self.config.layernorm_epsilon
+        )
+
         self.recompute_input_layernorm = False
         self.recompute_pre_mlp_layernorm = False
         self.recompute_mlp = False
@@ -535,6 +552,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
                 attention_output_with_bias[0]
             )
 
+        attention_output, attention_output_bias = attention_output_with_bias
+        attention_output = self.post_self_attn_layernorm(attention_output)
+        attention_output_with_bias = (attention_output, attention_output_bias)
+
+
         # TODO: could we move `bias_dropout_add_exec_handler` itself
         # inside the module provided in the `bias_dropout_add_spec` module?
         nvtx_range_push(suffix="self_attn_bda")
@@ -635,6 +657,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
         else:
             mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
 
+        mlp_output, mlp_output_bias = mlp_output_with_bias
+        mlp_output = self.post_mlp_layernorm(mlp_output)
+        mlp_output_with_bias = (mlp_output, mlp_output_bias)
+
         if self.recompute_pre_mlp_layernorm:
             # discard the output of the pre-mlp layernorm and register the recompute
             # as a gradient hook of mlp_output_with_bias[0]
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index 24ba8926..4f039fd4 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -1191,6 +1191,9 @@ def core_transformer_config_from_args(args, config_class=None):
     if args.is_hybrid_model:
         kw_args['is_hybrid_model'] = args.is_hybrid_model
 
+    kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
+    kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
+
     # handle quantization config
     # NOTE: Kitchen arguments are only added to the namespace when
     # Kitchen library is available.
@@ -1481,6 +1484,10 @@ def _add_network_size_args(parser):
                        action='store_true',
                        help='If set, use original BERT residula connection '
                        'ordering.')
+    group.add_argument('--post-self-attn-layernorm', action='store_true',
+                       help='If set, use post self attention layernorm.')
+    group.add_argument('--post-mlp-layernorm', action='store_true',
+                       help='If set, use post MLP layernorm.')
     group.add_argument('--openai-gelu', action='store_true',
                        help='Use OpenAIs GeLU implementation. This option'
                        'should not be used unless for backward compatibility'