diff --git a/nuscenes_need/mot.py b/nuscenes_need/mot.py
new file mode 100644
index 0000000..db4bd0b
--- /dev/null
+++ b/nuscenes_need/mot.py
@@ -0,0 +1,139 @@
+"""
+nuScenes dev-kit.
+Code written by Holger Caesar, Caglayan Dicle and Oscar Beijbom, 2019.
+
+This code is based on:
+
+py-motmetrics at:
+"""
+from collections import OrderedDict
+from itertools import count
+
+import motmetrics
+import numpy as np
+import pandas as pd
+
+
+class MOTAccumulatorCustom(motmetrics.mot.MOTAccumulator):
+    def __init__(self):
+        super().__init__()
+
+    @staticmethod
+    def new_event_dataframe_with_data(indices, events):
+        """
+        Create a new DataFrame filled with data.
+        This version overwrites the original in MOTAccumulator achieves about 2x speedups.
+
+        Params
+        ------
+        indices: list
+            list of tuples (frameid, eventid)
+        events: list
+            list of events where each event is a list containing
+            'Type', 'OId', HId', 'D'
+        """
+        idx = pd.MultiIndex.from_tuples(indices, names=['FrameId', 'Event'])
+        df = pd.DataFrame(events, index=idx, columns=['Type', 'OId', 'HId', 'D'])
+        return df
+
+    @staticmethod
+    def new_event_dataframe():
+        """ Create a new DataFrame for event tracking. """
+        idx = pd.MultiIndex(levels=[[], []], codes=[[], []], names=['FrameId', 'Event'])
+        cats = pd.Categorical([], categories=['RAW', 'FP', 'MISS', 'SWITCH', 'MATCH'])
+        df = pd.DataFrame(
+            OrderedDict([
+                ('Type', pd.Series(cats)),  # Type of event. One of FP (false positive), MISS, SWITCH, MATCH
+                ('OId', pd.Series(dtype=object)),
+                # Object ID or -1 if FP. Using float as missing values will be converted to NaN anyways.
+                ('HId', pd.Series(dtype=object)),
+                # Hypothesis ID or NaN if MISS. Using float as missing values will be converted to NaN anyways.
+                ('D', pd.Series(dtype=float)),  # Distance or NaN when FP or MISS
+            ]),
+            index=idx
+        )
+        return df
+
+    @property
+    def events(self):
+        if self.dirty_events:
+            self.cached_events_df = MOTAccumulatorCustom.new_event_dataframe_with_data(self._indices, self._events)
+            self.dirty_events = False
+        return self.cached_events_df
+
+    @staticmethod
+    def merge_event_dataframes(dfs, update_frame_indices=True, update_oids=True, update_hids=True,
+                               return_mappings=False):
+        """Merge dataframes.
+
+        Params
+        ------
+        dfs : list of pandas.DataFrame or MotAccumulator
+            A list of event containers to merge
+
+        Kwargs
+        ------
+        update_frame_indices : boolean, optional
+            Ensure that frame indices are unique in the merged container
+        update_oids : boolean, unique
+            Ensure that object ids are unique in the merged container
+        update_hids : boolean, unique
+            Ensure that hypothesis ids are unique in the merged container
+        return_mappings : boolean, unique
+            Whether or not to return mapping information
+
+        Returns
+        -------
+        df : pandas.DataFrame
+            Merged event data frame
+        """
+
+        mapping_infos = []
+        new_oid = count()
+        new_hid = count()
+
+        r = MOTAccumulatorCustom.new_event_dataframe()
+        for df in dfs:
+
+            if isinstance(df, MOTAccumulatorCustom):
+                df = df.events
+
+            copy = df.copy()
+            infos = {}
+
+            # Update index
+            if update_frame_indices:
+                if r.index.get_level_values(0).size > 0 and isinstance(r.index.get_level_values(0)[0], tuple):
+                    index_temp = []
+                    for item in r.index.get_level_values(0):
+                        index_temp += list(item)
+                    index_temp = np.array(index_temp)
+                    index_temp_unique = np.unique(index_temp)
+                    next_frame_id = max(np.max(index_temp)+1, index_temp_unique.shape[0])
+                else:
+                    next_frame_id = max(r.index.get_level_values(0).max() + 1,
+                                        r.index.get_level_values(0).unique().shape[0])
+
+                if np.isnan(next_frame_id):
+                    next_frame_id = 0
+                copy.index = copy.index.map(lambda x: (x[0] + next_frame_id, x[1]))
+                infos['frame_offset'] = next_frame_id
+
+            # Update object / hypothesis ids
+            if update_oids:
+                oid_map = dict([oid, str(next(new_oid))] for oid in copy['OId'].dropna().unique())
+                copy['OId'] = copy['OId'].map(lambda x: oid_map[x], na_action='ignore')
+                infos['oid_map'] = oid_map
+
+            if update_hids:
+                hid_map = dict([hid, str(next(new_hid))] for hid in copy['HId'].dropna().unique())
+                copy['HId'] = copy['HId'].map(lambda x: hid_map[x], na_action='ignore')
+                infos['hid_map'] = hid_map
+
+            r = pd.concat((r, copy))
+            mapping_infos.append(infos)
+
+        if return_mappings:
+            return r, mapping_infos
+        else:
+            return r
diff --git a/projects/configs/stage1_track_map/base_track_map.py b/projects/configs/stage1_track_map/base_track_map.py
index 0f056d4..9dbca22 100644
--- a/projects/configs/stage1_track_map/base_track_map.py
+++ b/projects/configs/stage1_track_map/base_track_map.py
@@ -555,7 +555,7 @@ data = dict(
     nonshuffler_sampler=dict(type="DistributedSampler"),
 )
 optimizer = dict(
-    type="AdamW",
+    type="NpuFusedAdamW",
     lr=2e-4,
     paramwise_cfg=dict(
         custom_keys={
@@ -564,7 +564,7 @@ optimizer = dict(
     ),
     weight_decay=0.01,
 )
-optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+optimizer_config = dict(type='GradientCumulativeOptimizerHook', cumulative_iters=2, grad_clip=dict(max_norm=70, norm_type=2))
 # learning policy
 lr_config = dict(
     policy="CosineAnnealing",
@@ -573,15 +573,11 @@ lr_config = dict(
     warmup_ratio=1.0 / 3,
     min_lr_ratio=1e-3,
 )
-total_epochs = 6
-evaluation = dict(
-    interval=6,
-    pipeline=test_pipeline,
-    planning_evaluation_strategy=planning_evaluation_strategy,
-)
+total_epochs = 4
+evaluation = dict(interval=4, pipeline=test_pipeline)
 runner = dict(type="EpochBasedRunner", max_epochs=total_epochs)
 log_config = dict(
-    interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
+    interval=1, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
 )
 checkpoint_config = dict(interval=1)
 load_from = "ckpts/bevformer_r101_dcn_24ep.pth"
diff --git a/projects/configs/stage2_e2e/base_e2e.py b/projects/configs/stage2_e2e/base_e2e.py
index 9903440..6300f15 100644
--- a/projects/configs/stage2_e2e/base_e2e.py
+++ b/projects/configs/stage2_e2e/base_e2e.py
@@ -670,7 +670,7 @@ data = dict(
     nonshuffler_sampler=dict(type="DistributedSampler"),
 )
 optimizer = dict(
-    type="AdamW",
+    type="NpuFusedAdamW",
     lr=2e-4,
     paramwise_cfg=dict(
         custom_keys={
@@ -679,7 +679,7 @@ optimizer = dict(
     ),
     weight_decay=0.01,
 )
-optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+optimizer_config = dict(type='GradientCumulativeOptimizerHook', cumulative_iters=2, grad_clip=dict(max_norm=70, norm_type=2))
 # learning policy
 lr_config = dict(
     policy="CosineAnnealing",
@@ -688,15 +688,11 @@ lr_config = dict(
     warmup_ratio=1.0 / 3,
     min_lr_ratio=1e-3,
 )
-total_epochs = 20
-evaluation = dict(
-    interval=4,
-    pipeline=test_pipeline,
-    planning_evaluation_strategy=planning_evaluation_strategy,
-)
+total_epochs = 4
+evaluation = dict(interval=4, pipeline=test_pipeline)
 runner = dict(type="EpochBasedRunner", max_epochs=total_epochs)
 log_config = dict(
-    interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
+    interval=1, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
 )
 checkpoint_config = dict(interval=1)
 load_from = "ckpts/uniad_base_track_map.pth"
diff --git a/projects/mmdet3d_plugin/core/bbox/match_costs/match_cost.py b/projects/mmdet3d_plugin/core/bbox/match_costs/match_cost.py
index d73b7e0..515ac1c 100755
--- a/projects/mmdet3d_plugin/core/bbox/match_costs/match_cost.py
+++ b/projects/mmdet3d_plugin/core/bbox/match_costs/match_cost.py
@@ -28,7 +28,7 @@ class BBox3DL1Cost(object):
         return bbox_cost * self.weight
 
 
-@MATCH_COST.register_module()
+@MATCH_COST.register_module(force=True)
 class DiceCost(object):
     """IoUCost.
 
diff --git a/projects/mmdet3d_plugin/core/bbox/util.py b/projects/mmdet3d_plugin/core/bbox/util.py
index c54bd75..16f4fd2 100755
--- a/projects/mmdet3d_plugin/core/bbox/util.py
+++ b/projects/mmdet3d_plugin/core/bbox/util.py
@@ -24,30 +24,15 @@ def normalize_bbox(bboxes, pc_range):
     return normalized_bboxes
 
 def denormalize_bbox(normalized_bboxes, pc_range):
-    # rotation 
-    rot_sine = normalized_bboxes[..., 6:7]
+    # rotation
+    cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy = normalized_bboxes.split(1, dim=-1)
 
-    rot_cosine = normalized_bboxes[..., 7:8]
     rot = torch.atan2(rot_sine, rot_cosine)
 
-    # center in the bev
-    cx = normalized_bboxes[..., 0:1]
-    cy = normalized_bboxes[..., 1:2]
-    cz = normalized_bboxes[..., 4:5]
-   
-    # size
-    w = normalized_bboxes[..., 2:3]
-    l = normalized_bboxes[..., 3:4]
-    h = normalized_bboxes[..., 5:6]
-
-    w = w.exp() 
-    l = l.exp() 
-    h = h.exp() 
-    if normalized_bboxes.size(-1) > 8:
-         # velocity 
-        vx = normalized_bboxes[:, 8:9]
-        vy = normalized_bboxes[:, 9:10]
-        denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
-    else:
-        denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)
+    w = w.exp()
+    l = l.exp()
+    h = h.exp()
+
+    denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
+
     return denormalized_bboxes
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/datasets/nuscenes_e2e_dataset.py b/projects/mmdet3d_plugin/datasets/nuscenes_e2e_dataset.py
index c1b9392..88de6c3 100644
--- a/projects/mmdet3d_plugin/datasets/nuscenes_e2e_dataset.py
+++ b/projects/mmdet3d_plugin/datasets/nuscenes_e2e_dataset.py
@@ -1025,7 +1025,6 @@ class NuScenesE2EDataset(NuScenesDataset):
             if 'planning_results_computed' in results.keys():
                 planning_results_computed = results['planning_results_computed']
                 planning_tab = PrettyTable()
-                planning_tab.title = f"{planning_evaluation_strategy}'s definition planning metrics"
                 planning_tab.field_names = [
                     "metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s"]
                 for key in planning_results_computed.keys():
@@ -1033,14 +1032,7 @@ class NuScenesE2EDataset(NuScenesDataset):
                     row_value = []
                     row_value.append(key)
                     for i in range(len(value)):
-                        if planning_evaluation_strategy == "stp3":
-                            row_value.append("%.4f" % float(value[: i + 1].mean()))
-                        elif planning_evaluation_strategy == "uniad":
                             row_value.append("%.4f" % float(value[i]))
-                        else:
-                            raise ValueError(
-                                "planning_evaluation_strategy should be uniad or spt3"
-                            )
                     planning_tab.add_row(row_value)
                 print(planning_tab)
             results = results['bbox_results']  # get bbox_results
diff --git a/projects/mmdet3d_plugin/losses/dice_loss.py b/projects/mmdet3d_plugin/losses/dice_loss.py
index 3cb635f..50f6a13 100644
--- a/projects/mmdet3d_plugin/losses/dice_loss.py
+++ b/projects/mmdet3d_plugin/losses/dice_loss.py
@@ -1,5 +1,4 @@
 import torch
-import torch
 import torch.nn as nn
 
 from mmdet.models.losses.utils import weighted_loss
@@ -21,7 +20,7 @@ def dice_loss(input, target,mask=None,eps=0.001):
     d = (2 * a) / (b + c)
     return 1 - d
 
-@LOSSES.register_module()
+@LOSSES.register_module(force=True)
 class DiceLoss(nn.Module):
 
     def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
diff --git a/projects/mmdet3d_plugin/losses/planning_loss.py b/projects/mmdet3d_plugin/losses/planning_loss.py
index 6d47070..bf70b26 100644
--- a/projects/mmdet3d_plugin/losses/planning_loss.py
+++ b/projects/mmdet3d_plugin/losses/planning_loss.py
@@ -40,20 +40,41 @@ class CollisionLoss(nn.Module):
         # sdc_planning_gt_mask (1, 6)
         # future_gt_bbox 6x[lidarboxinstance]
         n_futures = len(future_gt_bbox)
-        inter_sum = sdc_traj_all.new_zeros(1, )
-        dump_sdc = []
+        inter_sum = torch.tensor([0.], device=sdc_traj_all.device, dtype=sdc_traj_all.dtype)
         for i in range(n_futures):
             if len(future_gt_bbox[i].tensor) > 0:
-                future_gt_bbox_corners = future_gt_bbox[i].corners[:, [0,3,4,7], :2] # (N, 8, 3) -> (N, 4, 2) only bev 
+                future_gt_corners = future_gt_bbox[i].corners[:, [0,3,4,7], :2].npu() # (N, 8, 3) -> (N, 4, 2) only bev
                 # sdc_yaw = -sdc_planning_gt[0, i, 2].to(sdc_traj_all.dtype) - 1.5708
                 sdc_yaw = sdc_planning_gt[0, i, 2].to(sdc_traj_all.dtype)
                 sdc_bev_box = self.to_corners([sdc_traj_all[0, i, 0], sdc_traj_all[0, i, 1], self.w, self.h, sdc_yaw])
-                dump_sdc.append(sdc_bev_box.cpu().detach().numpy())
-                for j in range(future_gt_bbox_corners.shape[0]):
-                    inter_sum += self.inter_bbox(sdc_bev_box, future_gt_bbox_corners[j].to(sdc_traj_all.device))
+                sdc_min = sdc_bev_box.min(dim=0)[0] # (2,)
+                sdc_max = sdc_bev_box.max(dim=0)[0] # (2,)
+
+                # Compute min/max for all target boxes (N, 2)
+                target_min = future_gt_corners.min(dim=1)[0]  # (N, 2)
+                target_max = future_gt_corners.max(dim=1)[0]  # (N, 2)
+
+                # Compute intersection for all boxes (vectorized)
+                intersect_min = torch.maximum(sdc_min, target_min)  # (N, 2)
+                intersect_max = torch.minimum(sdc_max, target_max)  # (N, 2)
+                intersect_dims = intersect_max - intersect_min  # (N, 2)
+
+                # Clamp negative values to 0 and compute area
+                intersect_area = torch.prod(
+                    torch.clamp_min(intersect_dims, 0),
+                    dim=1
+                )  # (N,)
+
+                # Sum all intersections for this frame
+                inter_sum += intersect_area.sum()
+
         return inter_sum * self.weight
         
     def inter_bbox(self, corners_a, corners_b):
+        device_stack = corners_a.device
+        corners_a = corners_a.cpu()
+        corners_b = corners_b.cpu()
+
         xa1, ya1 = torch.max(corners_a[:, 0]), torch.max(corners_a[:, 1])
         xa2, ya2 = torch.min(corners_a[:, 0]), torch.min(corners_a[:, 1])
         xb1, yb1 = torch.max(corners_b[:, 0]), torch.max(corners_b[:, 1])
@@ -61,7 +82,9 @@ class CollisionLoss(nn.Module):
         
         xi1, yi1 = min(xa1, xb1), min(ya1, yb1)
         xi2, yi2 = max(xa2, xb2), max(ya2, yb2)
-        intersect = max((xi1 - xi2), xi1.new_zeros(1, ).to(xi1.device)) * max((yi1 - yi2), xi1.new_zeros(1,).to(xi1.device))
+
+        intersect = max((xi1 - xi2), xi1.new_zeros(1, )) * max((yi1 - yi2), xi1.new_zeros(1,))
+        intersect = intersect.to(device_stack)
         return intersect
 
     def to_corners(self, bbox):
diff --git a/projects/mmdet3d_plugin/losses/track_loss.py b/projects/mmdet3d_plugin/losses/track_loss.py
index 549c5ae..16e80cf 100644
--- a/projects/mmdet3d_plugin/losses/track_loss.py
+++ b/projects/mmdet3d_plugin/losses/track_loss.py
@@ -244,33 +244,26 @@ class ClipMatcher(nn.Module):
         filtered_idx = []
         for src_per_img, tgt_per_img in indices:
             keep = tgt_per_img != -1
-            filtered_idx.append((src_per_img[keep], tgt_per_img[keep]))
-        indices = filtered_idx
+            if not keep.all().item():
+                filtered_idx.append((src_per_img[keep], tgt_per_img[keep]))
+                indices = filtered_idx
         idx = self._get_src_permutation_idx(indices)
         src_boxes = outputs["pred_boxes"][idx]
         sdc_boxes = outputs["pred_sdc_boxes"][0, -1:]
         target_sdc_boxes = gt_instances[0].sdc_boxes[:1]
-        target_boxes = torch.cat(
-            [
-                gt_per_img.boxes[i]
-                for gt_per_img, (_, i) in zip(gt_instances, indices)
-            ],
-            dim=0,
-        )
+
+        target_boxes_list = []
+        target_obj_ids_list = []
+        for gt_per_img, (_, i) in zip(gt_instances, indices):
+            target_boxes_list.append(gt_per_img.boxes[i])
+            target_obj_ids_list.append(gt_per_img.obj_ids[i])
+
+        target_boxes = torch.cat(target_boxes_list, dim=0)
+        target_obj_ids = torch.cat(target_obj_ids_list, dim=0)
         
         src_boxes = torch.cat([src_boxes, sdc_boxes], dim=0)
         target_boxes = torch.cat([target_boxes, target_sdc_boxes], dim=0)
 
-        # for pad target, don't calculate regression loss, judged by whether obj_id=-1
-        target_obj_ids = torch.cat(
-            [
-                gt_per_img.obj_ids[i]
-                for gt_per_img, (_, i) in zip(gt_instances, indices)
-            ],
-            dim=0,
-        )
-        # [num_matched]
-
         target_obj_ids = torch.cat([target_obj_ids, torch.zeros(1).to(target_obj_ids.device)], dim=0)
         mask = target_obj_ids != -1
         bbox_weights = torch.ones_like(target_boxes) * self.code_weights
@@ -372,11 +365,8 @@ class ClipMatcher(nn.Module):
         pred_past_trajs_i = track_instances.pred_past_trajs  # predicted past trajs of i-th image.
 
         obj_idxes = gt_instances_i.obj_ids
-        obj_idxes_list = obj_idxes.detach().cpu().numpy().tolist()
-        obj_idx_to_gt_idx = {
-            obj_idx: gt_idx
-            for gt_idx, obj_idx in enumerate(obj_idxes_list)
-        }
+        obj_idxes_npu = obj_idxes.clone()
+
         outputs_i = {
             "pred_logits": pred_logits_i.unsqueeze(0),
             "pred_sdc_logits": pred_sdc_logits_i,
@@ -386,19 +376,30 @@ class ClipMatcher(nn.Module):
         }
         # step1. inherit and update the previous tracks.
         num_disappear_track = 0
-        for j in range(len(track_instances)):
-            obj_id = track_instances.obj_idxes[j].item()
-            # set new target idx.
-            if obj_id >= 0:
-                if obj_id in obj_idx_to_gt_idx:
-                    track_instances.matched_gt_idxes[j] = obj_idx_to_gt_idx[
-                        obj_id]
-                else:
-                    num_disappear_track += 1
-                    track_instances.matched_gt_idxes[
-                        j] = -1  # track-disappear case.
-            else:
-                track_instances.matched_gt_idxes[j] = -1
+
+        obj_idxes = track_instances.obj_idxes
+        mask_valid = obj_idxes >= 0
+
+        current_max = obj_idxes.max().item() if mask_valid.any() else 0
+
+        lookup_size = max(current_max + 1, 1)
+        lookup_table = torch.full((lookup_size,), -1, dtype=torch.long, device=track_instances.obj_idxes.device)
+
+        if obj_idxes_npu.any():
+            valid_mask = obj_idxes_npu < lookup_size
+            valid_obj_idxes = obj_idxes_npu[valid_mask]
+            valid_gt_indices = torch.arange(len(obj_idxes_npu), device=obj_idxes_npu.device)[valid_mask]
+            lookup_table[valid_obj_idxes] = valid_gt_indices
+
+        matched_gt_idxes = torch.where(
+            mask_valid,
+            lookup_table[torch.clamp(obj_idxes, max=lookup_size-1)],
+            torch.tensor(-1, dtype=torch.long, device=track_instances.obj_idxes.device)
+        )
+
+        num_disappear_track = (mask_valid & (matched_gt_idxes == -1)).sum().item()
+        track_instances.matched_gt_idxes = matched_gt_idxes
+
 
         full_track_idxes = torch.arange(
             len(track_instances), dtype=torch.long).to(pred_logits_i.device)
@@ -438,15 +439,20 @@ class ClipMatcher(nn.Module):
             bs, num_querys = bbox_preds.shape[:2]
             # Also concat the target labels and boxes
             targets = [untracked_gt_instances]
+            gt_labels_list = []
+            gt_bboxes_list = []
             if isinstance(targets[0], Instances):
-                # [num_box], [num_box, 9] (un-normalized bboxes)
-                gt_labels = torch.cat(
-                    [gt_per_img.labels for gt_per_img in targets])
-                gt_bboxes = torch.cat(
-                    [gt_per_img.boxes for gt_per_img in targets])
+                for gt_per_img in targets:
+                    gt_labels_list.append(gt_per_img.labels)
+                    gt_bboxes_list.append(gt_per_img.boxes)
+                gt_labels = torch.cat(gt_labels_list)
+                gt_bboxes = torch.cat(gt_bboxes_list)
             else:
-                gt_labels = torch.cat([v["labels"] for v in targets])
-                gt_bboxes = torch.cat([v["boxes"] for v in targets])
+                for v in targets:
+                    gt_labels_list.append(v["labels"])
+                    gt_bboxes_list.append(v["boxes"])
+                gt_labels = torch.cat(gt_labels_list)
+                gt_bboxes = torch.cat(gt_bboxes_list)
 
             bbox_pred = bbox_preds[0]
             cls_pred = cls_preds[0]
diff --git a/projects/mmdet3d_plugin/losses/traj_loss.py b/projects/mmdet3d_plugin/losses/traj_loss.py
index 87b26ca..5650b49 100644
--- a/projects/mmdet3d_plugin/losses/traj_loss.py
+++ b/projects/mmdet3d_plugin/losses/traj_loss.py
@@ -116,6 +116,7 @@ def min_ade(traj: torch.Tensor, traj_gt: torch.Tensor,
     err = torch.pow(err, exponent=0.5)
     err = torch.sum(err * (1 - masks_rpt), dim=2) / \
         torch.clip(torch.sum((1 - masks_rpt), dim=2), min=1)
+    err = err.float()
     err, inds = torch.min(err, dim=1)
 
     return err, inds
@@ -195,6 +196,7 @@ def min_fde(traj: torch.Tensor, traj_gt: torch.Tensor,
     err = torch.pow(err, exponent=2)
     err = torch.sum(err, dim=2)
     err = torch.pow(err, exponent=0.5)
+    err = err.float()
     err, inds = torch.min(err, dim=1)
 
     return err, inds
@@ -226,6 +228,7 @@ def miss_rate(
     dist = torch.sum(dist, dim=3)
     dist = torch.pow(dist, exponent=0.5)
     dist[masks_rpt.bool()] = -math.inf
+    dist = dist.float()
     dist, _ = torch.max(dist, dim=2)
     dist, _ = torch.min(dist, dim=1)
     m_r = torch.sum(torch.as_tensor(dist > dist_thresh)) / len(dist)
diff --git a/projects/mmdet3d_plugin/models/utils/functional.py b/projects/mmdet3d_plugin/models/utils/functional.py
index b4ae933..fb557f9 100644
--- a/projects/mmdet3d_plugin/models/utils/functional.py
+++ b/projects/mmdet3d_plugin/models/utils/functional.py
@@ -100,6 +100,14 @@ def anchor_coordinate_transform(anchors, bbox_results, with_translation_transfor
             rot_yaw = rot_2d(angle) # num_agents, 2, 2
             rot_yaw = rot_yaw[:, None, None,:, :] # num_agents, 1, 1, 2, 2
             transformed_anchors = rearrange(transformed_anchors, 'b g m t c -> b g m c t')  # 1, num_groups, num_modes, 12, 2 -> 1, num_groups, num_modes, 2, 12
+
+            num_agents, _, _, _, _ = rot_yaw.shape
+            _, num_groups, num_modes, _, _ = transformed_anchors.shape
+            broadcast_shape1 = (num_agents, num_groups, num_modes, 2, 2)
+            broadcast_shape2 = (num_agents, num_groups, num_modes, 2, 12)
+            rot_yaw = rot_yaw.expand(broadcast_shape1)
+            transformed_anchors = transformed_anchors.expand(broadcast_shape2)
+
             transformed_anchors = torch.matmul(rot_yaw, transformed_anchors)# -> num_agents, num_groups, num_modes, 12, 2
             transformed_anchors = rearrange(transformed_anchors, 'b g m c t -> b g m t c')
         if with_translation_transform:
diff --git a/projects/mmdet3d_plugin/uniad/apis/mmdet_train.py b/projects/mmdet3d_plugin/uniad/apis/mmdet_train.py
index 3439cde..4beadc1 100644
--- a/projects/mmdet3d_plugin/uniad/apis/mmdet_train.py
+++ b/projects/mmdet3d_plugin/uniad/apis/mmdet_train.py
@@ -4,7 +4,7 @@ import os
 import numpy as np
 import torch
 import torch.distributed as dist
-from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
 from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
                          Fp16OptimizerHook, OptimizerHook, build_optimizer,
                          build_runner, get_dist_info)
@@ -67,22 +67,22 @@ def custom_train_detector(model,
         find_unused_parameters = cfg.get('find_unused_parameters', False)
         # Sets the `find_unused_parameters` parameter in
         # torch.nn.parallel.DistributedDataParallel
-        model = MMDistributedDataParallel(
+        model = NPUDistributedDataParallel(
             model.cuda(),
             device_ids=[torch.cuda.current_device()],
             broadcast_buffers=False,
             find_unused_parameters=find_unused_parameters)
         if eval_model is not None:
-            eval_model = MMDistributedDataParallel(
+            eval_model = NPUDistributedDataParallel(
                 eval_model.cuda(),
                 device_ids=[torch.cuda.current_device()],
                 broadcast_buffers=False,
                 find_unused_parameters=find_unused_parameters)
     else:
-        model = MMDataParallel(
+        model = NPUDataParallel(
             model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
         if eval_model is not None:
-            eval_model = MMDataParallel(
+            eval_model = NPUDataParallel(
                 eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
 
 
diff --git a/projects/mmdet3d_plugin/uniad/dense_heads/motion_head.py b/projects/mmdet3d_plugin/uniad/dense_heads/motion_head.py
index dd3612a..7215ab8 100644
--- a/projects/mmdet3d_plugin/uniad/dense_heads/motion_head.py
+++ b/projects/mmdet3d_plugin/uniad/dense_heads/motion_head.py
@@ -467,6 +467,7 @@ class MotionHead(BaseMotionHead):
         gt_fut_traj_mask_all = []
         for i in range(num_imgs):
             matched_gt_idx = all_matched_idxes[i]
+            matched_gt_idx = matched_gt_idx.detach().cpu()
             valid_traj_masks = matched_gt_idx >= 0
             matched_gt_fut_traj = gt_fut_traj[i][matched_gt_idx][valid_traj_masks]
             matched_gt_fut_traj_mask = gt_fut_traj_mask[i][matched_gt_idx][valid_traj_masks]
diff --git a/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py b/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py
index 9aaee7e..45e8dae 100644
--- a/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py
+++ b/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py
@@ -19,6 +19,7 @@ from mmcv.cnn.bricks.drop import build_dropout
 from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
 from mmcv.utils import ConfigDict, deprecated_api_warning
 from projects.mmdet3d_plugin.uniad.modules.multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32
+import mx_driving.fused
 
 
 @TRANSFORMER_LAYER.register_module()
@@ -453,14 +454,8 @@ class MotionDeformableAttention(BaseModule):
                 f' 2 or 4, but get {reference_trajs.shape[-1]} instead.')
         if torch.cuda.is_available() and value.is_cuda:
 
-            # using fp16 deformable attention is unstable because it performs many sum operations
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                                         sampling_locations, attention_weights)
         else:
             output = multi_scale_deformable_attn_pytorch(
                 value, spatial_shapes, sampling_locations, attention_weights)
diff --git a/projects/mmdet3d_plugin/uniad/dense_heads/occ_head_plugin/metrics.py b/projects/mmdet3d_plugin/uniad/dense_heads/occ_head_plugin/metrics.py
index aa61597..ae762f8 100644
--- a/projects/mmdet3d_plugin/uniad/dense_heads/occ_head_plugin/metrics.py
+++ b/projects/mmdet3d_plugin/uniad/dense_heads/occ_head_plugin/metrics.py
@@ -21,7 +21,7 @@ class IntersectionOverUnion(Metric):
         reduction: str = 'none',
         compute_on_step: bool = False,
     ):
-        super().__init__(compute_on_step=compute_on_step)
+        super().__init__()
 
         self.n_classes = n_classes
         self.ignore_index = ignore_index
diff --git a/projects/mmdet3d_plugin/uniad/dense_heads/planning_head_plugin/planning_metrics.py b/projects/mmdet3d_plugin/uniad/dense_heads/planning_head_plugin/planning_metrics.py
index fdcf079..d4dc144 100644
--- a/projects/mmdet3d_plugin/uniad/dense_heads/planning_head_plugin/planning_metrics.py
+++ b/projects/mmdet3d_plugin/uniad/dense_heads/planning_head_plugin/planning_metrics.py
@@ -108,9 +108,16 @@ class PlanningMetric(Metric):
             m1 = torch.logical_and(m1, torch.logical_not(gt_box_coll))
 
             ti = torch.arange(n_future)
+
+            yi = yi.detach().cpu()
+            xi = xi.detach().cpu()
+            m1 = m1.detach().cpu()
+
             obj_coll_sum[ti[m1]] += segmentation[i, ti[m1], yi[m1], xi[m1]].long()
 
             m2 = torch.logical_not(gt_box_coll)
+            m2 = m2.detach().cpu()
+
             box_coll = self.evaluate_single_coll(trajs[i], segmentation[i])
             obj_box_coll_sum[ti[m2]] += (box_coll[ti[m2]]).long()
 
diff --git a/projects/mmdet3d_plugin/uniad/dense_heads/seg_head_plugin/seg_assigner.py b/projects/mmdet3d_plugin/uniad/dense_heads/seg_head_plugin/seg_assigner.py
index ebfd3b8..192482a 100644
--- a/projects/mmdet3d_plugin/uniad/dense_heads/seg_head_plugin/seg_assigner.py
+++ b/projects/mmdet3d_plugin/uniad/dense_heads/seg_head_plugin/seg_assigner.py
@@ -291,6 +291,9 @@ class HungarianAssigner_filter(BaseAssigner):
             assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
             if i == 0:
                 result = AssignResult(num_gts, assigned_gt_inds.clone(), None, labels=assigned_labels.clone())
+
+            matched_row_inds = matched_row_inds.detach().cpu()
+            matched_col_inds = matched_col_inds.detach().cpu()
             if cost[matched_row_inds,matched_col_inds].max()>=INF:
                 break
         pos_ind = assigned_gt_inds.gt(0).nonzero().squeeze(1)
diff --git a/projects/mmdet3d_plugin/uniad/detectors/uniad_e2e.py b/projects/mmdet3d_plugin/uniad/detectors/uniad_e2e.py
index 78a4a87..d87eb85 100644
--- a/projects/mmdet3d_plugin/uniad/detectors/uniad_e2e.py
+++ b/projects/mmdet3d_plugin/uniad/detectors/uniad_e2e.py
@@ -224,6 +224,7 @@ class UniAD(UniADTrack):
             losses.update(losses_planning)
         
         for k,v in losses.items():
+            v = v.float()
             losses[k] = torch.nan_to_num(v)
         return losses
     
diff --git a/projects/mmdet3d_plugin/uniad/modules/decoder.py b/projects/mmdet3d_plugin/uniad/modules/decoder.py
index 33024f8..b18d4f1 100644
--- a/projects/mmdet3d_plugin/uniad/modules/decoder.py
+++ b/projects/mmdet3d_plugin/uniad/modules/decoder.py
@@ -26,6 +26,7 @@ from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
 from mmcv.utils import ext_loader
 from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
     MultiScaleDeformableAttnFunction_fp16
+import mx_driving.fused
 
 ext_module = ext_loader.load_ext(
     '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@@ -324,14 +325,8 @@ class CustomMSDeformableAttention(BaseModule):
                 f' 2 or 4, but get {reference_points.shape[-1]} instead.')
         if torch.cuda.is_available() and value.is_cuda:
 
-            # using fp16 deformable attention is unstable because it performs many sum operations
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                                         sampling_locations, attention_weights)
         else:
             output = multi_scale_deformable_attn_pytorch(
                 value, spatial_shapes, sampling_locations, attention_weights)
diff --git a/projects/mmdet3d_plugin/uniad/modules/encoder.py b/projects/mmdet3d_plugin/uniad/modules/encoder.py
index 6875233..5e84f20 100644
--- a/projects/mmdet3d_plugin/uniad/modules/encoder.py
+++ b/projects/mmdet3d_plugin/uniad/modules/encoder.py
@@ -117,8 +117,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
         lidar2img = lidar2img.view(
             1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
 
-        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
-                                            reference_points.to(torch.float32)).squeeze(-1)
+        reference_points_cam = torch.mul(lidar2img.to(torch.float32),
+                                         reference_points.to(torch.float32).transpose(-1, -2)).sum(-1, keepdim=True).squeeze(-1)
         eps = 1e-5
 
         bev_mask = (reference_points_cam[..., 2:3] > eps)
diff --git a/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py b/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py
index 77dfa91..8b32eed 100644
--- a/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py
+++ b/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py
@@ -23,6 +23,13 @@ from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
 from mmcv.utils import ext_loader
 from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
     MultiScaleDeformableAttnFunction_fp16
+import mx_driving.fused
+
+indexes_global = None
+max_len_global = None
+bev_mask_id_global = -1
+count_global = None
+
 ext_module = ext_loader.load_ext(
     '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
 
@@ -134,10 +141,28 @@ class SpatialCrossAttention(BaseModule):
 
         D = reference_points_cam.size(3)
         indexes = []
-        for i, mask_per_img in enumerate(bev_mask):
-            index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
-            indexes.append(index_query_per_img)
-        max_len = max([len(each) for each in indexes])
+
+        global indexes_global, max_len_global, bev_mask_id_global, count_global
+        bev_mask_id = id(bev_mask)
+        if bev_mask_id == bev_mask_id_global:
+            indexes = indexes_global
+            max_len = max_len_global
+            count = count_global
+        else:
+            count = torch.any(bev_mask, 3)
+            bev_mask_ = count.squeeze()
+            for i, mask_per_img in enumerate(bev_mask_):
+                index_query_per_img = mask_per_img.nonzero().squeeze(-1)
+                indexes.append(index_query_per_img)
+
+            max_len = max([len(each) for each in indexes])
+            count = count.permute(1, 2, 0).sum(-1)
+            count = torch.clamp(count, min=1.0)
+            count = count[..., None]
+            count_global = count
+            indexes_global = indexes
+            max_len_global = max_len
+            bev_mask_id_global = bev_mask_id
 
         # each camera only interacts with its corresponding BEV queries. This step can  greatly save GPU memory.
         queries_rebatch = query.new_zeros(
@@ -145,9 +170,9 @@ class SpatialCrossAttention(BaseModule):
         reference_points_rebatch = reference_points_cam.new_zeros(
             [bs, self.num_cams, max_len, D, 2])
         
-        for j in range(bs):
-            for i, reference_points_per_img in enumerate(reference_points_cam):   
-                index_query_per_img = indexes[i]
+        for i, reference_points_per_img in enumerate(reference_points_cam):
+            index_query_per_img = indexes[i]
+            for j in range(bs):
                 queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
                 reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
 
@@ -165,10 +190,7 @@ class SpatialCrossAttention(BaseModule):
             for i, index_query_per_img in enumerate(indexes):
                 slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
 
-        count = bev_mask.sum(-1) > 0
-        count = count.permute(1, 2, 0).sum(-1)
-        count = torch.clamp(count, min=1.0)
-        slots = slots / count[..., None]
+        slots = slots / count
         slots = self.output_proj(slots)
 
         return self.dropout(slots) + inp_residual
@@ -382,13 +404,8 @@ class MSDeformableAttention3D(BaseModule):
         #
 
         if torch.cuda.is_available() and value.is_cuda:
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                                         sampling_locations, attention_weights)
         else:
             output = multi_scale_deformable_attn_pytorch(
                 value, spatial_shapes, sampling_locations, attention_weights)
diff --git a/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py b/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py
index f846b4b..5f65b63 100644
--- a/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py
+++ b/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py
@@ -17,6 +17,8 @@ from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
                         to_2tuple)
 
 from mmcv.utils import ext_loader
+import mx_driving.fused
+
 ext_module = ext_loader.load_ext(
     '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
 
@@ -236,14 +238,8 @@ class TemporalSelfAttention(BaseModule):
                 f' 2 or 4, but get {reference_points.shape[-1]} instead.')
         if torch.cuda.is_available() and value.is_cuda:
 
-            # using fp16 deformable attention is unstable because it performs many sum operations
-            if value.dtype == torch.float16:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            else:
-                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
-            output = MultiScaleDeformableAttnFunction.apply(
-                value, spatial_shapes, level_start_index, sampling_locations,
-                attention_weights, self.im2col_step)
+            output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+                                                                         sampling_locations, attention_weights)
         else:
 
             output = multi_scale_deformable_attn_pytorch(
diff --git a/projects/mmdet3d_plugin/uniad/modules/transformer.py b/projects/mmdet3d_plugin/uniad/modules/transformer.py
index bb5fae8..adaf13e 100644
--- a/projects/mmdet3d_plugin/uniad/modules/transformer.py
+++ b/projects/mmdet3d_plugin/uniad/modules/transformer.py
@@ -14,7 +14,7 @@ from mmcv.runner.base_module import BaseModule
 from mmdet.models.utils.builder import TRANSFORMER
 from torch.nn.init import normal_
 from mmcv.runner.base_module import BaseModule
-from torchvision.transforms.functional import rotate
+from torchvision.transforms.functional import InterpolationMode, rotate
 from .temporal_self_attention import TemporalSelfAttention
 from .spatial_cross_attention import MSDeformableAttention3D
 from .decoder import CustomMSDeformableAttention
@@ -142,7 +142,7 @@ class PerceptionTransformer(BaseModule):
                     rotation_angle = img_metas[i]['can_bus'][-1]
                     tmp_prev_bev = prev_bev[:, i].reshape(
                         bev_h, bev_w, -1).permute(2, 0, 1)
-                    tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
+                    tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle, InterpolationMode.BILINEAR,
                                           center=self.rotate_center)
                     tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
                         bev_h * bev_w, 1, -1)
diff --git a/requirements.txt b/requirements.txt
index a005cc2..d02f784 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,11 @@
+mmdet==2.28.0
+mmsegmentation==0.30.0
 google-cloud-bigquery
 motmetrics==1.1.3
 einops==0.4.1
-numpy==1.20.0
+numpy==1.23.0
 casadi==3.5.5
-pytorch-lightning==1.2.5
\ No newline at end of file
+pytorch-lightning==1.2.5
+torchmetrics==0.11.4
+numba==0.58.1
+IPython
\ No newline at end of file
diff --git a/tools/perf.py b/tools/perf.py
new file mode 100644
index 0000000..ee2ddcc
--- /dev/null
+++ b/tools/perf.py
@@ -0,0 +1,260 @@
+from __future__ import division
+
+import argparse
+import cv2
+import torch
+import sklearn
+import mmcv
+import copy
+import os
+import time
+import warnings
+from mx_driving.patcher import default_patcher_builder
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist
+from os import path as osp
+
+from mmdet import __version__ as mmdet_version
+from mmdet3d import __version__ as mmdet3d_version
+
+from mmdet3d.datasets import build_dataset
+from mmdet3d.models import build_model
+from mmdet3d.utils import collect_env, get_root_logger
+from mmdet.apis import set_random_seed
+from mmseg import __version__ as mmseg_version
+
+warnings.filterwarnings("ignore")
+
+from mmcv.utils import TORCH_VERSION, digit_version
+
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.config.allow_internal_format = False
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train a detector')
+    parser.add_argument('config', help='train config file path')
+    parser.add_argument('--work-dir', help='the dir to save logs and models')
+    parser.add_argument(
+        '--resume-from', help='the checkpoint file to resume from')
+    parser.add_argument(
+        '--no-validate',
+        action='store_true',
+        help='whether not to evaluate the checkpoint during training')
+    group_gpus = parser.add_mutually_exclusive_group()
+    group_gpus.add_argument(
+        '--gpus',
+        type=int,
+        help='number of gpus to use '
+        '(only applicable to non-distributed training)')
+    group_gpus.add_argument(
+        '--gpu-ids',
+        type=int,
+        nargs='+',
+        help='ids of gpus to use '
+        '(only applicable to non-distributed training)')
+    parser.add_argument('--seed', type=int, default=0, help='random seed')
+    parser.add_argument(
+        '--deterministic',
+        action='store_true',
+        help='whether to set deterministic options for CUDNN backend.')
+    parser.add_argument(
+        '--options',
+        nargs='+',
+        action=DictAction,
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file (deprecate), '
+        'change to --cfg-options instead.')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. If the value to '
+        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+        'Note that the quotation marks are necessary and that no white space '
+        'is allowed.')
+    parser.add_argument(
+        '--launcher',
+        choices=['none', 'pytorch', 'slurm', 'mpi'],
+        default='none',
+        help='job launcher')
+    parser.add_argument('--local_rank', type=int, default=0)
+    parser.add_argument(
+        '--autoscale-lr',
+        action='store_true',
+        help='automatically scale lr with the number of gpus')
+    args = parser.parse_args()
+    if 'LOCAL_RANK' not in os.environ:
+        os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+    if args.options and args.cfg_options:
+        raise ValueError(
+            '--options and --cfg-options cannot be both specified, '
+            '--options is deprecated in favor of --cfg-options')
+    if args.options:
+        warnings.warn('--options is deprecated in favor of --cfg-options')
+        args.cfg_options = args.options
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    cfg = Config.fromfile(args.config)
+    if args.cfg_options is not None:
+        cfg.merge_from_dict(args.cfg_options)
+    # import modules from string list.
+    if cfg.get('custom_imports', None):
+        from mmcv.utils import import_modules_from_strings
+        import_modules_from_strings(**cfg['custom_imports'])
+
+    # import modules from plguin/xx, registry will be updated
+    if hasattr(cfg, 'plugin'):
+        if cfg.plugin:
+            import importlib
+            if hasattr(cfg, 'plugin_dir'):
+                plugin_dir = cfg.plugin_dir
+                _module_dir = os.path.dirname(plugin_dir)
+                _module_dir = _module_dir.split('/')
+                _module_path = _module_dir[0]
+
+                for m in _module_dir[1:]:
+                    _module_path = _module_path + '.' + m
+                print(_module_path)
+                plg_lib = importlib.import_module(_module_path)
+            else:
+                # import dir is the dirpath for the config file
+                _module_dir = os.path.dirname(args.config)
+                _module_dir = _module_dir.split('/')
+                _module_path = _module_dir[0]
+                for m in _module_dir[1:]:
+                    _module_path = _module_path + '.' + m
+                print(_module_path)
+                plg_lib = importlib.import_module(_module_path)
+
+            from projects.mmdet3d_plugin.uniad.apis.train import custom_train_model
+    # set cudnn_benchmark
+    if cfg.get('cudnn_benchmark', False):
+        torch.backends.cudnn.benchmark = True
+
+    # work_dir is determined in this priority: CLI > segment in file > filename
+    if args.work_dir is not None:
+        # update configs according to CLI args if args.work_dir is not None
+        cfg.work_dir = args.work_dir
+    elif cfg.get('work_dir', None) is None:
+        # use config filename as default work_dir if cfg.work_dir is None
+        cfg.work_dir = osp.join('./work_dirs',
+                                osp.splitext(osp.basename(args.config))[0])
+    # if args.resume_from is not None:
+    if args.resume_from is not None and osp.isfile(args.resume_from):
+        cfg.resume_from = args.resume_from
+    if args.gpu_ids is not None:
+        cfg.gpu_ids = args.gpu_ids
+    else:
+        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+    if digit_version(TORCH_VERSION) == digit_version('1.8.1') and cfg.optimizer['type'] == 'AdamW':
+        cfg.optimizer['type'] = 'AdamW2' # fix bug in Adamw
+    if args.autoscale_lr:
+        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
+
+    # init distributed env first, since logger depends on the dist info.
+    if args.launcher == 'none':
+        distributed = False
+    else:
+        distributed = True
+        init_dist(args.launcher, **cfg.dist_params)
+        # re-set gpu_ids with distributed training mode
+        _, world_size = get_dist_info()
+        cfg.gpu_ids = range(world_size)
+
+    # create work_dir
+    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+    # dump config
+    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+    # init the logger before other steps
+    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+    # specify logger name, if we still use 'mmdet', the output info will be
+    # filtered and won't be saved in the log_file
+    # TODO: ugly workaround to judge whether we are training det or seg model
+    if cfg.model.type in ['EncoderDecoder3D']:
+        logger_name = 'mmseg'
+    else:
+        logger_name = 'mmdet'
+    logger = get_root_logger(
+        log_file=log_file, log_level=cfg.log_level, name=logger_name)
+
+    # init the meta dict to record some important information such as
+    # environment info and seed, which will be logged
+    meta = dict()
+    # log env info
+    env_info_dict = collect_env()
+    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+    dash_line = '-' * 60 + '\n'
+    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+                dash_line)
+    meta['env_info'] = env_info
+    meta['config'] = cfg.pretty_text
+
+    # log some basic info
+    logger.info(f'Distributed training: {distributed}')
+    logger.info(f'Config:\n{cfg.pretty_text}')
+
+    # set random seeds
+    if args.seed is not None:
+        logger.info(f'Set random seed to {args.seed}, '
+                    f'deterministic: {args.deterministic}')
+        set_random_seed(args.seed, deterministic=args.deterministic)
+    cfg.seed = args.seed
+    meta['seed'] = args.seed
+    meta['exp_name'] = osp.basename(args.config)
+
+    model = build_model(
+        cfg.model,
+        train_cfg=cfg.get('train_cfg'),
+        test_cfg=cfg.get('test_cfg'))
+    model.init_weights()
+
+    logger.info(f'Model:\n{model}')
+    datasets = [build_dataset(cfg.data.train)]
+    if len(cfg.workflow) == 2:
+        val_dataset = copy.deepcopy(cfg.data.val)
+        # in case we use a dataset wrapper
+        if 'dataset' in cfg.data.train:
+            val_dataset.pipeline = cfg.data.train.dataset.pipeline
+        else:
+            val_dataset.pipeline = cfg.data.train.pipeline
+        # set test_mode=False here in deep copied config
+        # which do not affect AP/AR calculation later
+        val_dataset.test_mode = False
+        datasets.append(build_dataset(val_dataset))
+    if cfg.checkpoint_config is not None:
+        # save mmdet version, config file content and class names in
+        # checkpoints as meta data
+        cfg.checkpoint_config.meta = dict(
+            mmdet_version=mmdet_version,
+            mmseg_version=mmseg_version,
+            mmdet3d_version=mmdet3d_version,
+            config=cfg.pretty_text,
+            CLASSES=datasets[0].CLASSES,
+            PALETTE=datasets[0].PALETTE  # for segmentors
+            if hasattr(datasets[0], 'PALETTE') else None)
+    # add an attribute for visualization convenience
+    model.CLASSES = datasets[0].CLASSES
+    custom_train_model(
+        model,
+        datasets,
+        cfg,
+        distributed=distributed,
+        validate=(not args.no_validate),
+        timestamp=timestamp,
+        meta=meta)
+
+
+if __name__ == '__main__':
+    with default_patcher_builder.disable_patches("index").brake_at(500).build():
+        main()
diff --git a/tools/test.py b/tools/test.py
index d4bf51d..e83fec3 100755
--- a/tools/test.py
+++ b/tools/test.py
@@ -7,7 +7,7 @@ import os
 import warnings
 from mmcv import Config, DictAction
 from mmcv.cnn import fuse_conv_bn
-from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
 from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                          wrap_fp16_model)
 
@@ -224,7 +224,7 @@ def main():
         # model = MMDataParallel(model, device_ids=[0])
         # outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
     else:
-        model = MMDistributedDataParallel(
+        model = NPUDistributedDataParallel(
             model.cuda(),
             device_ids=[torch.cuda.current_device()],
             broadcast_buffers=False)
diff --git a/tools/train.py b/tools/train.py
index f240c5a..24fe7da 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -9,6 +9,7 @@ import copy
 import os
 import time
 import warnings
+from mx_driving.patcher import default_patcher_builder
 from mmcv import Config, DictAction
 from mmcv.runner import get_dist_info, init_dist
 from os import path as osp
@@ -26,6 +27,10 @@ warnings.filterwarnings("ignore")
 
 from mmcv.utils import TORCH_VERSION, digit_version
 
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.config.allow_internal_format = False
 
 def parse_args():
     parser = argparse.ArgumentParser(description='Train a detector')
@@ -253,4 +258,5 @@ def main():
 
 
 if __name__ == '__main__':
-    main()
+    with default_patcher_builder.disable_patches("index").build():
+        main()
diff --git a/tools/uniad_dist_perf.sh b/tools/uniad_dist_perf.sh
new file mode 100644
index 0000000..e71741a
--- /dev/null
+++ b/tools/uniad_dist_perf.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+T=`date +%m%d%H%M`
+
+# -------------------------------------------------- #
+# Usually you only need to customize these variables #
+CFG=$1                                               #
+GPUS=$2                                              #
+# -------------------------------------------------- #
+GPUS_PER_NODE=$(($GPUS<8?$GPUS:8))
+NNODES=`expr $GPUS / $GPUS_PER_NODE`
+
+MASTER_PORT=${MASTER_PORT:-28596}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+RANK=${RANK:-0}
+
+WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
+# Intermediate files and logs will be saved to UniAD/projects/work_dirs/
+
+if [ ! -d ${WORK_DIR}logs ]; then
+    mkdir -p ${WORK_DIR}logs
+fi
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+torchrun \
+    --nproc_per_node=${GPUS_PER_NODE} \
+    --master_addr=${MASTER_ADDR} \
+    --master_port=${MASTER_PORT} \
+    --nnodes=${NNODES} \
+    --node_rank=${RANK} \
+    $(dirname "$0")/perf.py \
+    $CFG \
+    --launcher pytorch ${@:3} \
+    --deterministic \
+    --work-dir ${WORK_DIR} \
+    2>&1 | tee ${WORK_DIR}logs/train.$T
\ No newline at end of file
diff --git a/tools/uniad_dist_train.sh b/tools/uniad_dist_train.sh
index 2febf37..9a4431b 100755
--- a/tools/uniad_dist_train.sh
+++ b/tools/uniad_dist_train.sh
@@ -22,7 +22,7 @@ if [ ! -d ${WORK_DIR}logs ]; then
 fi
 
 PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
-python -m torch.distributed.launch \
+torchrun \
     --nproc_per_node=${GPUS_PER_NODE} \
     --master_addr=${MASTER_ADDR} \
     --master_port=${MASTER_PORT} \