diff --git a/projects/configs/bevformer/bevformer_base.py b/projects/configs/bevformer/bevformer_base.py
index fda635c..0029547 100644
--- a/projects/configs/bevformer/bevformer_base.py
+++ b/projects/configs/bevformer/bevformer_base.py
@@ -32,6 +32,7 @@ _dim_ = 256
 _pos_dim_ = _dim_//2
 _ffn_dim_ = _dim_*2
 _num_levels_ = 4
+bs_ = 1
 bev_h_ = 200
 bev_w_ = 200
 queue_length = 4 # each sequence contains `queue_length` frames.
@@ -61,6 +62,7 @@ model = dict(
         relu_before_extra_convs=True),
     pts_bbox_head=dict(
         type='BEVFormerHead',
+        bs=bs_,
         bev_h=bev_h_,
         bev_w=bev_w_,
         num_query=900,
@@ -83,6 +85,8 @@ model = dict(
                 return_intermediate=False,
                 transformerlayers=dict(
                     type='BEVFormerLayer',
+                    bev_h=bev_h_,
+                    bev_w=bev_w_,
                     attn_cfgs=[
                         dict(
                             type='TemporalSelfAttention',
@@ -226,7 +230,8 @@ data = dict(
 )
 
 optimizer = dict(
-    type='AdamW',
+    # type='AdamW',
+    type='NpuFusedAdamW',
     lr=2e-4,
     paramwise_cfg=dict(
         custom_keys={
diff --git a/projects/configs/bevformer_fp16/bevformer_base_fp16.py b/projects/configs/bevformer_fp16/bevformer_base_fp16.py
new file mode 100644
index 0000000..7f39f01
--- /dev/null
+++ b/projects/configs/bevformer_fp16/bevformer_base_fp16.py
@@ -0,0 +1,261 @@
+_base_ = [
+    '../datasets/custom_nus-3d.py',
+    '../_base_/default_runtime.py'
+]
+#
+plugin = True
+plugin_dir = 'projects/mmdet3d_plugin/'
+
+# If point cloud range is changed, the models should also change their point
+# cloud range accordingly
+point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
+voxel_size = [0.2, 0.2, 8]
+
+
+
+img_norm_cfg = dict(
+    mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
+# For nuScenes we usually do 10-class detection
+class_names = [
+    'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
+    'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
+]
+
+input_modality = dict(
+    use_lidar=False,
+    use_camera=True,
+    use_radar=False,
+    use_map=False,
+    use_external=True)
+
+_dim_ = 256
+_pos_dim_ = _dim_//2
+_ffn_dim_ = _dim_*2
+_num_levels_ = 4
+bev_h_ = 200
+bev_w_ = 200
+queue_length = 4 # each sequence contains `queue_length` frames.
+
+model = dict(
+    type='BEVFormer_fp16',
+    use_grid_mask=True,
+    video_test_mode=True,
+    img_backbone=dict(
+        type='ResNet',
+        depth=101,
+        num_stages=4,
+        out_indices=(1, 2, 3),
+        frozen_stages=1,
+        norm_cfg=dict(type='BN2d', requires_grad=False),
+        norm_eval=True,
+        style='caffe',
+        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), # original DCNv2 will print log when perform load_state_dict
+        stage_with_dcn=(False, False, True, True)),
+    img_neck=dict(
+        type='FPN',
+        in_channels=[512, 1024, 2048],
+        out_channels=_dim_,
+        start_level=0,
+        add_extra_convs='on_output',
+        num_outs=4,
+        relu_before_extra_convs=True),
+    pts_bbox_head=dict(
+        type='BEVFormerHead',
+        bev_h=bev_h_,
+        bev_w=bev_w_,
+        num_query=900,
+        num_classes=10,
+        in_channels=_dim_,
+        sync_cls_avg_factor=True,
+        with_box_refine=True,
+        as_two_stage=False,
+        transformer=dict(
+            type='PerceptionTransformer',
+            rotate_prev_bev=True,
+            use_shift=True,
+            use_can_bus=True,
+            embed_dims=_dim_,
+            encoder=dict(
+                type='BEVFormerEncoder',
+                num_layers=6,
+                pc_range=point_cloud_range,
+                num_points_in_pillar=4,
+                return_intermediate=False,
+                transformerlayers=dict(
+                    type='BEVFormerLayer',
+                    bev_h=bev_h_,
+                    bev_w=bev_w_,
+                    attn_cfgs=[
+                        dict(
+                            type='TemporalSelfAttention',
+                            embed_dims=_dim_,
+                            num_levels=1),
+                        dict(
+                            type='SpatialCrossAttention',
+                            pc_range=point_cloud_range,
+                            deformable_attention=dict(
+                                type='MSDeformableAttention3D',
+                                embed_dims=_dim_,
+                                num_points=8,
+                                num_levels=_num_levels_),
+                            embed_dims=_dim_,
+                        )
+                    ],
+                    feedforward_channels=_ffn_dim_,
+                    ffn_dropout=0.1,
+                    operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
+                                     'ffn', 'norm'))),
+            decoder=dict(
+                type='DetectionTransformerDecoder',
+                num_layers=6,
+                return_intermediate=True,
+                transformerlayers=dict(
+                    type='DetrTransformerDecoderLayer',
+                    attn_cfgs=[
+                        dict(
+                            type='MultiheadAttention',
+                            embed_dims=_dim_,
+                            num_heads=8,
+                            dropout=0.1),
+                         dict(
+                            type='CustomMSDeformableAttention',
+                            embed_dims=_dim_,
+                            num_levels=1),
+                    ],
+
+                    feedforward_channels=_ffn_dim_,
+                    ffn_dropout=0.1,
+                    operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
+                                     'ffn', 'norm')))),
+        bbox_coder=dict(
+            type='NMSFreeCoder',
+            post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
+            pc_range=point_cloud_range,
+            max_num=300,
+            voxel_size=voxel_size,
+            num_classes=10),
+        positional_encoding=dict(
+            type='LearnedPositionalEncoding',
+            num_feats=_pos_dim_,
+            row_num_embed=bev_h_,
+            col_num_embed=bev_w_,
+            ),
+        loss_cls=dict(
+            type='FocalLoss',
+            use_sigmoid=True,
+            gamma=2.0,
+            alpha=0.25,
+            loss_weight=2.0),
+        loss_bbox=dict(type='L1Loss', loss_weight=0.25),
+        loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
+    # model training and testing settings
+    train_cfg=dict(pts=dict(
+        grid_size=[512, 512, 1],
+        voxel_size=voxel_size,
+        point_cloud_range=point_cloud_range,
+        out_size_factor=4,
+        assigner=dict(
+            type='HungarianAssigner3D',
+            cls_cost=dict(type='FocalLossCost', weight=2.0),
+            reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
+            iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head.
+            pc_range=point_cloud_range))))
+
+dataset_type = 'CustomNuScenesDataset'
+data_root = 'data/nuscenes/'
+file_client_args = dict(backend='disk')
+
+
+train_pipeline = [
+    dict(type='LoadMultiViewImageFromFiles', to_float32=True),
+    dict(type='PhotoMetricDistortionMultiViewImage'),
+    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
+    dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
+    dict(type='ObjectNameFilter', classes=class_names),
+    dict(type='NormalizeMultiviewImage', **img_norm_cfg),
+    dict(type='PadMultiViewImage', size_divisor=32),
+    dict(type='DefaultFormatBundle3D', class_names=class_names),
+    dict(type='CustomCollect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'])
+]
+
+test_pipeline = [
+    dict(type='LoadMultiViewImageFromFiles', to_float32=True),
+    dict(type='NormalizeMultiviewImage', **img_norm_cfg),
+    dict(type='PadMultiViewImage', size_divisor=32),
+    dict(
+        type='MultiScaleFlipAug3D',
+        img_scale=(1600, 900),
+        pts_scale_ratio=1,
+        flip=False,
+        transforms=[
+            dict(
+                type='DefaultFormatBundle3D',
+                class_names=class_names,
+                with_label=False),
+            dict(type='CustomCollect3D', keys=['img'])
+        ])
+]
+
+data = dict(
+    samples_per_gpu=1,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
+        pipeline=train_pipeline,
+        classes=class_names,
+        modality=input_modality,
+        test_mode=False,
+        use_valid_flag=True,
+        bev_size=(bev_h_, bev_w_),
+        queue_length=queue_length,
+        # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
+        # and box_type_3d='Depth' in sunrgbd and scannet dataset.
+        box_type_3d='LiDAR'),
+    val=dict(type=dataset_type,
+             data_root=data_root,
+             ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
+             pipeline=test_pipeline,  bev_size=(bev_h_, bev_w_),
+             classes=class_names, modality=input_modality, samples_per_gpu=1),
+    test=dict(type=dataset_type,
+              data_root=data_root,
+              ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
+              pipeline=test_pipeline, bev_size=(bev_h_, bev_w_),
+              classes=class_names, modality=input_modality),
+    shuffler_sampler=dict(type='DistributedGroupSampler'),
+    nonshuffler_sampler=dict(type='DistributedSampler')
+)
+
+optimizer = dict(
+    type='AdamW',
+    lr=2e-4,
+    paramwise_cfg=dict(
+        custom_keys={
+            'img_backbone': dict(lr_mult=0.1),
+        }),
+    weight_decay=0.01)
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='CosineAnnealing',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    min_lr_ratio=1e-3)
+total_epochs = 24
+evaluation = dict(interval=1, pipeline=test_pipeline)
+
+runner = dict(type='EpochBasedRunner_video', max_epochs=total_epochs)
+load_from = 'ckpts/r101_dcn_fcos3d_pretrain.pth'
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        dict(type='TensorboardLoggerHook')
+    ])
+
+fp16 = dict(loss_scale=512.)
+checkpoint_config = dict(interval=1)
+custom_hooks = [dict(type='TransferWeight',priority='LOWEST')]
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py b/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py
index e57bd22..03c3589 100644
--- a/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py
+++ b/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py
@@ -9,7 +9,7 @@ import warnings
 import numpy as np
 import torch
 import torch.distributed as dist
-from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
 from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
                          Fp16OptimizerHook, OptimizerHook, build_optimizer,
                          build_runner, get_dist_info)
@@ -72,22 +72,22 @@ def custom_train_detector(model,
         find_unused_parameters = cfg.get('find_unused_parameters', False)
         # Sets the `find_unused_parameters` parameter in
         # torch.nn.parallel.DistributedDataParallel
-        model = MMDistributedDataParallel(
+        model = NPUDistributedDataParallel(
             model.cuda(),
             device_ids=[torch.cuda.current_device()],
             broadcast_buffers=False,
             find_unused_parameters=find_unused_parameters)
         if eval_model is not None:
-            eval_model = MMDistributedDataParallel(
+            eval_model = NPUDistributedDataParallel(
                 eval_model.cuda(),
                 device_ids=[torch.cuda.current_device()],
                 broadcast_buffers=False,
                 find_unused_parameters=find_unused_parameters)
     else:
-        model = MMDataParallel(
+        model = NPUDataParallel(
             model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
         if eval_model is not None:
-            eval_model = MMDataParallel(
+            eval_model = NPUDataParallel(
                 eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
 
 
diff --git a/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py b/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
index 93c7cd7..4ca3e63 100644
--- a/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
+++ b/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
@@ -34,6 +34,7 @@ class BEVFormerHead(DETRHead):
                  bbox_coder=None,
                  num_cls_fcs=2,
                  code_weights=None,
+                 bs=1,
                  bev_h=30,
                  bev_w=30,
                  **kwargs):
@@ -65,6 +66,7 @@ class BEVFormerHead(DETRHead):
             *args, transformer=transformer, **kwargs)
         self.code_weights = nn.Parameter(torch.tensor(
             self.code_weights, requires_grad=False), requires_grad=False)
+        self.bev_mask = torch.zeros((bs, self.bev_h, self.bev_w)).npu()
 
     def _init_layers(self):
         """Initialize classification branch and regression branch of head."""
@@ -131,14 +133,17 @@ class BEVFormerHead(DETRHead):
                 head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
                 Shape [nb_dec, bs, num_query, 9].
         """
-        bs, num_cam, _, _, _ = mlvl_feats[0].shape
         dtype = mlvl_feats[0].dtype
-        object_query_embeds = self.query_embedding.weight.to(dtype)
-        bev_queries = self.bev_embedding.weight.to(dtype)
+        object_query_embeds = self.query_embedding.weight
+        bev_queries = self.bev_embedding.weight
 
-        bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
-                               device=bev_queries.device).to(dtype)
-        bev_pos = self.positional_encoding(bev_mask).to(dtype)
+        bev_mask = self.bev_mask
+        bev_pos = self.positional_encoding(bev_mask)
+
+        if dtype == torch.float16:
+            object_query_embeds = object_query_embeds.to(dtype)
+            bev_queries = bev_queries.to(dtype)
+            bev_pos = bev_pos.to(dtype)
 
         if only_bev:  # only use encoder to obtain BEV features, TODO: refine the workaround
             return self.transformer.get_bev_features(
diff --git a/projects/mmdet3d_plugin/bevformer/modules/decoder.py b/projects/mmdet3d_plugin/bevformer/modules/decoder.py
index 33024f8..755883b 100644
--- a/projects/mmdet3d_plugin/bevformer/modules/decoder.py
+++ b/projects/mmdet3d_plugin/bevformer/modules/decoder.py
@@ -29,6 +29,7 @@ from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFuncti
 
 ext_module = ext_loader.load_ext(
     '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
+import mx_driving
 
 
 def inverse_sigmoid(x, eps=1e-5):
@@ -323,15 +324,8 @@ class CustomMSDeformableAttention(BaseModule):
                 f'Last dim of reference_points must be'
                 f' 2 or 4, but get {reference_points.shape[-1]} instead.')
         if torch.cuda.is_available() and value.is_cuda:
-
-            # using fp16 deformable attention is unstable because it performs many sum operations
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                            sampling_locations, attention_weights)
         else:
             output = multi_scale_deformable_attn_pytorch(
                 value, spatial_shapes, sampling_locations, attention_weights)
diff --git a/projects/mmdet3d_plugin/bevformer/modules/encoder.py b/projects/mmdet3d_plugin/bevformer/modules/encoder.py
index 6758847..022bdd1 100644
--- a/projects/mmdet3d_plugin/bevformer/modules/encoder.py
+++ b/projects/mmdet3d_plugin/bevformer/modules/encoder.py
@@ -119,8 +119,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
         lidar2img = lidar2img.view(
             1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
 
-        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
-                                            reference_points.to(torch.float32)).squeeze(-1)
+        reference_points_cam = torch.mul(lidar2img.to(torch.float32),
+                                         reference_points.to(torch.float32).transpose(-1, -2)).sum(-1, keepdim=True).squeeze(-1)
         eps = 1e-5
 
         bev_mask = (reference_points_cam[..., 2:3] > eps)
@@ -262,6 +262,8 @@ class BEVFormerLayer(MyCustomBaseTransformerLayer):
     """
 
     def __init__(self,
+                 bev_h,
+                 bev_w,
                  attn_cfgs,
                  feedforward_channels,
                  ffn_dropout=0.0,
@@ -280,6 +282,8 @@ class BEVFormerLayer(MyCustomBaseTransformerLayer):
             ffn_num_fcs=ffn_num_fcs,
             **kwargs)
         self.fp16_enabled = False
+        self.level_start_index = torch.tensor([0]).npu()
+        self.spatial_shapes = torch.tensor([[bev_h, bev_w]]).npu()
         assert len(operation_order) == 6
         assert set(operation_order) == set(
             ['self_attn', 'norm', 'cross_attn', 'ffn'])
@@ -367,9 +371,8 @@ class BEVFormerLayer(MyCustomBaseTransformerLayer):
                     attn_mask=attn_masks[attn_index],
                     key_padding_mask=query_key_padding_mask,
                     reference_points=ref_2d,
-                    spatial_shapes=torch.tensor(
-                        [[bev_h, bev_w]], device=query.device),
-                    level_start_index=torch.tensor([0], device=query.device),
+                    spatial_shapes=self.spatial_shapes,
+                    level_start_index=self.level_start_index,
                     **kwargs)
                 attn_index += 1
                 identity = query
diff --git a/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py b/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
index 100d94f..2b01580 100644
--- a/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
+++ b/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
@@ -26,7 +26,7 @@ from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFuncti
 from projects.mmdet3d_plugin.models.utils.bricks import run_time
 ext_module = ext_loader.load_ext(
     '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
-
+import mx_driving
 
 @ATTENTION.register_module()
 class SpatialCrossAttention(BaseModule):
@@ -383,17 +383,12 @@ class MSDeformableAttention3D(BaseModule):
         #
 
         if torch.cuda.is_available() and value.is_cuda:
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                            sampling_locations, attention_weights)
         else:
             output = multi_scale_deformable_attn_pytorch(
                 value, spatial_shapes, sampling_locations, attention_weights)
         if not self.batch_first:
             output = output.permute(1, 0, 2)
 
-        return output
+        return output
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py b/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
index 78fb9f5..208dcce 100644
--- a/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
+++ b/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
@@ -10,6 +10,7 @@ from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
 import warnings
 import torch
 import torch.nn as nn
+import mx_driving
 from mmcv.cnn import xavier_init, constant_init
 from mmcv.cnn.bricks.registry import ATTENTION
 import math
@@ -238,15 +239,8 @@ class TemporalSelfAttention(BaseModule):
                 f'Last dim of reference_points must be'
                 f' 2 or 4, but get {reference_points.shape[-1]} instead.')
         if torch.cuda.is_available() and value.is_cuda:
-
-            # using fp16 deformable attention is unstable because it performs many sum operations
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                            sampling_locations, attention_weights)
         else:
 
             output = multi_scale_deformable_attn_pytorch(
diff --git a/projects/mmdet3d_plugin/bevformer/modules/transformer.py b/projects/mmdet3d_plugin/bevformer/modules/transformer.py
index b740fcc..bb9c981 100644
--- a/projects/mmdet3d_plugin/bevformer/modules/transformer.py
+++ b/projects/mmdet3d_plugin/bevformer/modules/transformer.py
@@ -15,7 +15,7 @@ from mmdet.models.utils.builder import TRANSFORMER
 from torch.nn.init import normal_
 from projects.mmdet3d_plugin.models.utils.visual import save_tensor
 from mmcv.runner.base_module import BaseModule
-from torchvision.transforms.functional import rotate
+from torchvision.transforms.functional import InterpolationMode, rotate
 from .temporal_self_attention import TemporalSelfAttention
 from .spatial_cross_attention import MSDeformableAttention3D
 from .decoder import CustomMSDeformableAttention
@@ -150,6 +150,7 @@ class PerceptionTransformer(BaseModule):
                     tmp_prev_bev = prev_bev[:, i].reshape(
                         bev_h, bev_w, -1).permute(2, 0, 1)
                     tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
+                                          interpolation=InterpolationMode.BILINEAR,
                                           center=self.rotate_center)
                     tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
                         bev_h * bev_w, 1, -1)
diff --git a/projects/mmdet3d_plugin/datasets/builder.py b/projects/mmdet3d_plugin/datasets/builder.py
index f9bf5be..9586277 100644
--- a/projects/mmdet3d_plugin/datasets/builder.py
+++ b/projects/mmdet3d_plugin/datasets/builder.py
@@ -1,10 +1,12 @@
 
 # Copyright (c) OpenMMLab. All rights reserved.
+# Copyright 2024 Huawei Technologies Co., Ltd
 import copy
 import platform
 import random
 from functools import partial
 
+import torch
 import numpy as np
 from mmcv.parallel import collate
 from mmcv.runner import get_dist_info
@@ -16,6 +18,7 @@ from projects.mmdet3d_plugin.datasets.samplers.group_sampler import DistributedG
 from projects.mmdet3d_plugin.datasets.samplers.distributed_sampler import DistributedSampler
 from projects.mmdet3d_plugin.datasets.samplers.sampler import build_sampler
 
+
 def build_dataloader(dataset,
                      samples_per_gpu,
                      workers_per_gpu,
@@ -80,13 +83,14 @@ def build_dataloader(dataset,
         worker_init_fn, num_workers=num_workers, rank=rank,
         seed=seed) if seed is not None else None
 
+    kwargs = {"pin_memory_device":"npu"} if torch.__version__ >= "2.0" else {}
     data_loader = DataLoader(
         dataset,
         batch_size=batch_size,
         sampler=sampler,
         num_workers=num_workers,
         collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
-        pin_memory=False,
+        pin_memory=True,
         worker_init_fn=init_fn,
         persistent_workers=(num_workers > 0),
         **kwargs)
@@ -103,7 +107,6 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
 
 
 # Copyright (c) OpenMMLab. All rights reserved.
-import platform
 from mmcv.utils import Registry, build_from_cfg
 
 from mmdet.datasets import DATASETS
diff --git a/tools/dist_test.sh b/tools/dist_test.sh
index 3e2ec30..931aa0f 100755
--- a/tools/dist_test.sh
+++ b/tools/dist_test.sh
@@ -6,5 +6,5 @@ GPUS=$3
 PORT=${PORT:-29503}
 
 PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
-python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+torchrun --nproc_per_node=$GPUS --master_port=$PORT \
     $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} --eval bbox
diff --git a/tools/dist_train.sh b/tools/dist_train.sh
index 141b284..b132237 100755
--- a/tools/dist_train.sh
+++ b/tools/dist_train.sh
@@ -5,5 +5,5 @@ GPUS=$2
 PORT=${PORT:-28509}
 
 PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
-python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+torchrun --nproc_per_node=$GPUS --master_port=$PORT \
     $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} --deterministic
diff --git a/tools/fp16/dist_train.sh b/tools/fp16/dist_train.sh
index 4ac9a15..faa970b 100755
--- a/tools/fp16/dist_train.sh
+++ b/tools/fp16/dist_train.sh
@@ -2,8 +2,8 @@
 
 CONFIG=$1
 GPUS=$2
-PORT=${PORT:-28508}
+PORT=${PORT:-28509}
 
-PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
-python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+PYTHONPATH="$(dirname $0)/../..":$PYTHONPATH \
+torchrun --nproc_per_node=$GPUS --master_port=$PORT \
     $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} --deterministic
