diff --git a/rtdetrv2_pytorch/src/misc/dist_utils.py b/rtdetrv2_pytorch/src/misc/dist_utils.py
index 79f7944..a9ef2d4 100644
--- a/rtdetrv2_pytorch/src/misc/dist_utils.py
+++ b/rtdetrv2_pytorch/src/misc/dist_utils.py
@@ -40,17 +40,26 @@ def setup_distributed(print_rank: int=0, print_method: str='builtin', seed: int=
         WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
         
-        # torch.distributed.init_process_group(backend=backend, init_method='env://')
-        torch.distributed.init_process_group(init_method='env://')
+        # Determine backend and set device BEFORE init_process_group
+        try:
+            import torch_npu
+            backend = 'hccl'
+            # CRITICAL: Set device before HCCL initialization
+            torch.npu.set_device(LOCAL_RANK)
+            print(f'Using HCCL backend for NPU, device: {LOCAL_RANK}')
+        except:
+            backend = 'nccl'
+            torch.cuda.set_device(LOCAL_RANK)
+            print(f'Using NCCL backend for CUDA, device: {LOCAL_RANK}')
+
+        torch.distributed.init_process_group(backend=backend, init_method='env://')
         torch.distributed.barrier()
 
         rank = torch.distributed.get_rank()
-        torch.cuda.set_device(rank)
-        torch.cuda.empty_cache()
         enabled_dist = True
         print('Initialized distributed mode...')
 
-    except:
+    except Exception as e:
         enabled_dist = False
-        print('Not init distributed mode.')
+        print(f'Not init distributed mode. Error: {e}')
 
     setup_print(get_rank() == print_rank, method=print_method)
@@ -134,9 +143,26 @@ def warp_model(
     if is_dist_available_and_initialized():
         rank = get_rank()
-        model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model 
+
+        # Check if using NPU
+        try:
+            import torch_npu
+            use_npu = True
+        except:
+            use_npu = False
+
+        # SyncBatchNorm only works with CUDA, skip for NPU
+        if sync_bn and not use_npu:
+            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
         if dist_mode == 'dp':
-            model = DP(model, device_ids=[rank], output_device=rank)
+            if use_npu:
+                model = DP(model)
+            else:
+                model = DP(model, device_ids=[rank], output_device=rank)
         elif dist_mode == 'ddp':
-            model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters)
+            if use_npu:
+                model = DDP(model, find_unused_parameters=find_unused_parameters)
+            else:
+                model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters)
         else:
             raise AttributeError('')
 
