diff --git a/extensions/chamfer_dist/__init__.py b/extensions/chamfer_dist/__init__.py
index 8b4f53c..87b8fef 100644
--- a/extensions/chamfer_dist/__init__.py
+++ b/extensions/chamfer_dist/__init__.py
@@ -28,6 +28,7 @@ class ChamferFunction(torch.autograd.Function):
 class ChamferDistanceL2(torch.nn.Module):
     f''' Chamder Distance L2
     '''
+
     def __init__(self, ignore_zeros=False):
         super().__init__()
         self.ignore_zeros = ignore_zeros
@@ -43,9 +44,11 @@ class ChamferDistanceL2(torch.nn.Module):
         dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
         return torch.mean(dist1) + torch.mean(dist2)
 
+
 class ChamferDistanceL2_split(torch.nn.Module):
     f''' Chamder Distance L2
     '''
+
     def __init__(self, ignore_zeros=False):
         super().__init__()
         self.ignore_zeros = ignore_zeros
@@ -61,9 +64,11 @@ class ChamferDistanceL2_split(torch.nn.Module):
         dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
         return torch.mean(dist1), torch.mean(dist2)
 
+
 class ChamferDistanceL1(torch.nn.Module):
     f''' Chamder Distance L1
     '''
+
     def __init__(self, ignore_zeros=False):
         super().__init__()
         self.ignore_zeros = ignore_zeros
@@ -77,9 +82,7 @@ class ChamferDistanceL1(torch.nn.Module):
             xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
 
         dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
-        # import pdb
-        # pdb.set_trace()
         dist1 = torch.sqrt(dist1)
         dist2 = torch.sqrt(dist2)
-        return (torch.mean(dist1) + torch.mean(dist2))/2
+        return (torch.mean(dist1) + torch.mean(dist2)) / 2
 
diff --git a/extensions/chamfer_dist/chamfer_cuda.cpp b/extensions/chamfer_dist/chamfer_cuda.cpp
index 9fca161..b3bb729 100644
--- a/extensions/chamfer_dist/chamfer_cuda.cpp
+++ b/extensions/chamfer_dist/chamfer_cuda.cpp
@@ -21,7 +21,7 @@ std::vector<torch::Tensor> chamfer_cuda_backward(torch::Tensor xyz1,
 
 std::vector<torch::Tensor> chamfer_forward(torch::Tensor xyz1,
                                            torch::Tensor xyz2) {
-  return chamfer_cuda_forward(xyz1, xyz2);
+    return chamfer_cuda_forward(xyz1, xyz2);
 }
 
 std::vector<torch::Tensor> chamfer_backward(torch::Tensor xyz1,
@@ -30,10 +30,10 @@ std::vector<torch::Tensor> chamfer_backward(torch::Tensor xyz1,
                                             torch::Tensor idx2,
                                             torch::Tensor grad_dist1,
                                             torch::Tensor grad_dist2) {
-  return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
+    return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
 }
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)");
-  m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)");
+    m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)");
+    m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)");
 }
diff --git a/extensions/chamfer_dist/test.py b/extensions/chamfer_dist/test.py
index 0ece5d2..0ef2499 100644
--- a/extensions/chamfer_dist/test.py
+++ b/extensions/chamfer_dist/test.py
@@ -31,8 +31,5 @@ class ChamferDistanceTestCase(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    # unittest.main()
-    import pdb
-    x = torch.rand(32,128,3)
-    y = torch.rand(32,128,3)
-    pdb.set_trace()
+    x = torch.rand(32, 128, 3)
+    y = torch.rand(32, 128, 3)
diff --git a/projects/configs/surroundocc/surroundocc.py b/projects/configs/surroundocc/surroundocc.py
index f921741..7efaabc 100644
--- a/projects/configs/surroundocc/surroundocc.py
+++ b/projects/configs/surroundocc/surroundocc.py
@@ -16,7 +16,7 @@ use_semantic = True
 img_norm_cfg = dict(
     mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
 
-class_names =  ['barrier','bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
+class_names =  ['barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
                 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface',
                 'other_flat', 'sidewalk', 'terrain', 'manmade','vegetation']
 
@@ -43,7 +43,7 @@ model = dict(
        type='ResNet',
        depth=101,
        num_stages=4,
-       out_indices=(1,2,3),
+       out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
@@ -69,7 +69,7 @@ model = dict(
         conv_input=[_dim_[2], 256, _dim_[1], 128, _dim_[0], 64, 64],
         conv_output=[256, _dim_[1], 128, _dim_[0], 64, 64, 32],
         out_indices=[0, 2, 4, 6],
-        upsample_strides=[1,2,1,2,1,2,1],
+        upsample_strides=[1, 2, 1, 2, 1, 2, 1],
         embed_dims=_dim_,
         img_channels=[512, 512, 512],
         use_semantic=use_semantic,
@@ -125,7 +125,7 @@ test_pipeline = [
     dict(type='NormalizeMultiviewImage', **img_norm_cfg),
     dict(type='PadMultiViewImage', size_divisor=32),
     dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
-    dict(type='CustomCollect3D', keys=['img','gt_occ'])
+    dict(type='CustomCollect3D', keys=['img', 'gt_occ'])
 ]
 
 find_unused_parameters = True
@@ -168,7 +168,7 @@ data = dict(
 )
 
 optimizer = dict(
-    type='AdamW',
+    type='NpuFusedAdamW',
     lr=2e-4,
     paramwise_cfg=dict(
         custom_keys={
@@ -176,7 +176,8 @@ optimizer = dict(
         }),
     weight_decay=0.01)
 
-optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+optimizer_config = dict(type='GradientCumulativeOptimizerHook', cumulative_iters=2, grad_clip=dict(max_norm=35, norm_type=2))
+
 # learning policy
 lr_config = dict(
     policy='CosineAnnealing',
@@ -184,13 +185,13 @@ lr_config = dict(
     warmup_iters=500,
     warmup_ratio=1.0 / 3,
     min_lr_ratio=1e-3)
-total_epochs = 24
+total_epochs = 15
 evaluation = dict(interval=1, pipeline=test_pipeline)
 
 runner = dict(type='EpochBasedRunner', max_epochs=total_epochs)
 load_from = 'ckpts/r101_dcn_fcos3d_pretrain.pth'
 log_config = dict(
-    interval=50,
+    interval=1,
     hooks=[
         dict(type='TextLoggerHook'),
         dict(type='TensorboardLoggerHook')
diff --git a/projects/configs/surroundocc/surroundocc_inference.py b/projects/configs/surroundocc/surroundocc_inference.py
index 2d514ed..2533799 100644
--- a/projects/configs/surroundocc/surroundocc_inference.py
+++ b/projects/configs/surroundocc/surroundocc_inference.py
@@ -16,9 +16,9 @@ use_semantic = True
 img_norm_cfg = dict(
     mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
 
-class_names =  ['barrier','bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
+class_names =  ['barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
                 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface',
-                'other_flat', 'sidewalk', 'terrain', 'manmade','vegetation']
+                'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation']
 
 input_modality = dict(
     use_lidar=False,
@@ -44,7 +44,7 @@ model = dict(
        type='ResNet',
        depth=101,
        num_stages=4,
-       out_indices=(1,2,3),
+       out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
@@ -70,7 +70,7 @@ model = dict(
         conv_input=[_dim_[2], 256, _dim_[1], 128, _dim_[0], 64, 64],
         conv_output=[256, _dim_[1], 128, _dim_[0], 64, 64, 32],
         out_indices=[0, 2, 4, 6],
-        upsample_strides=[1,2,1,2,1,2,1],
+        upsample_strides=[1, 2, 1, 2, 1, 2, 1],
         embed_dims=_dim_,
         img_channels=[512, 512, 512],
         use_semantic=use_semantic,
diff --git a/projects/configs/surroundocc/surroundocc_nosemantic.py b/projects/configs/surroundocc/surroundocc_nosemantic.py
index 2f3b790..46d19e8 100644
--- a/projects/configs/surroundocc/surroundocc_nosemantic.py
+++ b/projects/configs/surroundocc/surroundocc_nosemantic.py
@@ -16,9 +16,9 @@ use_semantic = False
 img_norm_cfg = dict(
     mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
 
-class_names =  ['barrier','bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
+class_names =  ['barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
                 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface',
-                'other_flat', 'sidewalk', 'terrain', 'manmade','vegetation']
+                'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation']
 
 input_modality = dict(
     use_lidar=False,
@@ -43,7 +43,7 @@ model = dict(
        type='ResNet',
        depth=101,
        num_stages=4,
-       out_indices=(1,2,3),
+       out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
@@ -69,7 +69,7 @@ model = dict(
         conv_input=[_dim_[2], 256, _dim_[1], 128, _dim_[0], 64, 64],
         conv_output=[256, _dim_[1], 128, _dim_[0], 64, 64, 32],
         out_indices=[0, 2, 4, 6],
-        upsample_strides=[1,2,1,2,1,2,1],
+        upsample_strides=[1, 2, 1, 2, 1, 2, 1],
         embed_dims=_dim_,
         img_channels=[512, 512, 512],
         use_semantic=use_semantic,
@@ -125,7 +125,7 @@ test_pipeline = [
     dict(type='NormalizeMultiviewImage', **img_norm_cfg),
     dict(type='PadMultiViewImage', size_divisor=32),
     dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
-    dict(type='CustomCollect3D', keys=['img','gt_occ'])
+    dict(type='CustomCollect3D', keys=['img', 'gt_occ'])
 ]
 
 find_unused_parameters = True
diff --git a/projects/configs/surroundocc/surroundocc_performance.py b/projects/configs/surroundocc/surroundocc_performance.py
new file mode 100644
index 0000000..31bb909
--- /dev/null
+++ b/projects/configs/surroundocc/surroundocc_performance.py
@@ -0,0 +1,200 @@
+_base_ = [
+    '../datasets/custom_nus-3d.py',
+    '../_base_/default_runtime.py'
+]
+#
+plugin = True
+plugin_dir = 'projects/mmdet3d_plugin/'
+
+# If point cloud range is changed, the models should also change their point
+# cloud range accordingly
+point_cloud_range = [-50, -50, -5.0, 50, 50, 3.0]
+occ_size = [200, 200, 16]
+use_semantic = True
+
+
+img_norm_cfg = dict(
+    mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
+
+class_names =  ['barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle',
+                'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface',
+                'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation']
+
+input_modality = dict(
+    use_lidar=False,
+    use_camera=True,
+    use_radar=False,
+    use_map=False,
+    use_external=True)
+
+_dim_ = [128, 256, 512]
+_ffn_dim_ = [256, 512, 1024]
+volume_h_ = [100, 50, 25]
+volume_w_ = [100, 50, 25]
+volume_z_ = [8, 4, 2]
+_num_points_ = [2, 4, 8]
+_num_layers_ = [1, 3, 6]
+
+model = dict(
+    type='SurroundOcc',
+    use_grid_mask=True,
+    use_semantic=use_semantic,
+    img_backbone=dict(
+       type='ResNet',
+       depth=101,
+       num_stages=4,
+       out_indices=(1, 2, 3),
+       frozen_stages=1,
+       norm_cfg=dict(type='BN2d', requires_grad=False),
+       norm_eval=True,
+       style='caffe',
+       #with_cp=True, # using checkpoint to save GPU memory
+       dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), # original DCNv2 will print log when perform load_state_dict
+       stage_with_dcn=(False, False, True, True)),
+    img_neck=dict(
+        type='FPN',
+        in_channels=[512, 1024, 2048],
+        out_channels=512,
+        start_level=0,
+        add_extra_convs='on_output',
+        num_outs=3,
+        relu_before_extra_convs=True),
+    pts_bbox_head=dict(
+        type='OccHead',
+        volume_h=volume_h_,
+        volume_w=volume_w_,
+        volume_z=volume_z_,
+        num_query=900,
+        num_classes=17,
+        conv_input=[_dim_[2], 256, _dim_[1], 128, _dim_[0], 64, 64],
+        conv_output=[256, _dim_[1], 128, _dim_[0], 64, 64, 32],
+        out_indices=[0, 2, 4, 6],
+        upsample_strides=[1, 2, 1, 2, 1, 2, 1],
+        embed_dims=_dim_,
+        img_channels=[512, 512, 512],
+        use_semantic=use_semantic,
+        transformer_template=dict(
+            type='PerceptionTransformer',
+            embed_dims=_dim_,
+            encoder=dict(
+                type='OccEncoder',
+                num_layers=_num_layers_,
+                pc_range=point_cloud_range,
+                return_intermediate=False,
+                transformerlayers=dict(
+                    type='OccLayer',
+                    attn_cfgs=[
+                        dict(
+                            type='SpatialCrossAttention',
+                            pc_range=point_cloud_range,
+                            deformable_attention=dict(
+                                type='MSDeformableAttention3D',
+                                embed_dims=_dim_,
+                                num_points=_num_points_,
+                                num_levels=1),
+                            embed_dims=_dim_,
+                        )
+                    ],
+                    feedforward_channels=_ffn_dim_,
+                    ffn_dropout=0.1,
+                    embed_dims=_dim_,
+                    conv_num=2,
+                    operation_order=('cross_attn', 'norm',
+                                     'ffn', 'norm', 'conv')))),
+),
+)
+
+dataset_type = 'CustomNuScenesOccDataset'
+data_root = 'data/nuscenes/'
+file_client_args = dict(backend='disk')
+
+
+train_pipeline = [
+    dict(type='LoadMultiViewImageFromFiles', to_float32=True),
+    dict(type='PhotoMetricDistortionMultiViewImage'),
+    dict(type='LoadOccupancy', use_semantic=use_semantic),
+    dict(type='NormalizeMultiviewImage', **img_norm_cfg),
+    dict(type='PadMultiViewImage', size_divisor=32),
+    dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
+    dict(type='CustomCollect3D', keys=['img', 'gt_occ'])
+]
+
+test_pipeline = [
+    dict(type='LoadMultiViewImageFromFiles', to_float32=True),
+    dict(type='LoadOccupancy', use_semantic=use_semantic),
+    dict(type='NormalizeMultiviewImage', **img_norm_cfg),
+    dict(type='PadMultiViewImage', size_divisor=32),
+    dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
+    dict(type='CustomCollect3D', keys=['img', 'gt_occ'])
+]
+
+find_unused_parameters = True
+data = dict(
+    samples_per_gpu=1,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        ann_file='data/nuscenes_infos_train.pkl',
+        pipeline=train_pipeline,
+        modality=input_modality,
+        test_mode=False,
+        use_valid_flag=True,
+        occ_size=occ_size,
+        pc_range=point_cloud_range,
+        use_semantic=use_semantic,
+        classes=class_names,
+        box_type_3d='LiDAR'),
+    val=dict(type=dataset_type,
+             data_root=data_root,
+             ann_file='data/nuscenes_infos_val.pkl',
+             pipeline=test_pipeline,
+             occ_size=occ_size,
+             pc_range=point_cloud_range,
+             use_semantic=use_semantic,
+             classes=class_names,
+             modality=input_modality),
+    test=dict(type=dataset_type,
+              data_root=data_root,
+              ann_file='data/nuscenes_infos_val.pkl',
+              pipeline=test_pipeline,
+              occ_size=occ_size,
+              pc_range=point_cloud_range,
+              use_semantic=use_semantic,
+              classes=class_names,
+              modality=input_modality),
+    shuffler_sampler=dict(type='DistributedGroupSampler'),
+    nonshuffler_sampler=dict(type='DistributedSampler')
+)
+
+optimizer = dict(
+    type='NpuFusedAdamW',
+    lr=2e-4,
+    paramwise_cfg=dict(
+        custom_keys={
+            'img_backbone': dict(lr_mult=0.1),
+        }),
+    weight_decay=0.01)
+
+optimizer_config = dict(type='GradientCumulativeOptimizerHook', cumulative_iters=2, grad_clip=dict(max_norm=35, norm_type=2))
+
+# learning policy
+lr_config = dict(
+    policy='CosineAnnealing',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    min_lr_ratio=1e-3)
+total_epochs = 1
+evaluation = dict(interval=1, pipeline=test_pipeline)
+
+runner = dict(type='EpochBasedRunner', max_epochs=total_epochs)
+load_from = 'ckpts/r101_dcn_fcos3d_pretrain.pth'
+log_config = dict(
+    interval=1,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        dict(type='TensorboardLoggerHook')
+    ])
+
+checkpoint_config = dict(interval=1)
diff --git a/projects/mmdet3d_plugin/datasets/evaluation_metrics.py b/projects/mmdet3d_plugin/datasets/evaluation_metrics.py
index 1b0f356..e37a119 100644
--- a/projects/mmdet3d_plugin/datasets/evaluation_metrics.py
+++ b/projects/mmdet3d_plugin/datasets/evaluation_metrics.py
@@ -1,6 +1,20 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import numpy as np
 import torch
-import chamfer
+
 
 def voxel_to_vertices(voxel, img_metas, thresh=0.5):
     x = torch.linspace(0, voxel.shape[0] - 1, voxel.shape[0])
@@ -10,46 +24,51 @@ def voxel_to_vertices(voxel, img_metas, thresh=0.5):
     vv = torch.stack([X, Y, Z], dim=-1).to(voxel.device)
 
     vertices = vv[voxel > thresh]
-    vertices[:, 0] = (vertices[:, 0] + 0.5) * (img_metas['pc_range'][3] - img_metas['pc_range'][0]) /  img_metas['occ_size'][0]  + img_metas['pc_range'][0]
-    vertices[:, 1] = (vertices[:, 1] + 0.5) * (img_metas['pc_range'][4] - img_metas['pc_range'][1]) /  img_metas['occ_size'][1]  + img_metas['pc_range'][1]
-    vertices[:, 2] = (vertices[:, 2] + 0.5) * (img_metas['pc_range'][5] - img_metas['pc_range'][2]) /  img_metas['occ_size'][2]  + img_metas['pc_range'][2]
+    vertices[:, 0] = (vertices[:, 0] + 0.5) * (img_metas['pc_range'][3] - img_metas['pc_range'][0]) / img_metas['occ_size'][0] + img_metas['pc_range'][0]
+    vertices[:, 1] = (vertices[:, 1] + 0.5) * (img_metas['pc_range'][4] - img_metas['pc_range'][1]) / img_metas['occ_size'][1] + img_metas['pc_range'][1]
+    vertices[:, 2] = (vertices[:, 2] + 0.5) * (img_metas['pc_range'][5] - img_metas['pc_range'][2]) / img_metas['occ_size'][2] + img_metas['pc_range'][2]
 
     return vertices
 
+
 def gt_to_vertices(gt, img_metas):
-    gt[:, 0] = (gt[:, 0] + 0.5) * (img_metas['pc_range'][3] - img_metas['pc_range'][0]) /  img_metas['occ_size'][0]  + img_metas['pc_range'][0]
-    gt[:, 1] = (gt[:, 1] + 0.5) * (img_metas['pc_range'][4] - img_metas['pc_range'][1]) /  img_metas['occ_size'][1]  + img_metas['pc_range'][1]
-    gt[:, 2] = (gt[:, 2] + 0.5) * (img_metas['pc_range'][5] - img_metas['pc_range'][2]) /  img_metas['occ_size'][2]  + img_metas['pc_range'][2]
+    gt[:, 0] = (gt[:, 0] + 0.5) * (img_metas['pc_range'][3] - img_metas['pc_range'][0]) / img_metas['occ_size'][0] + img_metas['pc_range'][0]
+    gt[:, 1] = (gt[:, 1] + 0.5) * (img_metas['pc_range'][4] - img_metas['pc_range'][1]) / img_metas['occ_size'][1] + img_metas['pc_range'][1]
+    gt[:, 2] = (gt[:, 2] + 0.5) * (img_metas['pc_range'][5] - img_metas['pc_range'][2]) / img_metas['occ_size'][2] + img_metas['pc_range'][2]
     return gt
 
+
 def gt_to_voxel(gt, img_metas):
     voxel = np.zeros(img_metas['occ_size'])
     voxel[gt[:, 0].astype(np.int), gt[:, 1].astype(np.int), gt[:, 2].astype(np.int)] = gt[:, 3]
 
     return voxel
 
+
 def eval_3d(verts_pred, verts_trgt, threshold=.5):
     d1, d2, idx1, idx2 = chamfer.forward(verts_pred.unsqueeze(0).type(torch.float), verts_trgt.unsqueeze(0).type(torch.float))
     dist1 = torch.sqrt(d1).cpu().numpy()
     dist2 = torch.sqrt(d2).cpu().numpy()
     cd = dist1.mean() + dist2.mean()
-    precision = np.mean((dist1<threshold).astype('float'))
+    precision = np.mean((dist1 < threshold).astype('float'))
     recal = np.mean((dist2<threshold).astype('float'))
     fscore = 2 * precision * recal / (precision + recal)
-    metrics = np.array([np.mean(dist1),np.mean(dist2),cd, precision,recal,fscore])
+    metrics = np.array([np.mean(dist1), np.mean(dist2), cd, precision, recal, fscore])
     return metrics
 
+
 def evaluation_reconstruction(pred_occ, gt_occ, img_metas):
     results = []
     for i in range(pred_occ.shape[0]):
-        
-        vertices_pred = voxel_to_vertices(pred_occ[i], img_metas, thresh=0.25) #set low thresh for class imbalance problem
+        #set low thresh for class imbalance problem
+        vertices_pred = voxel_to_vertices(pred_occ[i], img_metas, thresh=0.25)
         vertices_gt = gt_to_vertices(gt_occ[i][..., :3], img_metas)
-        
-        metrics = eval_3d(vertices_pred.type(torch.double), vertices_gt.type(torch.double)) #must convert to double, a bug in chamfer
+        #must convert to double, a bug in chamfer
+        metrics = eval_3d(vertices_pred.type(torch.double), vertices_gt.type(torch.double))
         results.append(metrics)
     return np.stack(results, axis=0)
 
+
 def evaluation_semantic(pred_occ, gt_occ, img_metas, class_num):
     results = []
 
@@ -59,7 +78,8 @@ def evaluation_semantic(pred_occ, gt_occ, img_metas, class_num):
         mask = (gt_i != 255)
         score = np.zeros((class_num, 3))
         for j in range(class_num):
-            if j == 0: #class 0 for geometry IoU
+            #class 0 for geometry IoU
+            if j == 0:
                 score[j][0] += ((gt_i[mask] != 0) * (pred_i[mask] != 0)).sum()
                 score[j][1] += (gt_i[mask] != 0).sum()
                 score[j][2] += (pred_i[mask] != 0).sum()
diff --git a/projects/mmdet3d_plugin/datasets/nuscenes_occupancy_dataset.py b/projects/mmdet3d_plugin/datasets/nuscenes_occupancy_dataset.py
index ab976f4..9328afe 100644
--- a/projects/mmdet3d_plugin/datasets/nuscenes_occupancy_dataset.py
+++ b/projects/mmdet3d_plugin/datasets/nuscenes_occupancy_dataset.py
@@ -12,7 +12,6 @@ from nuscenes.eval.common.utils import quaternion_yaw, Quaternion
 from projects.mmdet3d_plugin.models.utils.visual import save_tensor
 from mmcv.parallel import DataContainer as DC
 import random
-import pdb, os
 
 
 @DATASETS.register_module()
@@ -70,8 +69,8 @@ class CustomNuScenesOccDataset(NuScenesDataset):
         # standard protocal modified from SECOND.Pytorch
         input_dict = dict(
             occ_path=info['occ_path'],
-            occ_size = np.array(self.occ_size),
-            pc_range = np.array(self.pc_range)
+            occ_size=np.array(self.occ_size),
+            pc_range=np.array(self.pc_range)
         )
 
         if self.modality['use_camera']:
@@ -217,7 +216,7 @@ class CustomNuScenesOccDataset(NuScenesDataset):
 
         else:
             results = np.stack(results, axis=0).mean(0)
-            results_dict={'Acc':results[0],
+            results_dict = {'Acc':results[0],
                           'Comp':results[1],
                           'CD':results[2],
                           'Prec':results[3],
diff --git a/projects/mmdet3d_plugin/datasets/pipelines/transform_3d.py b/projects/mmdet3d_plugin/datasets/pipelines/transform_3d.py
index b6ff218..783b0e3 100644
--- a/projects/mmdet3d_plugin/datasets/pipelines/transform_3d.py
+++ b/projects/mmdet3d_plugin/datasets/pipelines/transform_3d.py
@@ -344,7 +344,7 @@ class Augmentation(object):
         self.data_config = data_config
 
 
-    def get_rot(self,h):
+    def get_rot(self, h):
         return np.array([
             [np.cos(h), np.sin(h)],
             [-np.sin(h), np.cos(h)],
@@ -382,11 +382,11 @@ class Augmentation(object):
         
         return img
 
-    def sample_augmentation(self, H , W, flip=None, scale=None, decay_aug=False):
+    def sample_augmentation(self, H, W, flip=None, scale=None, decay_aug=False):
         fH, fW = self.data_config['input_size']
         
         if self.is_train and (not decay_aug):
-            resize = float(fW)/float(W)
+            resize = float(fW) / float(W)
             resize += np.random.uniform(*self.data_config['resize'])
             resize_dims = (int(W * resize), int(H * resize))
             newW, newH = resize_dims
@@ -398,7 +398,7 @@ class Augmentation(object):
             rotate = np.random.uniform(*self.data_config['rot'])
         
         else:
-            resize = float(fW)/float(W)
+            resize = float(fW) / float(W)
             resize += self.data_config.get('resize_test', 0.0)
             if scale is not None:
                 resize = scale
@@ -432,7 +432,7 @@ class Augmentation(object):
             resize, resize_dims, crop, flip, rotate = img_augs
             img, post_rot2, post_tran2 = \
                 self.img_transform(img, post_rot, post_tran, resize=resize, 
-                    resize_dims=resize_dims, crop=crop,flip=flip, rotate=rotate)
+                    resize_dims=resize_dims, crop=crop, flip=flip, rotate=rotate)
     
             # for convenience, make augmentation matrices 3x3
             post_rot = np.eye(4)
diff --git a/projects/mmdet3d_plugin/models/backbones/efficientnet.py b/projects/mmdet3d_plugin/models/backbones/efficientnet.py
index 5b0b47a..fb69b36 100644
--- a/projects/mmdet3d_plugin/models/backbones/efficientnet.py
+++ b/projects/mmdet3d_plugin/models/backbones/efficientnet.py
@@ -12,6 +12,7 @@ from mmcv.runner import BaseModule, Sequential
 from mmdet3d.models.builder import BACKBONES
 from mmdet.models.utils import SELayer, make_divisible
 
+
 class EdgeResidual(BaseModule):
     """Edge Residual Block.
     Args:
@@ -109,6 +110,7 @@ class EdgeResidual(BaseModule):
 
         return out
 
+
 class InvertedResidual(BaseModule):
     """Inverted Residual Block.
     Args:
@@ -228,9 +230,11 @@ class InvertedResidual(BaseModule):
 
         return out
 
+
 def model_scaling(layer_setting, arch_setting):
     """Scaling operation to the layer's parameters according to the
-    arch_setting."""
+    arch_setting.
+    """
     # scale width
     new_layer_setting = copy.deepcopy(layer_setting)
     for layer_cfg in new_layer_setting:
diff --git a/projects/mmdet3d_plugin/models/utils/position_embedding.py b/projects/mmdet3d_plugin/models/utils/position_embedding.py
index dccb4f2..dbb20b4 100644
--- a/projects/mmdet3d_plugin/models/utils/position_embedding.py
+++ b/projects/mmdet3d_plugin/models/utils/position_embedding.py
@@ -36,6 +36,7 @@ class RelPositionEmbedding(nn.Module):
 from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
 from mmcv.runner import BaseModule
 
+
 @POSITIONAL_ENCODING.register_module()
 class LearnedPositionalEncoding3D(BaseModule):
     """Position embedding with learnable embedding weights.
diff --git a/projects/mmdet3d_plugin/surroundocc/apis/mmdet_train.py b/projects/mmdet3d_plugin/surroundocc/apis/mmdet_train.py
index 108d6e4..c124f8e 100644
--- a/projects/mmdet3d_plugin/surroundocc/apis/mmdet_train.py
+++ b/projects/mmdet3d_plugin/surroundocc/apis/mmdet_train.py
@@ -1,3 +1,17 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 # ---------------------------------------------
 # Copyright (c) OpenMMLab. All rights reserved.
 # ---------------------------------------------
@@ -9,7 +23,7 @@ import warnings
 import numpy as np
 import torch
 import torch.distributed as dist
-from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
 from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
                          Fp16OptimizerHook, OptimizerHook, build_optimizer,
                          build_runner, get_dist_info)
@@ -73,23 +87,23 @@ def custom_train_detector(model,
         # Sets the `find_unused_parameters` parameter in
         # torch.nn.parallel.DistributedDataParallel
         print(torch.cuda.current_device(), cfg.gpu_ids)
-        model = MMDistributedDataParallel(
+        model = NPUDistributedDataParallel(
             model.cuda(),
             device_ids=[torch.cuda.current_device()],
             #device_ids=[4,5,7],
             broadcast_buffers=False,
             find_unused_parameters=find_unused_parameters)
         if eval_model is not None:
-            eval_model = MMDistributedDataParallel(
+            eval_model = NPUDistributedDataParallel(
                 eval_model.cuda(),
                 device_ids=[torch.cuda.current_device()],
                 broadcast_buffers=False,
                 find_unused_parameters=find_unused_parameters)
     else:
-        model = MMDataParallel(
+        model = NPUDataParallel(
             model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
         if eval_model is not None:
-            eval_model = MMDataParallel(
+            eval_model = NPUDataParallel(
                 eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
 
 
diff --git a/projects/mmdet3d_plugin/surroundocc/apis/test.py b/projects/mmdet3d_plugin/surroundocc/apis/test.py
index 551bb6a..61e5338 100644
--- a/projects/mmdet3d_plugin/surroundocc/apis/test.py
+++ b/projects/mmdet3d_plugin/surroundocc/apis/test.py
@@ -22,7 +22,7 @@ import mmcv
 import numpy as np
 import pycocotools.mask as mask_util
 #import open3d as o3d
-import pdb
+
 
 def custom_encode_mask_results(mask_results):
     """Encode bitmap mask to RLE code. Semantic Masks only
@@ -44,6 +44,7 @@ def custom_encode_mask_results(mask_results):
                         dtype='uint8'))[0])  # encoded with RLE
     return [encoded_mask_results]
 
+
 def custom_multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False, is_vis=False):
     """Test model with multiple gpus.
     This method tests model with multiple gpus and collects the results
diff --git a/projects/mmdet3d_plugin/surroundocc/dense_heads/occ_head.py b/projects/mmdet3d_plugin/surroundocc/dense_heads/occ_head.py
index a69fb4a..46293b8 100644
--- a/projects/mmdet3d_plugin/surroundocc/dense_heads/occ_head.py
+++ b/projects/mmdet3d_plugin/surroundocc/dense_heads/occ_head.py
@@ -22,9 +22,10 @@ from mmcv.cnn.utils.weight_init import constant_init
 import os
 from torch.autograd import Variable
 try:
-    from itertools import  ifilterfalse
+    from itertools import ifilterfalse
 except ImportError: # py3k
-    from itertools import  filterfalse as ifilterfalse
+    from itertools import filterfalse as ifilterfalse
+
 
 @HEADS.register_module()
 class OccHead(nn.Module): 
@@ -100,9 +101,9 @@ class OccHead(nn.Module):
         out_channels = self.conv_output
         in_channels = self.conv_input
 
-        norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)
-        upsample_cfg=dict(type='deconv3d', bias=False)
-        conv_cfg=dict(type='Conv3d', bias=False)
+        norm_cfg = dict(type='GN', num_groups=16, requires_grad=True)
+        upsample_cfg = dict(type='deconv3d', bias=False)
+        conv_cfg = dict(type='Conv3d', bias=False)
 
         for i, out_channel in enumerate(out_channels):
             stride = upsample_strides[i]
@@ -159,8 +160,8 @@ class OccHead(nn.Module):
 
 
         self.transfer_conv = nn.ModuleList()
-        norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)
-        conv_cfg=dict(type='Conv2d', bias=True)
+        norm_cfg = dict(type='GN', num_groups=16, requires_grad=True)
+        conv_cfg = dict(type='Conv2d', bias=True)
         for i in range(self.fpn_level):
             transfer_layer = build_conv_layer(
                     conv_cfg,
@@ -199,7 +200,7 @@ class OccHead(nn.Module):
             volume_z = self.volume_z[i]
 
             _, _, C, H, W = mlvl_feats[i].shape
-            view_features = self.transfer_conv[i](mlvl_feats[i].reshape(bs*num_cam, C, H, W)).reshape(bs, num_cam, -1, H, W)
+            view_features = self.transfer_conv[i](mlvl_feats[i].reshape(bs * num_cam, C, H, W)).reshape(bs, num_cam, -1, H, W)
 
             volume_embed_i = self.transformer[i](
                 [view_features],
@@ -268,8 +269,8 @@ class OccHead(nn.Module):
                 #gt = torch.mode(gt, dim=-1)[0].float()
                     
                 loss_occ_i = (F.binary_cross_entropy_with_logits(pred, gt) + geo_scal_loss(pred, gt.long(), semantic=False))
-                    
-                loss_occ_i =  loss_occ_i * ((0.5)**(len(preds_dicts['occ_preds']) - 1 -i)) #* focal_weight
+                #* focal_weight
+                loss_occ_i = loss_occ_i * ((0.5) ** (len(preds_dicts['occ_preds']) - 1 - i))
 
                 loss_dict['loss_occ_{}'.format(i)] = loss_occ_i
     
@@ -291,7 +292,7 @@ class OccHead(nn.Module):
 
                 loss_occ_i = (criterion(pred, gt.long()) + sem_scal_loss(pred, gt.long()) + geo_scal_loss(pred, gt.long()))
 
-                loss_occ_i = loss_occ_i * ((0.5)**(len(preds_dicts['occ_preds']) - 1 -i))
+                loss_occ_i = loss_occ_i * ((0.5) ** (len(preds_dicts['occ_preds']) - 1 - i))
 
                 loss_dict['loss_occ_{}'.format(i)] = loss_occ_i
 
diff --git a/projects/mmdet3d_plugin/surroundocc/detectors/surroundocc.py b/projects/mmdet3d_plugin/surroundocc/detectors/surroundocc.py
index 7145d1e..aba0994 100644
--- a/projects/mmdet3d_plugin/surroundocc/detectors/surroundocc.py
+++ b/projects/mmdet3d_plugin/surroundocc/detectors/surroundocc.py
@@ -21,7 +21,6 @@ from projects.mmdet3d_plugin.datasets.evaluation_metrics import evaluation_recon
 from sklearn.metrics import confusion_matrix as CM
 import time, yaml, os
 import torch.nn as nn
-import pdb
 
 
 @DETECTORS.register_module()
@@ -92,7 +91,7 @@ class SurroundOcc(MVXTwoStageDetector):
         for img_feat in img_feats:
             BN, C, H, W = img_feat.size()
             if len_queue is not None:
-                img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
+                img_feats_reshaped.append(img_feat.view(int(B / len_queue), len_queue, int(BN / B), C, H, W))
             else:
                 img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
         return img_feats_reshaped
@@ -232,9 +231,9 @@ class SurroundOcc(MVXTwoStageDetector):
             vv = torch.stack([X, Y, Z], dim=-1).to(voxel.device)
         
             vertices = vv[voxel[i] > 0.5]
-            vertices[:, 0] = (vertices[:, 0] + 0.5) * (img_metas[i]['pc_range'][3] - img_metas[i]['pc_range'][0]) /  img_metas[i]['occ_size'][0]  + img_metas[i]['pc_range'][0]
-            vertices[:, 1] = (vertices[:, 1] + 0.5) * (img_metas[i]['pc_range'][4] - img_metas[i]['pc_range'][1]) /  img_metas[i]['occ_size'][1]  + img_metas[i]['pc_range'][1]
-            vertices[:, 2] = (vertices[:, 2] + 0.5) * (img_metas[i]['pc_range'][5] - img_metas[i]['pc_range'][2]) /  img_metas[i]['occ_size'][2]  + img_metas[i]['pc_range'][2]
+            vertices[:, 0] = (vertices[:, 0] + 0.5) * (img_metas[i]['pc_range'][3] - img_metas[i]['pc_range'][0]) / img_metas[i]['occ_size'][0] + img_metas[i]['pc_range'][0]
+            vertices[:, 1] = (vertices[:, 1] + 0.5) * (img_metas[i]['pc_range'][4] - img_metas[i]['pc_range'][1]) / img_metas[i]['occ_size'][1] + img_metas[i]['pc_range'][1]
+            vertices[:, 2] = (vertices[:, 2] + 0.5) * (img_metas[i]['pc_range'][5] - img_metas[i]['pc_range'][2]) / img_metas[i]['occ_size'][2] + img_metas[i]['pc_range'][2]
             
             vertices = vertices.cpu().numpy()
     
@@ -252,8 +251,6 @@ class SurroundOcc(MVXTwoStageDetector):
 
             o3d.io.write_point_cloud(os.path.join(save_dir, 'pred.ply'), pcd)
             np.save(os.path.join(save_dir, 'pred.npy'), vertices)
-            for cam_id, cam_path in enumerate(img_metas[i]['filename']):
-                os.system('cp {} {}/{}.jpg'.format(cam_path, save_dir, cam_id))
 
 
     
diff --git a/projects/mmdet3d_plugin/surroundocc/loss/loss_utils.py b/projects/mmdet3d_plugin/surroundocc/loss/loss_utils.py
index dd74441..843a4c0 100644
--- a/projects/mmdet3d_plugin/surroundocc/loss/loss_utils.py
+++ b/projects/mmdet3d_plugin/surroundocc/loss/loss_utils.py
@@ -1,7 +1,21 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import pdb
+
 
 def multiscale_supervision(gt_occ, ratio, gt_shape):
     '''
@@ -11,10 +25,11 @@ def multiscale_supervision(gt_occ, ratio, gt_shape):
     gt = torch.zeros([gt_shape[0], gt_shape[2], gt_shape[3], gt_shape[4]]).to(gt_occ.device).type(torch.float) 
     for i in range(gt.shape[0]):
         coords = gt_occ[i][:, :3].type(torch.long) // ratio
-        gt[i, coords[:, 0], coords[:, 1], coords[:, 2]] =  gt_occ[i][:, 3]
+        gt[i, coords[:, 0], coords[:, 1], coords[:, 2]] = gt_occ[i][:, 3]
     
     return gt
 
+
 def geo_scal_loss(pred, ssc_target, semantic=True):
 
     # Get softmax probabilities
@@ -30,9 +45,9 @@ def geo_scal_loss(pred, ssc_target, semantic=True):
     # Remove unknown voxels
     mask = ssc_target != 255
     nonempty_target = ssc_target != 0
-    nonempty_target = nonempty_target[mask].float()
-    nonempty_probs = nonempty_probs[mask]
-    empty_probs = empty_probs[mask]
+    nonempty_target = torch.where(mask, nonempty_target, 0).float()
+    nonempty_probs = torch.where(mask, nonempty_probs, 0)
+    empty_probs = torch.where(mask, empty_probs, 0)
 
     intersection = (nonempty_target * nonempty_probs).sum()
     precision = intersection / nonempty_probs.sum()
@@ -59,13 +74,14 @@ def sem_scal_loss(pred, ssc_target):
 
         # Remove unknown voxels
         target_ori = ssc_target
-        p = p[mask]
-        target = ssc_target[mask]
 
+        p = torch.where(mask, p, 0)
+        target = torch.where(mask, ssc_target, i + 1)
         completion_target = torch.ones_like(target)
-        completion_target[target != i] = 0
-        completion_target_ori = torch.ones_like(target_ori).float()
-        completion_target_ori[target_ori != i] = 0
+        completion_target *= ~(target != i)
+        completion_target_ori = torch.ones_like(target_ori.to(torch.float))
+        completion_target_ori *= ~(target_ori != i)
+
         if torch.sum(completion_target) > 0:
             count += 1.0
             nominator = torch.sum(p * completion_target)
diff --git a/projects/mmdet3d_plugin/surroundocc/modules/encoder.py b/projects/mmdet3d_plugin/surroundocc/modules/encoder.py
index 24d8e53..ccdc80a 100644
--- a/projects/mmdet3d_plugin/surroundocc/modules/encoder.py
+++ b/projects/mmdet3d_plugin/surroundocc/modules/encoder.py
@@ -1,3 +1,16 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 # ---------------------------------------------
 # Copyright (c) OpenMMLab. All rights reserved.
@@ -21,7 +34,6 @@ import cv2 as cv
 import mmcv
 from mmcv.utils import TORCH_VERSION, digit_version
 from mmcv.utils import ext_loader
-import pdb
 import torch.nn.functional as F
 import torch.nn as nn
 from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
@@ -73,7 +85,6 @@ class OccEncoder(TransformerLayerSequence):
         ref_3d = ref_3d[None, None].repeat(bs, 1, 1, 1)
         return ref_3d
 
-
     # This function must use fp32!!!
     @force_fp32(apply_to=('reference_points', 'img_metas'))
     def point_sampling(self, reference_points, pc_range,  img_metas):
@@ -105,8 +116,9 @@ class OccEncoder(TransformerLayerSequence):
         lidar2img = lidar2img.view(
             1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
 
-        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
-                                            reference_points.to(torch.float32)).squeeze(-1)
+        reference_points_cam = torch.mul(lidar2img.to(torch.float32),
+                                            reference_points.to(torch.float32).transpose(-1, -2)).sum(-1, keepdim=True).squeeze(-1)
+
         eps = 1e-5
 
         volume_mask = (reference_points_cam[..., 2:3] > eps)
@@ -164,7 +176,7 @@ class OccEncoder(TransformerLayerSequence):
         intermediate = []
 
         ref_3d = self.get_reference_points(
-                    volume_h, volume_w, volume_z, bs=volume_query.size(1),  device=volume_query.device, dtype=volume_query.dtype)
+                    volume_h, volume_w, volume_z, bs=volume_query.size(1), device=volume_query.device, dtype=volume_query.dtype)
 
         reference_points_cam, volume_mask = self.point_sampling(
             ref_3d, self.pc_range, kwargs['img_metas'])
@@ -245,8 +257,8 @@ class OccLayer(MyCustomBaseTransformerLayer):
         self.fp16_enabled = False
 
         self.deblock = nn.ModuleList()
-        conv_cfg=dict(type='Conv3d', bias=False)
-        norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)
+        conv_cfg = dict(type='Conv3d', bias=False)
+        norm_cfg = dict(type='GN', num_groups=16, requires_grad=True)
         for i in range(conv_num):
             conv_layer = build_conv_layer(
                     conv_cfg,
@@ -338,7 +350,7 @@ class OccLayer(MyCustomBaseTransformerLayer):
                 query = query.reshape(bs, volume_z, volume_h, volume_w, -1).permute(0, 4, 3, 2, 1)
                 for i in range(len(self.deblock)):
                     query = self.deblock[i](query)
-                query = query.permute(0, 4, 3, 2, 1).reshape(bs, volume_z*volume_h*volume_w, -1)
+                query = query.permute(0, 4, 3, 2, 1).reshape(bs, volume_z * volume_h * volume_w, -1)
                 query = query + identity
     
             elif layer == 'norm':
diff --git a/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py b/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py
index 2ad500c..c402c11 100644
--- a/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py
+++ b/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py
@@ -1,3 +1,16 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 # ---------------------------------------------
 # Copyright (c) OpenMMLab. All rights reserved.
@@ -5,28 +18,31 @@
 #  Modified by Zhiqi Li
 # ---------------------------------------------
 
-from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
+import math
 import warnings
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from mmcv.cnn import xavier_init, constant_init
-from mmcv.cnn.bricks.registry import (ATTENTION,
-                                      TRANSFORMER_LAYER,
-                                      TRANSFORMER_LAYER_SEQUENCE)
+from mmcv.cnn import constant_init, xavier_init
+from mmcv.cnn.bricks.registry import ATTENTION, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE
 from mmcv.cnn.bricks.transformer import build_attention
-import math
-from mmcv.runner import force_fp32, auto_fp16
-
+from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
+from mmcv.runner import auto_fp16, force_fp32
 from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
-
 from mmcv.utils import ext_loader
-from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
-    MultiScaleDeformableAttnFunction_fp16
 from projects.mmdet3d_plugin.models.utils.bricks import run_time
+
+from mx_driving.fused import multi_scale_deformable_attn
+
+from .multi_scale_deformable_attn_function import (
+    MultiScaleDeformableAttnFunction_fp16,
+    MultiScaleDeformableAttnFunction_fp32,
+)
+
+
 ext_module = ext_loader.load_ext(
     '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
-import pdb
 
 
 @ATTENTION.register_module()
@@ -383,17 +399,7 @@ class MSDeformableAttention3D(BaseModule):
         #  attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
         #
 
-        if torch.cuda.is_available() and value.is_cuda:
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
-        else:
-            output = multi_scale_deformable_attn_pytorch(
-                value, spatial_shapes, sampling_locations, attention_weights)
+        output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights)
         if not self.batch_first:
             output = output.permute(1, 0, 2)
 
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..13a4b63
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+torchaudio==0.11.0
+torchvision==0.12.0
+mmsegmentation==0.30.0
+timm==0.9.16
+open3d-python==0.3.0.0
+numba==0.58.1
+numpy==1.23.0
+ipython
+sympy
+psutil
+attrs
diff --git a/tools/data_converter/nuimage_converter.py b/tools/data_converter/nuimage_converter.py
deleted file mode 100644
index 92be1de..0000000
--- a/tools/data_converter/nuimage_converter.py
+++ /dev/null
@@ -1,225 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import base64
-import mmcv
-import numpy as np
-from nuimages import NuImages
-from nuimages.utils.utils import mask_decode, name_to_index_mapping
-from os import path as osp
-
-nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
-                  'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
-                  'barrier')
-
-NAME_MAPPING = {
-    'movable_object.barrier': 'barrier',
-    'vehicle.bicycle': 'bicycle',
-    'vehicle.bus.bendy': 'bus',
-    'vehicle.bus.rigid': 'bus',
-    'vehicle.car': 'car',
-    'vehicle.construction': 'construction_vehicle',
-    'vehicle.motorcycle': 'motorcycle',
-    'human.pedestrian.adult': 'pedestrian',
-    'human.pedestrian.child': 'pedestrian',
-    'human.pedestrian.construction_worker': 'pedestrian',
-    'human.pedestrian.police_officer': 'pedestrian',
-    'movable_object.trafficcone': 'traffic_cone',
-    'vehicle.trailer': 'trailer',
-    'vehicle.truck': 'truck',
-}
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(description='Data converter arg parser')
-    parser.add_argument(
-        '--data-root',
-        type=str,
-        default='./data/nuimages',
-        help='specify the root path of dataset')
-    parser.add_argument(
-        '--version',
-        type=str,
-        nargs='+',
-        default=['v1.0-mini'],
-        required=False,
-        help='specify the dataset version')
-    parser.add_argument(
-        '--out-dir',
-        type=str,
-        default='./data/nuimages/annotations/',
-        required=False,
-        help='path to save the exported json')
-    parser.add_argument(
-        '--nproc',
-        type=int,
-        default=4,
-        required=False,
-        help='workers to process semantic masks')
-    parser.add_argument('--extra-tag', type=str, default='nuimages')
-    args = parser.parse_args()
-    return args
-
-
-def get_img_annos(nuim, img_info, cat2id, out_dir, data_root, seg_root):
-    """Get semantic segmentation map for an image.
-
-    Args:
-        nuim (obj:`NuImages`): NuImages dataset object
-        img_info (dict): Meta information of img
-
-    Returns:
-        np.ndarray: Semantic segmentation map of the image
-    """
-    sd_token = img_info['token']
-    image_id = img_info['id']
-    name_to_index = name_to_index_mapping(nuim.category)
-
-    # Get image data.
-    width, height = img_info['width'], img_info['height']
-    semseg_mask = np.zeros((height, width)).astype('uint8')
-
-    # Load stuff / surface regions.
-    surface_anns = [
-        o for o in nuim.surface_ann if o['sample_data_token'] == sd_token
-    ]
-
-    # Draw stuff / surface regions.
-    for ann in surface_anns:
-        # Get color and mask.
-        category_token = ann['category_token']
-        category_name = nuim.get('category', category_token)['name']
-        if ann['mask'] is None:
-            continue
-        mask = mask_decode(ann['mask'])
-
-        # Draw mask for semantic segmentation.
-        semseg_mask[mask == 1] = name_to_index[category_name]
-
-    # Load object instances.
-    object_anns = [
-        o for o in nuim.object_ann if o['sample_data_token'] == sd_token
-    ]
-
-    # Sort by token to ensure that objects always appear in the
-    # instance mask in the same order.
-    object_anns = sorted(object_anns, key=lambda k: k['token'])
-
-    # Draw object instances.
-    # The 0 index is reserved for background; thus, the instances
-    # should start from index 1.
-    annotations = []
-    for i, ann in enumerate(object_anns, start=1):
-        # Get color, box, mask and name.
-        category_token = ann['category_token']
-        category_name = nuim.get('category', category_token)['name']
-        if ann['mask'] is None:
-            continue
-        mask = mask_decode(ann['mask'])
-
-        # Draw masks for semantic segmentation and instance segmentation.
-        semseg_mask[mask == 1] = name_to_index[category_name]
-
-        if category_name in NAME_MAPPING:
-            cat_name = NAME_MAPPING[category_name]
-            cat_id = cat2id[cat_name]
-
-            x_min, y_min, x_max, y_max = ann['bbox']
-            # encode calibrated instance mask
-            mask_anno = dict()
-            mask_anno['counts'] = base64.b64decode(
-                ann['mask']['counts']).decode()
-            mask_anno['size'] = ann['mask']['size']
-
-            data_anno = dict(
-                image_id=image_id,
-                category_id=cat_id,
-                bbox=[x_min, y_min, x_max - x_min, y_max - y_min],
-                area=(x_max - x_min) * (y_max - y_min),
-                segmentation=mask_anno,
-                iscrowd=0)
-            annotations.append(data_anno)
-
-    # after process, save semantic masks
-    img_filename = img_info['file_name']
-    seg_filename = img_filename.replace('jpg', 'png')
-    seg_filename = osp.join(seg_root, seg_filename)
-    mmcv.imwrite(semseg_mask, seg_filename)
-    return annotations, np.max(semseg_mask)
-
-
-def export_nuim_to_coco(nuim, data_root, out_dir, extra_tag, version, nproc):
-    print('Process category information')
-    categories = []
-    categories = [
-        dict(id=nus_categories.index(cat_name), name=cat_name)
-        for cat_name in nus_categories
-    ]
-    cat2id = {k_v['name']: k_v['id'] for k_v in categories}
-
-    images = []
-    print('Process image meta information...')
-    for sample_info in mmcv.track_iter_progress(nuim.sample_data):
-        if sample_info['is_key_frame']:
-            img_idx = len(images)
-            images.append(
-                dict(
-                    id=img_idx,
-                    token=sample_info['token'],
-                    file_name=sample_info['filename'],
-                    width=sample_info['width'],
-                    height=sample_info['height']))
-
-    seg_root = f'{out_dir}semantic_masks'
-    mmcv.mkdir_or_exist(seg_root)
-    mmcv.mkdir_or_exist(osp.join(data_root, 'calibrated'))
-
-    global process_img_anno
-
-    def process_img_anno(img_info):
-        single_img_annos, max_cls_id = get_img_annos(nuim, img_info, cat2id,
-                                                     out_dir, data_root,
-                                                     seg_root)
-        return single_img_annos, max_cls_id
-
-    print('Process img annotations...')
-    if nproc > 1:
-        outputs = mmcv.track_parallel_progress(
-            process_img_anno, images, nproc=nproc)
-    else:
-        outputs = []
-        for img_info in mmcv.track_iter_progress(images):
-            outputs.append(process_img_anno(img_info))
-
-    # Determine the index of object annotation
-    print('Process annotation information...')
-    annotations = []
-    max_cls_ids = []
-    for single_img_annos, max_cls_id in outputs:
-        max_cls_ids.append(max_cls_id)
-        for img_anno in single_img_annos:
-            img_anno.update(id=len(annotations))
-            annotations.append(img_anno)
-
-    max_cls_id = max(max_cls_ids)
-    print(f'Max ID of class in the semantic map: {max_cls_id}')
-
-    coco_format_json = dict(
-        images=images, annotations=annotations, categories=categories)
-
-    mmcv.mkdir_or_exist(out_dir)
-    out_file = osp.join(out_dir, f'{extra_tag}_{version}.json')
-    print(f'Annotation dumped to {out_file}')
-    mmcv.dump(coco_format_json, out_file)
-
-
-def main():
-    args = parse_args()
-    for version in args.version:
-        nuim = NuImages(
-            dataroot=args.data_root, version=version, verbose=True, lazy=True)
-        export_nuim_to_coco(nuim, args.data_root, args.out_dir, args.extra_tag,
-                            version, args.nproc)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/tools/generate_occupancy_nuscenes/generate_occupancy_nuscenes.py b/tools/generate_occupancy_nuscenes/generate_occupancy_nuscenes.py
index e04ba5f..9ace9ba 100644
--- a/tools/generate_occupancy_nuscenes/generate_occupancy_nuscenes.py
+++ b/tools/generate_occupancy_nuscenes/generate_occupancy_nuscenes.py
@@ -1,6 +1,5 @@
 import os
 import sys
-import pdb
 import time
 import yaml
 import torch
@@ -36,7 +35,8 @@ def run_poisson(pcd, depth, n_threads, min_density=None):
 
     return mesh, densities
 
-def create_mesh_from_map(buffer, depth, n_threads, min_density=None, point_cloud_original= None):
+
+def create_mesh_from_map(buffer, depth, n_threads, min_density=None, point_cloud_original=None):
 
     if point_cloud_original is None:
         pcd = buffer_to_pointcloud(buffer)
@@ -45,6 +45,7 @@ def create_mesh_from_map(buffer, depth, n_threads, min_density=None, point_cloud
 
     return run_poisson(pcd, depth, n_threads, min_density)
 
+
 def buffer_to_pointcloud(buffer, compute_normals=False):
     pcd = o3d.geometry.PointCloud()
     for cloud in buffer:
@@ -77,6 +78,7 @@ def preprocess(pcd, config):
         normals=True
     )
 
+
 def nn_correspondance(verts1, verts2):
     """ for each vertex in verts2 find the nearest vertex in verts1
 
@@ -105,8 +107,7 @@ def nn_correspondance(verts1, verts2):
 
 
 
-
-def lidar_to_world_to_lidar(pc,lidar_calibrated_sensor,lidar_ego_pose,
+def lidar_to_world_to_lidar(pc, lidar_calibrated_sensor, lidar_ego_pose,
     cam_calibrated_sensor,
     cam_ego_pose):
 
@@ -205,7 +206,7 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
         object_points_list = []
         j = 0
         while j < points_in_boxes.shape[-1]:
-            object_points_mask = points_in_boxes[0][:,j].bool()
+            object_points_mask = points_in_boxes[0][:, j].bool()
             object_points = pc0[object_points_mask]
             object_points_list.append(object_points)
             j = j + 1
@@ -274,7 +275,7 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
     object_token_zoo = []
     object_semantic = []
     for dict in dict_list:
-        for i,object_token in enumerate(dict['object_tokens']):
+        for i, object_token in enumerate(dict['object_tokens']):
             if object_token not in object_token_zoo:
                 if (dict['object_points_list'][i].shape[0] > 0):
                     object_token_zoo.append(object_token)
@@ -291,7 +292,7 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
                 if query_object_token == object_token:
                     object_points = dict['object_points_list'][i]
                     if object_points.shape[0] > 0:
-                        object_points = object_points[:,:3] - dict['gt_bbox_3d'][i][:3]
+                        object_points = object_points[:, :3] - dict['gt_bbox_3d'][i][:3]
                         rots = dict['gt_bbox_3d'][i][6]
                         Rot = Rotation.from_euler('z', -rots, degrees=False)
                         rotated_object_points = Rot.apply(object_points)
@@ -305,7 +306,7 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
     object_points_vertice = []
     for key in object_points_dict.keys():
         point_cloud = object_points_dict[key]
-        object_points_vertice.append(point_cloud[:,:3])
+        object_points_vertice.append(point_cloud[:, :3])
     # print('object finish')
 
 
@@ -333,7 +334,7 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
                                                       lidar_ego_pose0.copy(),
                                                       lidar_calibrated_sensor,
                                                       lidar_ego_pose)
-        point_cloud = lidar_pc_i.points.T[:,:3]
+        point_cloud = lidar_pc_i.points.T[:, :3]
         point_cloud_with_semantic = lidar_pc_i_semantic.points.T
 
         ################## load bbox of target frame ##############
@@ -347,26 +348,26 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
         gt_bbox_3d[:, 2] -= dims[:, 2] / 2.
         gt_bbox_3d[:, 2] = gt_bbox_3d[:, 2] - 0.1
         gt_bbox_3d[:, 3:6] = gt_bbox_3d[:, 3:6] * 1.1
-        rots = gt_bbox_3d[:,6:7]
-        locs = gt_bbox_3d[:,0:3]
+        rots = gt_bbox_3d[:, 6:7]
+        locs = gt_bbox_3d[:, 0:3]
 
         ################## bbox placement ##############
         object_points_list = []
         object_semantic_list = []
         for j, object_token in enumerate(dict['object_tokens']):
             for k, object_token_in_zoo in enumerate(object_token_zoo):
-                if object_token==object_token_in_zoo:
+                if object_token == object_token_in_zoo:
                     points = object_points_vertice[k]
                     Rot = Rotation.from_euler('z', rots[j], degrees=False)
                     rotated_object_points = Rot.apply(points)
                     points = rotated_object_points + locs[j]
                     if points.shape[0] >= 5:
                         points_in_boxes = points_in_boxes_cpu(torch.from_numpy(points[:, :3][np.newaxis, :, :]),
-                                                              torch.from_numpy(gt_bbox_3d[j:j+1][np.newaxis, :]))
-                        points = points[points_in_boxes[0,:,0].bool()]
+                                                              torch.from_numpy(gt_bbox_3d[j:j + 1][np.newaxis, :]))
+                        points = points[points_in_boxes[0, :, 0].bool()]
 
                     object_points_list.append(points)
-                    semantics = np.ones_like(points[:,0:1]) * object_semantic[k]
+                    semantics = np.ones_like(points[:, 0:1]) * object_semantic[k]
                     object_semantic_list.append(np.concatenate([points[:, :3], semantics], axis=1))
 
         try: # avoid concatenate an empty array
@@ -433,8 +434,8 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
         sparse_voxels_semantic = scene_semantic_points
 
         x = torch.from_numpy(dense_voxels).cuda().unsqueeze(0).float()
-        y = torch.from_numpy(sparse_voxels_semantic[:,:3]).cuda().unsqueeze(0).float()
-        d1, d2, idx1, idx2 = chamfer.forward(x,y)
+        y = torch.from_numpy(sparse_voxels_semantic[:, :3]).cuda().unsqueeze(0).float()
+        d1, d2, idx1, idx2 = chamfer.forward(x, y)
         indices = idx1[0].cpu().numpy()
 
 
@@ -459,7 +460,7 @@ def main(nusc, val_list, indice, nuscenesyaml, args, config):
 
 def save_ply(points, name):
     point_cloud_original = o3d.geometry.PointCloud()
-    point_cloud_original.points = o3d.utility.Vector3dVector(points[:,:3])
+    point_cloud_original.points = o3d.utility.Vector3dVector(points[:, :3])
     o3d.io.write_point_cloud("{}.ply".format(name), point_cloud_original)
 
 
@@ -476,10 +477,10 @@ if __name__ == '__main__':
     parse.add_argument('--dataroot', type=str, default='./data/nuScenes/')
     parse.add_argument('--nusc_val_list', type=str, default='./nuscenes_val_list.txt')
     parse.add_argument('--label_mapping', type=str, default='nuscenes.yaml')
-    args=parse.parse_args()
+    args = parse.parse_args()
 
 
-    if args.dataset=='nuscenes':
+    if args.dataset == 'nuscenes':
         val_list = []
         with open(args.nusc_val_list, 'r') as file:
             for item in file:
@@ -504,7 +505,7 @@ if __name__ == '__main__':
         nuscenesyaml = yaml.safe_load(stream)
 
 
-    for i in range(args.start,args.end):
+    for i in range(args.start, args.end):
         print('processing sequecne:', i)
         main(nusc, val_list, indice=i,
              nuscenesyaml=nuscenesyaml, args=args, config=config)
diff --git a/tools/generate_occupancy_with_own_data/process_your_own_data.py b/tools/generate_occupancy_with_own_data/process_your_own_data.py
index 2395da0..bc2af65 100644
--- a/tools/generate_occupancy_with_own_data/process_your_own_data.py
+++ b/tools/generate_occupancy_with_own_data/process_your_own_data.py
@@ -26,7 +26,8 @@ def run_poisson(pcd, depth, n_threads, min_density=None):
 
     return mesh, densities
 
-def create_mesh_from_map(buffer, depth, n_threads, min_density=None, point_cloud_original= None):
+
+def create_mesh_from_map(buffer, depth, n_threads, min_density=None, point_cloud_original=None):
 
     if point_cloud_original is None:
         pcd = buffer_to_pointcloud(buffer)
@@ -35,6 +36,7 @@ def create_mesh_from_map(buffer, depth, n_threads, min_density=None, point_cloud
 
     return run_poisson(pcd, depth, n_threads, min_density)
 
+
 def buffer_to_pointcloud(buffer, compute_normals=False):
     pcd = o3d.geometry.PointCloud()
     for cloud in buffer:
@@ -67,6 +69,7 @@ def preprocess(pcd, config):
         normals=True
     )
 
+
 def nn_correspondance(verts1, verts2):
     """ for each vertex in verts2 find the nearest vertex in verts1
 
@@ -94,9 +97,7 @@ def nn_correspondance(verts1, verts2):
     return indices, distances
 
 
-
-
-def lidar_to_world_to_lidar(pc,lidar_calibrated_sensor,lidar_ego_pose,
+def lidar_to_world_to_lidar(pc, lidar_calibrated_sensor, lidar_ego_pose,
     cam_calibrated_sensor,
     cam_ego_pose):
 
@@ -128,7 +129,7 @@ if __name__ == '__main__':
     parse.add_argument('--whole_scene_to_mesh', action='store_true', default=False)
 
 
-    args=parse.parse_args()
+    args = parse.parse_args()
 
     # load config
     with open(args.config_path, 'r') as stream:
@@ -140,11 +141,11 @@ if __name__ == '__main__':
 
 
     path = args.data_path
-    pc_path = os.path.join(path,'pc/')
-    pc_seman_path = os.path.join(path,'pc_seman/')
-    bbox_path = os.path.join(path,'bbox/')
-    calib_path = os.path.join(path,'calib/')
-    pose_path = os.path.join(path,'pose/')
+    pc_path = os.path.join(path, 'pc/')
+    pc_seman_path = os.path.join(path, 'pc_seman/')
+    bbox_path = os.path.join(path, 'bbox/')
+    calib_path = os.path.join(path, 'calib/')
+    pose_path = os.path.join(path, 'pose/')
 
     lidar_ego_pose0 = np.load(os.path.join(pose_path, 'lidar_ego_pose0.npy'), allow_pickle=True).item()
     lidar_calibrated_sensor0 = np.load(os.path.join(calib_path, 'lidar_calibrated_sensor0.npy'), allow_pickle=True).item()
@@ -167,7 +168,7 @@ if __name__ == '__main__':
         object_points_list = []
         j = 0
         while j < points_in_boxes.shape[-1]:
-            object_points_mask = points_in_boxes[0][:,j].bool()
+            object_points_mask = points_in_boxes[0][:, j].bool()
             object_points = pc0[object_points_mask]
             object_points_list.append(object_points)
             j = j + 1
diff --git a/tools/model_converters/convert_votenet_checkpoints.py b/tools/model_converters/convert_votenet_checkpoints.py
deleted file mode 100644
index 33792b0..0000000
--- a/tools/model_converters/convert_votenet_checkpoints.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import tempfile
-import torch
-from mmcv import Config
-from mmcv.runner import load_state_dict
-
-from mmdet3d.models import build_detector
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(
-        description='MMDet3D upgrade model version(before v0.6.0) of VoteNet')
-    parser.add_argument('checkpoint', help='checkpoint file')
-    parser.add_argument('--out', help='path of the output checkpoint file')
-    args = parser.parse_args()
-    return args
-
-
-def parse_config(config_strings):
-    """Parse config from strings.
-
-    Args:
-        config_strings (string): strings of model config.
-
-    Returns:
-        Config: model config
-    """
-    temp_file = tempfile.NamedTemporaryFile()
-    config_path = f'{temp_file.name}.py'
-    with open(config_path, 'w') as f:
-        f.write(config_strings)
-
-    config = Config.fromfile(config_path)
-
-    # Update backbone config
-    if 'pool_mod' in config.model.backbone:
-        config.model.backbone.pop('pool_mod')
-
-    if 'sa_cfg' not in config.model.backbone:
-        config.model.backbone['sa_cfg'] = dict(
-            type='PointSAModule',
-            pool_mod='max',
-            use_xyz=True,
-            normalize_xyz=True)
-
-    if 'type' not in config.model.bbox_head.vote_aggregation_cfg:
-        config.model.bbox_head.vote_aggregation_cfg['type'] = 'PointSAModule'
-
-    # Update bbox_head config
-    if 'pred_layer_cfg' not in config.model.bbox_head:
-        config.model.bbox_head['pred_layer_cfg'] = dict(
-            in_channels=128, shared_conv_channels=(128, 128), bias=True)
-
-    if 'feat_channels' in config.model.bbox_head:
-        config.model.bbox_head.pop('feat_channels')
-
-    if 'vote_moudule_cfg' in config.model.bbox_head:
-        config.model.bbox_head['vote_module_cfg'] = config.model.bbox_head.pop(
-            'vote_moudule_cfg')
-
-    if config.model.bbox_head.vote_aggregation_cfg.use_xyz:
-        config.model.bbox_head.vote_aggregation_cfg.mlp_channels[0] -= 3
-
-    temp_file.close()
-
-    return config
-
-
-def main():
-    """Convert keys in checkpoints for VoteNet.
-
-    There can be some breaking changes during the development of mmdetection3d,
-    and this tool is used for upgrading checkpoints trained with old versions
-    (before v0.6.0) to the latest one.
-    """
-    args = parse_args()
-    checkpoint = torch.load(args.checkpoint)
-    cfg = parse_config(checkpoint['meta']['config'])
-    # Build the model and load checkpoint
-    model = build_detector(
-        cfg.model,
-        train_cfg=cfg.get('train_cfg'),
-        test_cfg=cfg.get('test_cfg'))
-    orig_ckpt = checkpoint['state_dict']
-    converted_ckpt = orig_ckpt.copy()
-
-    if cfg['dataset_type'] == 'ScanNetDataset':
-        NUM_CLASSES = 18
-    elif cfg['dataset_type'] == 'SUNRGBDDataset':
-        NUM_CLASSES = 10
-    else:
-        raise NotImplementedError
-
-    RENAME_PREFIX = {
-        'bbox_head.conv_pred.0': 'bbox_head.conv_pred.shared_convs.layer0',
-        'bbox_head.conv_pred.1': 'bbox_head.conv_pred.shared_convs.layer1'
-    }
-
-    DEL_KEYS = [
-        'bbox_head.conv_pred.0.bn.num_batches_tracked',
-        'bbox_head.conv_pred.1.bn.num_batches_tracked'
-    ]
-
-    EXTRACT_KEYS = {
-        'bbox_head.conv_pred.conv_cls.weight':
-        ('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]),
-        'bbox_head.conv_pred.conv_cls.bias':
-        ('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]),
-        'bbox_head.conv_pred.conv_reg.weight':
-        ('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]),
-        'bbox_head.conv_pred.conv_reg.bias':
-        ('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)])
-    }
-
-    # Delete some useless keys
-    for key in DEL_KEYS:
-        converted_ckpt.pop(key)
-
-    # Rename keys with specific prefix
-    RENAME_KEYS = dict()
-    for old_key in converted_ckpt.keys():
-        for rename_prefix in RENAME_PREFIX.keys():
-            if rename_prefix in old_key:
-                new_key = old_key.replace(rename_prefix,
-                                          RENAME_PREFIX[rename_prefix])
-                RENAME_KEYS[new_key] = old_key
-    for new_key, old_key in RENAME_KEYS.items():
-        converted_ckpt[new_key] = converted_ckpt.pop(old_key)
-
-    # Extract weights and rename the keys
-    for new_key, (old_key, indices) in EXTRACT_KEYS.items():
-        cur_layers = orig_ckpt[old_key]
-        converted_layers = []
-        for (start, end) in indices:
-            if end != -1:
-                converted_layers.append(cur_layers[start:end])
-            else:
-                converted_layers.append(cur_layers[start:])
-        converted_layers = torch.cat(converted_layers, 0)
-        converted_ckpt[new_key] = converted_layers
-        if old_key in converted_ckpt.keys():
-            converted_ckpt.pop(old_key)
-
-    # Check the converted checkpoint by loading to the model
-    load_state_dict(model, converted_ckpt, strict=True)
-    checkpoint['state_dict'] = converted_ckpt
-    torch.save(checkpoint, args.out)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/tools/model_converters/publish_model.py b/tools/model_converters/publish_model.py
deleted file mode 100644
index 318fd46..0000000
--- a/tools/model_converters/publish_model.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import subprocess
-import torch
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(
-        description='Process a checkpoint to be published')
-    parser.add_argument('in_file', help='input checkpoint filename')
-    parser.add_argument('out_file', help='output checkpoint filename')
-    args = parser.parse_args()
-    return args
-
-
-def process_checkpoint(in_file, out_file):
-    checkpoint = torch.load(in_file, map_location='cpu')
-    # remove optimizer for smaller file size
-    if 'optimizer' in checkpoint:
-        del checkpoint['optimizer']
-    # if it is necessary to remove some sensitive data in checkpoint['meta'],
-    # add the code here.
-    torch.save(checkpoint, out_file)
-    sha = subprocess.check_output(['sha256sum', out_file]).decode()
-    final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
-    subprocess.Popen(['mv', out_file, final_file])
-
-
-def main():
-    args = parse_args()
-    process_checkpoint(args.in_file, args.out_file)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/tools/model_converters/regnet2mmdet.py b/tools/model_converters/regnet2mmdet.py
deleted file mode 100644
index 9dee3c8..0000000
--- a/tools/model_converters/regnet2mmdet.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import torch
-from collections import OrderedDict
-
-
-def convert_stem(model_key, model_weight, state_dict, converted_names):
-    new_key = model_key.replace('stem.conv', 'conv1')
-    new_key = new_key.replace('stem.bn', 'bn1')
-    state_dict[new_key] = model_weight
-    converted_names.add(model_key)
-    print(f'Convert {model_key} to {new_key}')
-
-
-def convert_head(model_key, model_weight, state_dict, converted_names):
-    new_key = model_key.replace('head.fc', 'fc')
-    state_dict[new_key] = model_weight
-    converted_names.add(model_key)
-    print(f'Convert {model_key} to {new_key}')
-
-
-def convert_reslayer(model_key, model_weight, state_dict, converted_names):
-    split_keys = model_key.split('.')
-    layer, block, module = split_keys[:3]
-    block_id = int(block[1:])
-    layer_name = f'layer{int(layer[1:])}'
-    block_name = f'{block_id - 1}'
-
-    if block_id == 1 and module == 'bn':
-        new_key = f'{layer_name}.{block_name}.downsample.1.{split_keys[-1]}'
-    elif block_id == 1 and module == 'proj':
-        new_key = f'{layer_name}.{block_name}.downsample.0.{split_keys[-1]}'
-    elif module == 'f':
-        if split_keys[3] == 'a_bn':
-            module_name = 'bn1'
-        elif split_keys[3] == 'b_bn':
-            module_name = 'bn2'
-        elif split_keys[3] == 'c_bn':
-            module_name = 'bn3'
-        elif split_keys[3] == 'a':
-            module_name = 'conv1'
-        elif split_keys[3] == 'b':
-            module_name = 'conv2'
-        elif split_keys[3] == 'c':
-            module_name = 'conv3'
-        new_key = f'{layer_name}.{block_name}.{module_name}.{split_keys[-1]}'
-    else:
-        raise ValueError(f'Unsupported conversion of key {model_key}')
-    print(f'Convert {model_key} to {new_key}')
-    state_dict[new_key] = model_weight
-    converted_names.add(model_key)
-
-
-def convert(src, dst):
-    """Convert keys in pycls pretrained RegNet models to mmdet style."""
-    # load caffe model
-    regnet_model = torch.load(src)
-    blobs = regnet_model['model_state']
-    # convert to pytorch style
-    state_dict = OrderedDict()
-    converted_names = set()
-    for key, weight in blobs.items():
-        if 'stem' in key:
-            convert_stem(key, weight, state_dict, converted_names)
-        elif 'head' in key:
-            convert_head(key, weight, state_dict, converted_names)
-        elif key.startswith('s'):
-            convert_reslayer(key, weight, state_dict, converted_names)
-
-    # check if all layers are converted
-    for key in blobs:
-        if key not in converted_names:
-            print(f'not converted: {key}')
-    # save checkpoint
-    checkpoint = dict()
-    checkpoint['state_dict'] = state_dict
-    torch.save(checkpoint, dst)
-
-
-def main():
-    parser = argparse.ArgumentParser(description='Convert model keys')
-    parser.add_argument('src', help='src detectron model path')
-    parser.add_argument('dst', help='save path')
-    args = parser.parse_args()
-    convert(args.src, args.dst)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/tools/train.py b/tools/train.py
index c172243..b24347c 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -1,3 +1,17 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 # ---------------------------------------------
 # Copyright (c) OpenMMLab. All rights reserved.
 # ---------------------------------------------
@@ -28,8 +42,10 @@ from mmdet.apis import set_random_seed
 from mmseg import __version__ as mmseg_version
 
 from mmcv.utils import TORCH_VERSION, digit_version
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
 
-
+torch.npu.config.allow_internal_format = False
 
 def parse_args():
     parser = argparse.ArgumentParser(description='Train a detector')
@@ -81,13 +97,14 @@ def parse_args():
         default='none',
         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
+    parser.add_argument('--local-rank', type=int, default=0)
     parser.add_argument(
         '--autoscale-lr',
         action='store_true',
         help='automatically scale lr with the number of gpus')
     args = parser.parse_args()
     if 'LOCAL_RANK' not in os.environ:
-        os.environ['LOCAL_RANK'] = str(args.local_rank)
+        os.environ['LOCAL_RANK'] = str(args.local-rank)
 
     if args.options and args.cfg_options:
         raise ValueError(
diff --git a/tools/visual.py b/tools/visual.py
index 34b513a..f67f743 100644
--- a/tools/visual.py
+++ b/tools/visual.py
@@ -34,7 +34,7 @@ colors = np.array(
 #mlab.options.offscreen = True
 
 voxel_size = 0.5
-pc_range = [-50, -50,  -5, 50, 50, 3]
+pc_range = [-50, -50, -5, 50, 50, 3]
 
 visual_path = sys.argv[1]
 fov_voxels = np.load(visual_path)
@@ -48,14 +48,13 @@ fov_voxels[:, 2] += pc_range[2]
 
 #figure = mlab.figure(size=(600, 600), bgcolor=(1, 1, 1))
 figure = mlab.figure(size=(2560, 1440), bgcolor=(1, 1, 1))
-# pdb.set_trace()
 plt_plot_fov = mlab.points3d(
     fov_voxels[:, 0],
     fov_voxels[:, 1],
     fov_voxels[:, 2],
     fov_voxels[:, 3],
     colormap="viridis",
-    scale_factor=voxel_size - 0.05*voxel_size,
+    scale_factor=voxel_size - 0.05 * voxel_size,
     mode="cube",
     opacity=1.0,
     vmin=0,