diff --git a/tools/fp16/train.py b/tools/fp16/train.py
index eddc349..77f0244 100644
--- a/tools/fp16/train.py
+++ b/tools/fp16/train.py
@@ -23,6 +23,10 @@ from mmdet.apis import set_random_seed
 from mmseg import __version__ as mmseg_version
 
 from mmcv.utils import TORCH_VERSION, digit_version
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.config.allow_internal_format = True
 
 def parse_args():
     parser = argparse.ArgumentParser(description='Train a detector')
diff --git a/tools/test.py b/tools/test.py
index acce20b..12ed94e 100755
--- a/tools/test.py
+++ b/tools/test.py
@@ -23,6 +23,9 @@ from projects.mmdet3d_plugin.bevformer.apis.test import custom_multi_gpu_test
 from mmdet.datasets import replace_ImageToTensor
 import time
 import os.path as osp
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
 
 
 def parse_args():
@@ -230,7 +233,7 @@ def main():
         # model = MMDataParallel(model, device_ids=[0])
         # outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
     else:
-        model = MMDistributedDataParallel(
+        model = NPUDistributedDataParallel(
             model.cuda(),
             device_ids=[torch.cuda.current_device()],
             broadcast_buffers=False)
diff --git a/tools/train.py b/tools/train.py
index d6e65c1..889ef74 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -28,6 +28,8 @@ from mmdet.apis import set_random_seed
 from mmseg import __version__ as mmseg_version
 
 from mmcv.utils import TORCH_VERSION, digit_version
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
 
 
 def parse_args():