diff --git a/rtdetrv2_pytorch/src/misc/logger.py b/rtdetrv2_pytorch/src/misc/logger.py
index 2ef0c27..ba5d7c8 100644
--- a/rtdetrv2_pytorch/src/misc/logger.py
+++ b/rtdetrv2_pytorch/src/misc/logger.py
@@ -40,7 +40,17 @@ class SmoothedValue(object):
         """
         if not is_dist_available_and_initialized():
             return
-        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+
+        # Detect device type: NPU or CUDA
+        try:
+            import torch_npu
+            device = 'npu'
+        except:
+            device = 'cuda'
+
+        # HCCL doesn't support float64, use float32 for NPU
+        dtype = torch.float32 if device == 'npu' else torch.float64
+        t = torch.tensor([self.count, self.total], dtype=dtype, device=device)
         tdist.barrier()
         tdist.all_reduce(t)
         t = t.tolist()
diff --git a/rtdetrv2_pytorch/src/solver/_solver.py b/rtdetrv2_pytorch/src/solver/_solver.py
index 51e9bef..681265e 100644
--- a/rtdetrv2_pytorch/src/solver/_solver.py
+++ b/rtdetrv2_pytorch/src/solver/_solver.py
@@ -1,14 +1,20 @@
 """Copyright(c) 2023 lyuwenyu. All Rights Reserved.
 """
 
 import torch 
 import torch.nn as nn 
 
 from datetime import datetime
 from pathlib import Path 
 from typing import Dict
 import atexit
 
+# Import torch_npu if available to register NPU device type
+try:
+    import torch_npu
+except ImportError:
+    pass
+
 from ..misc import dist_utils
 from ..core import BaseConfig
 
@@ -24,11 +30,29 @@ class BaseSolver(object):
         self.cfg = cfg 
 
     def _setup(self, ):
         """Avoid instantiating unnecessary classes 
         """
         cfg = self.cfg
-        if cfg.device:
-            device = torch.device(cfg.device)
+
+        # Determine device: distributed mode uses rank, otherwise use cfg.device
+        if dist_utils.is_dist_available_and_initialized():
+            rank = dist_utils.get_rank()
+            try:
+                import torch_npu
+                device = torch.device(f'npu:{rank}')
+            except:
+                device = torch.device(f'cuda:{rank}')
+        elif cfg.device:
+            # Handle NPU device string
+            if 'npu' in cfg.device:
+                try:
+                    import torch_npu
+                    device_id = int(cfg.device.split(':')[1]) if ':' in cfg.device else 0
+                    device = torch.device(f'npu:{device_id}')
+                except:
+                    device = torch.device('cpu')
+            else:
+                device = torch.device(cfg.device)
         else:
             device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
diff --git a/rtdetrv2_pytorch/src/solver/det_engine.py b/rtdetrv2_pytorch/src/solver/det_engine.py
index 441ef39..c64065c 100644
--- a/rtdetrv2_pytorch/src/solver/det_engine.py
+++ b/rtdetrv2_pytorch/src/solver/det_engine.py
@@ -18,6 +18,58 @@ from ..optim import ModelEMA, Warmup
 from ..data import CocoEvaluator
 from ..misc import MetricLogger, SmoothedValue, dist_utils
 
+# Monkey-patch faster_coco_eval to support NPU
+try:
+    import torch_npu
+    from faster_coco_eval.utils.pytorch import coco_eval as faster_coco_eval_module
+
+    original_all_gather = faster_coco_eval_module.all_gather
+
+    def patched_all_gather(data, world_size):
+        """Patched all_gather that supports NPU"""
+        import pickle
+        import torch.distributed as dist
+
+        # Get world_size from distributed context if not provided
+        if world_size is None:
+            world_size = dist.get_world_size()
+
+        # Serialize data
+        byte_array = pickle.dumps(data)
+
+        # Use NPU device instead of CUDA
+        tensor = torch.ByteTensor(list(byte_array)).to("npu")
+
+        # Get size tensor
+        size_tensor = torch.LongTensor([len(byte_array)]).to("npu")
+
+        # Gather sizes
+        size_list = [torch.LongTensor([0]).to("npu") for _ in range(world_size)]
+        dist.all_gather(size_list, size_tensor)
+
+        max_size = max([int(size.item()) for size in size_list])
+
+        # Pad tensor to max size
+        if len(byte_array) < max_size:
+            padding = torch.ByteTensor([0] * (max_size - len(byte_array))).to("npu")
+            tensor = torch.cat([tensor, padding])
+
+        # Gather tensors
+        tensor_list = [torch.ByteTensor([0] * max_size).to("npu") for _ in range(world_size)]
+        dist.all_gather(tensor_list, tensor)
+
+        # Deserialize
+        data_list = []
+        for size, tensor in zip(size_list, tensor_list):
+            byte_data = bytes(tensor[:int(size.item())].cpu().tolist())
+            data_list.append(pickle.loads(byte_data))
+
+        return data_list
+
+    faster_coco_eval_module.all_gather = patched_all_gather
+except ImportError:
+    pass
+
 
 def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                     data_loader: Iterable, optimizer: torch.optim.Optimizer,
@@ -42,10 +94,17 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
         metas = dict(epoch=epoch, step=i, global_step=global_step)
 
         if scaler is not None:
-            with torch.autocast(device_type=str(device), cache_enabled=True):
+            # Handle NPU device type for autocast compatibility
+            autocast_device_type = str(device)
+            if 'npu' in autocast_device_type:
+                # Map NPU device type to 'cuda' for autocast compatibility
+                # Both NPU and CUDA support mixed precision, so this is safe
+                autocast_device_type = 'cuda'
+
+            with torch.autocast(device_type=autocast_device_type, cache_enabled=True):
                 outputs = model(samples, targets=targets)
             
-            with torch.autocast(device_type=str(device), enabled=False):
+            with torch.autocast(device_type=autocast_device_type, enabled=False):
                 loss_dict = criterion(outputs, targets, **metas)
 
             loss = sum(loss_dict.values())
@@ -136,7 +195,10 @@ def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessor,
     metric_logger.synchronize_between_processes()
     print("Averaged stats:", metric_logger)
     if coco_evaluator is not None:
-        coco_evaluator.synchronize_between_processes()
+        # Check if distributed training is available and initialized to avoid RuntimeError
+        from ..misc import dist_utils
+        if dist_utils.is_dist_available_and_initialized():
+            coco_evaluator.synchronize_between_processes()
 
     # accumulate predictions from all images
     if coco_evaluator is not None:
diff --git a/rtdetrv2_pytorch/tools/train.py b/rtdetrv2_pytorch/tools/train.py
index 280caa8..95b488a 100644
--- a/rtdetrv2_pytorch/tools/train.py
+++ b/rtdetrv2_pytorch/tools/train.py
@@ -7,6 +7,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'
 
 import argparse
 
+# Import torch_npu if available to register NPU device type
+try:
+    import torch_npu
+    from torch_npu.contrib import transfer_to_npu
+except ImportError:
+    pass
+
 from src.misc import dist_utils
 from src.core import YAMLConfig, yaml_utils
 from src.solver import TASKS
diff --git a/rtdetrv2_pytorch/tools/export_onnx.py b/rtdetrv2_pytorch/tools/export_onnx.py
index 15863192f..5f186f787 100644
--- a/rtdetrv2_pytorch/tools/export_onnx.py
+++ b/rtdetrv2_pytorch/tools/export_onnx.py
@@ -65,6 +65,7 @@ def main(args, ):
         opset_version=16,
         verbose=False,
         do_constant_folding=True,
+        dynamo=False,
     )
 
     if args.check: