new file mode 100644
@@ -0,0 +1,139 @@
+"""
+nuScenes dev-kit.
+Code written by Holger Caesar, Caglayan Dicle and Oscar Beijbom, 2019.
+
+This code is based on:
+
+py-motmetrics at:
+"""
+from collections import OrderedDict
+from itertools import count
+
+import motmetrics
+import numpy as np
+import pandas as pd
+
+
+class MOTAccumulatorCustom(motmetrics.mot.MOTAccumulator):
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def new_event_dataframe_with_data(indices, events):
+ """
+ Create a new DataFrame filled with data.
+ This version overwrites the original in MOTAccumulator achieves about 2x speedups.
+
+ Params
+ ------
+ indices: list
+ list of tuples (frameid, eventid)
+ events: list
+ list of events where each event is a list containing
+ 'Type', 'OId', HId', 'D'
+ """
+ idx = pd.MultiIndex.from_tuples(indices, names=['FrameId', 'Event'])
+ df = pd.DataFrame(events, index=idx, columns=['Type', 'OId', 'HId', 'D'])
+ return df
+
+ @staticmethod
+ def new_event_dataframe():
+ """ Create a new DataFrame for event tracking. """
+ idx = pd.MultiIndex(levels=[[], []], codes=[[], []], names=['FrameId', 'Event'])
+ cats = pd.Categorical([], categories=['RAW', 'FP', 'MISS', 'SWITCH', 'MATCH'])
+ df = pd.DataFrame(
+ OrderedDict([
+ ('Type', pd.Series(cats)), # Type of event. One of FP (false positive), MISS, SWITCH, MATCH
+ ('OId', pd.Series(dtype=object)),
+ # Object ID or -1 if FP. Using float as missing values will be converted to NaN anyways.
+ ('HId', pd.Series(dtype=object)),
+ # Hypothesis ID or NaN if MISS. Using float as missing values will be converted to NaN anyways.
+ ('D', pd.Series(dtype=float)), # Distance or NaN when FP or MISS
+ ]),
+ index=idx
+ )
+ return df
+
+ @property
+ def events(self):
+ if self.dirty_events:
+ self.cached_events_df = MOTAccumulatorCustom.new_event_dataframe_with_data(self._indices, self._events)
+ self.dirty_events = False
+ return self.cached_events_df
+
+ @staticmethod
+ def merge_event_dataframes(dfs, update_frame_indices=True, update_oids=True, update_hids=True,
+ return_mappings=False):
+ """Merge dataframes.
+
+ Params
+ ------
+ dfs : list of pandas.DataFrame or MotAccumulator
+ A list of event containers to merge
+
+ Kwargs
+ ------
+ update_frame_indices : boolean, optional
+ Ensure that frame indices are unique in the merged container
+ update_oids : boolean, unique
+ Ensure that object ids are unique in the merged container
+ update_hids : boolean, unique
+ Ensure that hypothesis ids are unique in the merged container
+ return_mappings : boolean, unique
+ Whether or not to return mapping information
+
+ Returns
+ -------
+ df : pandas.DataFrame
+ Merged event data frame
+ """
+
+ mapping_infos = []
+ new_oid = count()
+ new_hid = count()
+
+ r = MOTAccumulatorCustom.new_event_dataframe()
+ for df in dfs:
+
+ if isinstance(df, MOTAccumulatorCustom):
+ df = df.events
+
+ copy = df.copy()
+ infos = {}
+
+ # Update index
+ if update_frame_indices:
+ if r.index.get_level_values(0).size > 0 and isinstance(r.index.get_level_values(0)[0], tuple):
+ index_temp = []
+ for item in r.index.get_level_values(0):
+ index_temp += list(item)
+ index_temp = np.array(index_temp)
+ index_temp_unique = np.unique(index_temp)
+ next_frame_id = max(np.max(index_temp)+1, index_temp_unique.shape[0])
+ else:
+ next_frame_id = max(r.index.get_level_values(0).max() + 1,
+ r.index.get_level_values(0).unique().shape[0])
+
+ if np.isnan(next_frame_id):
+ next_frame_id = 0
+ copy.index = copy.index.map(lambda x: (x[0] + next_frame_id, x[1]))
+ infos['frame_offset'] = next_frame_id
+
+ # Update object / hypothesis ids
+ if update_oids:
+ oid_map = dict([oid, str(next(new_oid))] for oid in copy['OId'].dropna().unique())
+ copy['OId'] = copy['OId'].map(lambda x: oid_map[x], na_action='ignore')
+ infos['oid_map'] = oid_map
+
+ if update_hids:
+ hid_map = dict([hid, str(next(new_hid))] for hid in copy['HId'].dropna().unique())
+ copy['HId'] = copy['HId'].map(lambda x: hid_map[x], na_action='ignore')
+ infos['hid_map'] = hid_map
+
+ r = pd.concat((r, copy))
+ mapping_infos.append(infos)
+
+ if return_mappings:
+ return r, mapping_infos
+ else:
+ return r
@@ -555,7 +555,7 @@ data = dict(
nonshuffler_sampler=dict(type="DistributedSampler"),
)
optimizer = dict(
- type="AdamW",
+ type="NpuFusedAdamW",
lr=2e-4,
paramwise_cfg=dict(
custom_keys={
@@ -564,7 +564,7 @@ optimizer = dict(
),
weight_decay=0.01,
)
-optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+optimizer_config = dict(type='GradientCumulativeOptimizerHook', cumulative_iters=2, grad_clip=dict(max_norm=70, norm_type=2))
# learning policy
lr_config = dict(
policy="CosineAnnealing",
@@ -573,15 +573,11 @@ lr_config = dict(
warmup_ratio=1.0 / 3,
min_lr_ratio=1e-3,
)
-total_epochs = 6
-evaluation = dict(
- interval=6,
- pipeline=test_pipeline,
- planning_evaluation_strategy=planning_evaluation_strategy,
-)
+total_epochs = 4
+evaluation = dict(interval=4, pipeline=test_pipeline)
runner = dict(type="EpochBasedRunner", max_epochs=total_epochs)
log_config = dict(
- interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
+ interval=1, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
)
checkpoint_config = dict(interval=1)
load_from = "ckpts/bevformer_r101_dcn_24ep.pth"
@@ -670,7 +670,7 @@ data = dict(
nonshuffler_sampler=dict(type="DistributedSampler"),
)
optimizer = dict(
- type="AdamW",
+ type="NpuFusedAdamW",
lr=2e-4,
paramwise_cfg=dict(
custom_keys={
@@ -679,7 +679,7 @@ optimizer = dict(
),
weight_decay=0.01,
)
-optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+optimizer_config = dict(type='GradientCumulativeOptimizerHook', cumulative_iters=2, grad_clip=dict(max_norm=70, norm_type=2))
# learning policy
lr_config = dict(
policy="CosineAnnealing",
@@ -688,15 +688,11 @@ lr_config = dict(
warmup_ratio=1.0 / 3,
min_lr_ratio=1e-3,
)
-total_epochs = 20
-evaluation = dict(
- interval=4,
- pipeline=test_pipeline,
- planning_evaluation_strategy=planning_evaluation_strategy,
-)
+total_epochs = 4
+evaluation = dict(interval=4, pipeline=test_pipeline)
runner = dict(type="EpochBasedRunner", max_epochs=total_epochs)
log_config = dict(
- interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
+ interval=1, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
)
checkpoint_config = dict(interval=1)
load_from = "ckpts/uniad_base_track_map.pth"
@@ -28,7 +28,7 @@ class BBox3DL1Cost(object):
return bbox_cost * self.weight
-@MATCH_COST.register_module()
+@MATCH_COST.register_module(force=True)
class DiceCost(object):
"""IoUCost.
@@ -24,30 +24,15 @@ def normalize_bbox(bboxes, pc_range):
return normalized_bboxes
def denormalize_bbox(normalized_bboxes, pc_range):
- # rotation
- rot_sine = normalized_bboxes[..., 6:7]
+ # rotation
+ cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy = normalized_bboxes.split(1, dim=-1)
- rot_cosine = normalized_bboxes[..., 7:8]
rot = torch.atan2(rot_sine, rot_cosine)
- # center in the bev
- cx = normalized_bboxes[..., 0:1]
- cy = normalized_bboxes[..., 1:2]
- cz = normalized_bboxes[..., 4:5]
-
- # size
- w = normalized_bboxes[..., 2:3]
- l = normalized_bboxes[..., 3:4]
- h = normalized_bboxes[..., 5:6]
-
- w = w.exp()
- l = l.exp()
- h = h.exp()
- if normalized_bboxes.size(-1) > 8:
- # velocity
- vx = normalized_bboxes[:, 8:9]
- vy = normalized_bboxes[:, 9:10]
- denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
- else:
- denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)
+ w = w.exp()
+ l = l.exp()
+ h = h.exp()
+
+ denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
+
return denormalized_bboxes
\ No newline at end of file
@@ -1025,7 +1025,6 @@ class NuScenesE2EDataset(NuScenesDataset):
if 'planning_results_computed' in results.keys():
planning_results_computed = results['planning_results_computed']
planning_tab = PrettyTable()
- planning_tab.title = f"{planning_evaluation_strategy}'s definition planning metrics"
planning_tab.field_names = [
"metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s"]
for key in planning_results_computed.keys():
@@ -1033,14 +1032,7 @@ class NuScenesE2EDataset(NuScenesDataset):
row_value = []
row_value.append(key)
for i in range(len(value)):
- if planning_evaluation_strategy == "stp3":
- row_value.append("%.4f" % float(value[: i + 1].mean()))
- elif planning_evaluation_strategy == "uniad":
row_value.append("%.4f" % float(value[i]))
- else:
- raise ValueError(
- "planning_evaluation_strategy should be uniad or spt3"
- )
planning_tab.add_row(row_value)
print(planning_tab)
results = results['bbox_results'] # get bbox_results
@@ -1,5 +1,4 @@
import torch
-import torch
import torch.nn as nn
from mmdet.models.losses.utils import weighted_loss
@@ -21,7 +20,7 @@ def dice_loss(input, target,mask=None,eps=0.001):
d = (2 * a) / (b + c)
return 1 - d
-@LOSSES.register_module()
+@LOSSES.register_module(force=True)
class DiceLoss(nn.Module):
def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
@@ -40,20 +40,41 @@ class CollisionLoss(nn.Module):
# sdc_planning_gt_mask (1, 6)
# future_gt_bbox 6x[lidarboxinstance]
n_futures = len(future_gt_bbox)
- inter_sum = sdc_traj_all.new_zeros(1, )
- dump_sdc = []
+ inter_sum = torch.tensor([0.], device=sdc_traj_all.device, dtype=sdc_traj_all.dtype)
for i in range(n_futures):
if len(future_gt_bbox[i].tensor) > 0:
- future_gt_bbox_corners = future_gt_bbox[i].corners[:, [0,3,4,7], :2] # (N, 8, 3) -> (N, 4, 2) only bev
+ future_gt_corners = future_gt_bbox[i].corners[:, [0,3,4,7], :2].npu() # (N, 8, 3) -> (N, 4, 2) only bev
# sdc_yaw = -sdc_planning_gt[0, i, 2].to(sdc_traj_all.dtype) - 1.5708
sdc_yaw = sdc_planning_gt[0, i, 2].to(sdc_traj_all.dtype)
sdc_bev_box = self.to_corners([sdc_traj_all[0, i, 0], sdc_traj_all[0, i, 1], self.w, self.h, sdc_yaw])
- dump_sdc.append(sdc_bev_box.cpu().detach().numpy())
- for j in range(future_gt_bbox_corners.shape[0]):
- inter_sum += self.inter_bbox(sdc_bev_box, future_gt_bbox_corners[j].to(sdc_traj_all.device))
+ sdc_min = sdc_bev_box.min(dim=0)[0] # (2,)
+ sdc_max = sdc_bev_box.max(dim=0)[0] # (2,)
+
+ # Compute min/max for all target boxes (N, 2)
+ target_min = future_gt_corners.min(dim=1)[0] # (N, 2)
+ target_max = future_gt_corners.max(dim=1)[0] # (N, 2)
+
+ # Compute intersection for all boxes (vectorized)
+ intersect_min = torch.maximum(sdc_min, target_min) # (N, 2)
+ intersect_max = torch.minimum(sdc_max, target_max) # (N, 2)
+ intersect_dims = intersect_max - intersect_min # (N, 2)
+
+ # Clamp negative values to 0 and compute area
+ intersect_area = torch.prod(
+ torch.clamp_min(intersect_dims, 0),
+ dim=1
+ ) # (N,)
+
+ # Sum all intersections for this frame
+ inter_sum += intersect_area.sum()
+
return inter_sum * self.weight
def inter_bbox(self, corners_a, corners_b):
+ device_stack = corners_a.device
+ corners_a = corners_a.cpu()
+ corners_b = corners_b.cpu()
+
xa1, ya1 = torch.max(corners_a[:, 0]), torch.max(corners_a[:, 1])
xa2, ya2 = torch.min(corners_a[:, 0]), torch.min(corners_a[:, 1])
xb1, yb1 = torch.max(corners_b[:, 0]), torch.max(corners_b[:, 1])
@@ -61,7 +82,9 @@ class CollisionLoss(nn.Module):
xi1, yi1 = min(xa1, xb1), min(ya1, yb1)
xi2, yi2 = max(xa2, xb2), max(ya2, yb2)
- intersect = max((xi1 - xi2), xi1.new_zeros(1, ).to(xi1.device)) * max((yi1 - yi2), xi1.new_zeros(1,).to(xi1.device))
+
+ intersect = max((xi1 - xi2), xi1.new_zeros(1, )) * max((yi1 - yi2), xi1.new_zeros(1,))
+ intersect = intersect.to(device_stack)
return intersect
def to_corners(self, bbox):
@@ -244,33 +244,26 @@ class ClipMatcher(nn.Module):
filtered_idx = []
for src_per_img, tgt_per_img in indices:
keep = tgt_per_img != -1
- filtered_idx.append((src_per_img[keep], tgt_per_img[keep]))
- indices = filtered_idx
+ if not keep.all().item():
+ filtered_idx.append((src_per_img[keep], tgt_per_img[keep]))
+ indices = filtered_idx
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
sdc_boxes = outputs["pred_sdc_boxes"][0, -1:]
target_sdc_boxes = gt_instances[0].sdc_boxes[:1]
- target_boxes = torch.cat(
- [
- gt_per_img.boxes[i]
- for gt_per_img, (_, i) in zip(gt_instances, indices)
- ],
- dim=0,
- )
+
+ target_boxes_list = []
+ target_obj_ids_list = []
+ for gt_per_img, (_, i) in zip(gt_instances, indices):
+ target_boxes_list.append(gt_per_img.boxes[i])
+ target_obj_ids_list.append(gt_per_img.obj_ids[i])
+
+ target_boxes = torch.cat(target_boxes_list, dim=0)
+ target_obj_ids = torch.cat(target_obj_ids_list, dim=0)
src_boxes = torch.cat([src_boxes, sdc_boxes], dim=0)
target_boxes = torch.cat([target_boxes, target_sdc_boxes], dim=0)
- # for pad target, don't calculate regression loss, judged by whether obj_id=-1
- target_obj_ids = torch.cat(
- [
- gt_per_img.obj_ids[i]
- for gt_per_img, (_, i) in zip(gt_instances, indices)
- ],
- dim=0,
- )
- # [num_matched]
-
target_obj_ids = torch.cat([target_obj_ids, torch.zeros(1).to(target_obj_ids.device)], dim=0)
mask = target_obj_ids != -1
bbox_weights = torch.ones_like(target_boxes) * self.code_weights
@@ -372,11 +365,8 @@ class ClipMatcher(nn.Module):
pred_past_trajs_i = track_instances.pred_past_trajs # predicted past trajs of i-th image.
obj_idxes = gt_instances_i.obj_ids
- obj_idxes_list = obj_idxes.detach().cpu().numpy().tolist()
- obj_idx_to_gt_idx = {
- obj_idx: gt_idx
- for gt_idx, obj_idx in enumerate(obj_idxes_list)
- }
+ obj_idxes_npu = obj_idxes.clone()
+
outputs_i = {
"pred_logits": pred_logits_i.unsqueeze(0),
"pred_sdc_logits": pred_sdc_logits_i,
@@ -386,19 +376,30 @@ class ClipMatcher(nn.Module):
}
# step1. inherit and update the previous tracks.
num_disappear_track = 0
- for j in range(len(track_instances)):
- obj_id = track_instances.obj_idxes[j].item()
- # set new target idx.
- if obj_id >= 0:
- if obj_id in obj_idx_to_gt_idx:
- track_instances.matched_gt_idxes[j] = obj_idx_to_gt_idx[
- obj_id]
- else:
- num_disappear_track += 1
- track_instances.matched_gt_idxes[
- j] = -1 # track-disappear case.
- else:
- track_instances.matched_gt_idxes[j] = -1
+
+ obj_idxes = track_instances.obj_idxes
+ mask_valid = obj_idxes >= 0
+
+ current_max = obj_idxes.max().item() if mask_valid.any() else 0
+
+ lookup_size = max(current_max + 1, 1)
+ lookup_table = torch.full((lookup_size,), -1, dtype=torch.long, device=track_instances.obj_idxes.device)
+
+ if obj_idxes_npu.any():
+ valid_mask = obj_idxes_npu < lookup_size
+ valid_obj_idxes = obj_idxes_npu[valid_mask]
+ valid_gt_indices = torch.arange(len(obj_idxes_npu), device=obj_idxes_npu.device)[valid_mask]
+ lookup_table[valid_obj_idxes] = valid_gt_indices
+
+ matched_gt_idxes = torch.where(
+ mask_valid,
+ lookup_table[torch.clamp(obj_idxes, max=lookup_size-1)],
+ torch.tensor(-1, dtype=torch.long, device=track_instances.obj_idxes.device)
+ )
+
+ num_disappear_track = (mask_valid & (matched_gt_idxes == -1)).sum().item()
+ track_instances.matched_gt_idxes = matched_gt_idxes
+
full_track_idxes = torch.arange(
len(track_instances), dtype=torch.long).to(pred_logits_i.device)
@@ -438,15 +439,20 @@ class ClipMatcher(nn.Module):
bs, num_querys = bbox_preds.shape[:2]
# Also concat the target labels and boxes
targets = [untracked_gt_instances]
+ gt_labels_list = []
+ gt_bboxes_list = []
if isinstance(targets[0], Instances):
- # [num_box], [num_box, 9] (un-normalized bboxes)
- gt_labels = torch.cat(
- [gt_per_img.labels for gt_per_img in targets])
- gt_bboxes = torch.cat(
- [gt_per_img.boxes for gt_per_img in targets])
+ for gt_per_img in targets:
+ gt_labels_list.append(gt_per_img.labels)
+ gt_bboxes_list.append(gt_per_img.boxes)
+ gt_labels = torch.cat(gt_labels_list)
+ gt_bboxes = torch.cat(gt_bboxes_list)
else:
- gt_labels = torch.cat([v["labels"] for v in targets])
- gt_bboxes = torch.cat([v["boxes"] for v in targets])
+ for v in targets:
+ gt_labels_list.append(v["labels"])
+ gt_bboxes_list.append(v["boxes"])
+ gt_labels = torch.cat(gt_labels_list)
+ gt_bboxes = torch.cat(gt_bboxes_list)
bbox_pred = bbox_preds[0]
cls_pred = cls_preds[0]
@@ -116,6 +116,7 @@ def min_ade(traj: torch.Tensor, traj_gt: torch.Tensor,
err = torch.pow(err, exponent=0.5)
err = torch.sum(err * (1 - masks_rpt), dim=2) / \
torch.clip(torch.sum((1 - masks_rpt), dim=2), min=1)
+ err = err.float()
err, inds = torch.min(err, dim=1)
return err, inds
@@ -195,6 +196,7 @@ def min_fde(traj: torch.Tensor, traj_gt: torch.Tensor,
err = torch.pow(err, exponent=2)
err = torch.sum(err, dim=2)
err = torch.pow(err, exponent=0.5)
+ err = err.float()
err, inds = torch.min(err, dim=1)
return err, inds
@@ -226,6 +228,7 @@ def miss_rate(
dist = torch.sum(dist, dim=3)
dist = torch.pow(dist, exponent=0.5)
dist[masks_rpt.bool()] = -math.inf
+ dist = dist.float()
dist, _ = torch.max(dist, dim=2)
dist, _ = torch.min(dist, dim=1)
m_r = torch.sum(torch.as_tensor(dist > dist_thresh)) / len(dist)
@@ -100,6 +100,14 @@ def anchor_coordinate_transform(anchors, bbox_results, with_translation_transfor
rot_yaw = rot_2d(angle) # num_agents, 2, 2
rot_yaw = rot_yaw[:, None, None,:, :] # num_agents, 1, 1, 2, 2
transformed_anchors = rearrange(transformed_anchors, 'b g m t c -> b g m c t') # 1, num_groups, num_modes, 12, 2 -> 1, num_groups, num_modes, 2, 12
+
+ num_agents, _, _, _, _ = rot_yaw.shape
+ _, num_groups, num_modes, _, _ = transformed_anchors.shape
+ broadcast_shape1 = (num_agents, num_groups, num_modes, 2, 2)
+ broadcast_shape2 = (num_agents, num_groups, num_modes, 2, 12)
+ rot_yaw = rot_yaw.expand(broadcast_shape1)
+ transformed_anchors = transformed_anchors.expand(broadcast_shape2)
+
transformed_anchors = torch.matmul(rot_yaw, transformed_anchors)# -> num_agents, num_groups, num_modes, 12, 2
transformed_anchors = rearrange(transformed_anchors, 'b g m c t -> b g m t c')
if with_translation_transform:
@@ -4,7 +4,7 @@ import os
import numpy as np
import torch
import torch.distributed as dist
-from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
@@ -67,22 +67,22 @@ def custom_train_detector(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
- model = MMDistributedDataParallel(
+ model = NPUDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
if eval_model is not None:
- eval_model = MMDistributedDataParallel(
+ eval_model = NPUDistributedDataParallel(
eval_model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
- model = MMDataParallel(
+ model = NPUDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
if eval_model is not None:
- eval_model = MMDataParallel(
+ eval_model = NPUDataParallel(
eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
@@ -467,6 +467,7 @@ class MotionHead(BaseMotionHead):
gt_fut_traj_mask_all = []
for i in range(num_imgs):
matched_gt_idx = all_matched_idxes[i]
+ matched_gt_idx = matched_gt_idx.detach().cpu()
valid_traj_masks = matched_gt_idx >= 0
matched_gt_fut_traj = gt_fut_traj[i][matched_gt_idx][valid_traj_masks]
matched_gt_fut_traj_mask = gt_fut_traj_mask[i][matched_gt_idx][valid_traj_masks]
@@ -19,6 +19,7 @@ from mmcv.cnn.bricks.drop import build_dropout
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import ConfigDict, deprecated_api_warning
from projects.mmdet3d_plugin.uniad.modules.multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32
+import mx_driving.fused
@TRANSFORMER_LAYER.register_module()
@@ -453,14 +454,8 @@ class MotionDeformableAttention(BaseModule):
f' 2 or 4, but get {reference_trajs.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
- # using fp16 deformable attention is unstable because it performs many sum operations
- if value.dtype == torch.float16:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- else:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- output = MultiScaleDeformableAttnFunction.apply(
- value, spatial_shapes, level_start_index, sampling_locations,
- attention_weights, self.im2col_step)
+ output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+ sampling_locations, attention_weights)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
@@ -21,7 +21,7 @@ class IntersectionOverUnion(Metric):
reduction: str = 'none',
compute_on_step: bool = False,
):
- super().__init__(compute_on_step=compute_on_step)
+ super().__init__()
self.n_classes = n_classes
self.ignore_index = ignore_index
@@ -108,9 +108,16 @@ class PlanningMetric(Metric):
m1 = torch.logical_and(m1, torch.logical_not(gt_box_coll))
ti = torch.arange(n_future)
+
+ yi = yi.detach().cpu()
+ xi = xi.detach().cpu()
+ m1 = m1.detach().cpu()
+
obj_coll_sum[ti[m1]] += segmentation[i, ti[m1], yi[m1], xi[m1]].long()
m2 = torch.logical_not(gt_box_coll)
+ m2 = m2.detach().cpu()
+
box_coll = self.evaluate_single_coll(trajs[i], segmentation[i])
obj_box_coll_sum[ti[m2]] += (box_coll[ti[m2]]).long()
@@ -291,6 +291,9 @@ class HungarianAssigner_filter(BaseAssigner):
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
if i == 0:
result = AssignResult(num_gts, assigned_gt_inds.clone(), None, labels=assigned_labels.clone())
+
+ matched_row_inds = matched_row_inds.detach().cpu()
+ matched_col_inds = matched_col_inds.detach().cpu()
if cost[matched_row_inds,matched_col_inds].max()>=INF:
break
pos_ind = assigned_gt_inds.gt(0).nonzero().squeeze(1)
@@ -224,6 +224,7 @@ class UniAD(UniADTrack):
losses.update(losses_planning)
for k,v in losses.items():
+ v = v.float()
losses[k] = torch.nan_to_num(v)
return losses
@@ -26,6 +26,7 @@ from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16
+import mx_driving.fused
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@@ -324,14 +325,8 @@ class CustomMSDeformableAttention(BaseModule):
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
- # using fp16 deformable attention is unstable because it performs many sum operations
- if value.dtype == torch.float16:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- else:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- output = MultiScaleDeformableAttnFunction.apply(
- value, spatial_shapes, level_start_index, sampling_locations,
- attention_weights, self.im2col_step)
+ output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+ sampling_locations, attention_weights)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
@@ -117,8 +117,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
lidar2img = lidar2img.view(
1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
- reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
- reference_points.to(torch.float32)).squeeze(-1)
+ reference_points_cam = torch.mul(lidar2img.to(torch.float32),
+ reference_points.to(torch.float32).transpose(-1, -2)).sum(-1, keepdim=True).squeeze(-1)
eps = 1e-5
bev_mask = (reference_points_cam[..., 2:3] > eps)
@@ -23,6 +23,13 @@ from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16
+import mx_driving.fused
+
+indexes_global = None
+max_len_global = None
+bev_mask_id_global = -1
+count_global = None
+
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@@ -134,10 +141,28 @@ class SpatialCrossAttention(BaseModule):
D = reference_points_cam.size(3)
indexes = []
- for i, mask_per_img in enumerate(bev_mask):
- index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
- indexes.append(index_query_per_img)
- max_len = max([len(each) for each in indexes])
+
+ global indexes_global, max_len_global, bev_mask_id_global, count_global
+ bev_mask_id = id(bev_mask)
+ if bev_mask_id == bev_mask_id_global:
+ indexes = indexes_global
+ max_len = max_len_global
+ count = count_global
+ else:
+ count = torch.any(bev_mask, 3)
+ bev_mask_ = count.squeeze()
+ for i, mask_per_img in enumerate(bev_mask_):
+ index_query_per_img = mask_per_img.nonzero().squeeze(-1)
+ indexes.append(index_query_per_img)
+
+ max_len = max([len(each) for each in indexes])
+ count = count.permute(1, 2, 0).sum(-1)
+ count = torch.clamp(count, min=1.0)
+ count = count[..., None]
+ count_global = count
+ indexes_global = indexes
+ max_len_global = max_len
+ bev_mask_id_global = bev_mask_id
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch = query.new_zeros(
@@ -145,9 +170,9 @@ class SpatialCrossAttention(BaseModule):
reference_points_rebatch = reference_points_cam.new_zeros(
[bs, self.num_cams, max_len, D, 2])
- for j in range(bs):
- for i, reference_points_per_img in enumerate(reference_points_cam):
- index_query_per_img = indexes[i]
+ for i, reference_points_per_img in enumerate(reference_points_cam):
+ index_query_per_img = indexes[i]
+ for j in range(bs):
queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
@@ -165,10 +190,7 @@ class SpatialCrossAttention(BaseModule):
for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
- count = bev_mask.sum(-1) > 0
- count = count.permute(1, 2, 0).sum(-1)
- count = torch.clamp(count, min=1.0)
- slots = slots / count[..., None]
+ slots = slots / count
slots = self.output_proj(slots)
return self.dropout(slots) + inp_residual
@@ -382,13 +404,8 @@ class MSDeformableAttention3D(BaseModule):
#
if torch.cuda.is_available() and value.is_cuda:
- if value.dtype == torch.float16:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- else:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- output = MultiScaleDeformableAttnFunction.apply(
- value, spatial_shapes, level_start_index, sampling_locations,
- attention_weights, self.im2col_step)
+ output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+ sampling_locations, attention_weights)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
@@ -17,6 +17,8 @@ from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from mmcv.utils import ext_loader
+import mx_driving.fused
+
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@@ -236,14 +238,8 @@ class TemporalSelfAttention(BaseModule):
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
- # using fp16 deformable attention is unstable because it performs many sum operations
- if value.dtype == torch.float16:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- else:
- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
- output = MultiScaleDeformableAttnFunction.apply(
- value, spatial_shapes, level_start_index, sampling_locations,
- attention_weights, self.im2col_step)
+ output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index,
+ sampling_locations, attention_weights)
else:
output = multi_scale_deformable_attn_pytorch(
@@ -14,7 +14,7 @@ from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_
from mmcv.runner.base_module import BaseModule
-from torchvision.transforms.functional import rotate
+from torchvision.transforms.functional import InterpolationMode, rotate
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention
@@ -142,7 +142,7 @@ class PerceptionTransformer(BaseModule):
rotation_angle = img_metas[i]['can_bus'][-1]
tmp_prev_bev = prev_bev[:, i].reshape(
bev_h, bev_w, -1).permute(2, 0, 1)
- tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
+ tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle, InterpolationMode.BILINEAR,
center=self.rotate_center)
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
bev_h * bev_w, 1, -1)
@@ -1,6 +1,11 @@
+mmdet==2.28.0
+mmsegmentation==0.30.0
google-cloud-bigquery
motmetrics==1.1.3
einops==0.4.1
-numpy==1.20.0
+numpy==1.23.0
casadi==3.5.5
-pytorch-lightning==1.2.5
\ No newline at end of file
+pytorch-lightning==1.2.5
+torchmetrics==0.11.4
+numba==0.58.1
+IPython
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,260 @@
+from __future__ import division
+
+import argparse
+import cv2
+import torch
+import sklearn
+import mmcv
+import copy
+import os
+import time
+import warnings
+from mx_driving.patcher import default_patcher_builder
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist
+from os import path as osp
+
+from mmdet import __version__ as mmdet_version
+from mmdet3d import __version__ as mmdet3d_version
+
+from mmdet3d.datasets import build_dataset
+from mmdet3d.models import build_model
+from mmdet3d.utils import collect_env, get_root_logger
+from mmdet.apis import set_random_seed
+from mmseg import __version__ as mmseg_version
+
+warnings.filterwarnings("ignore")
+
+from mmcv.utils import TORCH_VERSION, digit_version
+
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.config.allow_internal_format = False
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a detector')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--work-dir', help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume-from', help='the checkpoint file to resume from')
+ parser.add_argument(
+ '--no-validate',
+ action='store_true',
+ help='whether not to evaluate the checkpoint during training')
+ group_gpus = parser.add_mutually_exclusive_group()
+ group_gpus.add_argument(
+ '--gpus',
+ type=int,
+ help='number of gpus to use '
+ '(only applicable to non-distributed training)')
+ group_gpus.add_argument(
+ '--gpu-ids',
+ type=int,
+ nargs='+',
+ help='ids of gpus to use '
+ '(only applicable to non-distributed training)')
+ parser.add_argument('--seed', type=int, default=0, help='random seed')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file (deprecate), '
+ 'change to --cfg-options instead.')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--autoscale-lr',
+ action='store_true',
+ help='automatically scale lr with the number of gpus')
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ if args.options and args.cfg_options:
+ raise ValueError(
+ '--options and --cfg-options cannot be both specified, '
+ '--options is deprecated in favor of --cfg-options')
+ if args.options:
+ warnings.warn('--options is deprecated in favor of --cfg-options')
+ args.cfg_options = args.options
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # import modules from string list.
+ if cfg.get('custom_imports', None):
+ from mmcv.utils import import_modules_from_strings
+ import_modules_from_strings(**cfg['custom_imports'])
+
+ # import modules from plguin/xx, registry will be updated
+ if hasattr(cfg, 'plugin'):
+ if cfg.plugin:
+ import importlib
+ if hasattr(cfg, 'plugin_dir'):
+ plugin_dir = cfg.plugin_dir
+ _module_dir = os.path.dirname(plugin_dir)
+ _module_dir = _module_dir.split('/')
+ _module_path = _module_dir[0]
+
+ for m in _module_dir[1:]:
+ _module_path = _module_path + '.' + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+ else:
+ # import dir is the dirpath for the config file
+ _module_dir = os.path.dirname(args.config)
+ _module_dir = _module_dir.split('/')
+ _module_path = _module_dir[0]
+ for m in _module_dir[1:]:
+ _module_path = _module_path + '.' + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+
+ from projects.mmdet3d_plugin.uniad.apis.train import custom_train_model
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+ # if args.resume_from is not None:
+ if args.resume_from is not None and osp.isfile(args.resume_from):
+ cfg.resume_from = args.resume_from
+ if args.gpu_ids is not None:
+ cfg.gpu_ids = args.gpu_ids
+ else:
+ cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+ if digit_version(TORCH_VERSION) == digit_version('1.8.1') and cfg.optimizer['type'] == 'AdamW':
+ cfg.optimizer['type'] = 'AdamW2' # fix bug in Adamw
+ if args.autoscale_lr:
+ cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+ # re-set gpu_ids with distributed training mode
+ _, world_size = get_dist_info()
+ cfg.gpu_ids = range(world_size)
+
+ # create work_dir
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+ # dump config
+ cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+ # init the logger before other steps
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+ # specify logger name, if we still use 'mmdet', the output info will be
+ # filtered and won't be saved in the log_file
+ # TODO: ugly workaround to judge whether we are training det or seg model
+ if cfg.model.type in ['EncoderDecoder3D']:
+ logger_name = 'mmseg'
+ else:
+ logger_name = 'mmdet'
+ logger = get_root_logger(
+ log_file=log_file, log_level=cfg.log_level, name=logger_name)
+
+ # init the meta dict to record some important information such as
+ # environment info and seed, which will be logged
+ meta = dict()
+ # log env info
+ env_info_dict = collect_env()
+ env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+ dash_line)
+ meta['env_info'] = env_info
+ meta['config'] = cfg.pretty_text
+
+ # log some basic info
+ logger.info(f'Distributed training: {distributed}')
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # set random seeds
+ if args.seed is not None:
+ logger.info(f'Set random seed to {args.seed}, '
+ f'deterministic: {args.deterministic}')
+ set_random_seed(args.seed, deterministic=args.deterministic)
+ cfg.seed = args.seed
+ meta['seed'] = args.seed
+ meta['exp_name'] = osp.basename(args.config)
+
+ model = build_model(
+ cfg.model,
+ train_cfg=cfg.get('train_cfg'),
+ test_cfg=cfg.get('test_cfg'))
+ model.init_weights()
+
+ logger.info(f'Model:\n{model}')
+ datasets = [build_dataset(cfg.data.train)]
+ if len(cfg.workflow) == 2:
+ val_dataset = copy.deepcopy(cfg.data.val)
+ # in case we use a dataset wrapper
+ if 'dataset' in cfg.data.train:
+ val_dataset.pipeline = cfg.data.train.dataset.pipeline
+ else:
+ val_dataset.pipeline = cfg.data.train.pipeline
+ # set test_mode=False here in deep copied config
+ # which do not affect AP/AR calculation later
+ val_dataset.test_mode = False
+ datasets.append(build_dataset(val_dataset))
+ if cfg.checkpoint_config is not None:
+ # save mmdet version, config file content and class names in
+ # checkpoints as meta data
+ cfg.checkpoint_config.meta = dict(
+ mmdet_version=mmdet_version,
+ mmseg_version=mmseg_version,
+ mmdet3d_version=mmdet3d_version,
+ config=cfg.pretty_text,
+ CLASSES=datasets[0].CLASSES,
+ PALETTE=datasets[0].PALETTE # for segmentors
+ if hasattr(datasets[0], 'PALETTE') else None)
+ # add an attribute for visualization convenience
+ model.CLASSES = datasets[0].CLASSES
+ custom_train_model(
+ model,
+ datasets,
+ cfg,
+ distributed=distributed,
+ validate=(not args.no_validate),
+ timestamp=timestamp,
+ meta=meta)
+
+
+if __name__ == '__main__':
+ with default_patcher_builder.disable_patches("index").brake_at(500).build():
+ main()
@@ -7,7 +7,7 @@ import os
import warnings
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
-from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
@@ -224,7 +224,7 @@ def main():
# model = MMDataParallel(model, device_ids=[0])
# outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
else:
- model = MMDistributedDataParallel(
+ model = NPUDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
@@ -9,6 +9,7 @@ import copy
import os
import time
import warnings
+from mx_driving.patcher import default_patcher_builder
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from os import path as osp
@@ -26,6 +27,10 @@ warnings.filterwarnings("ignore")
from mmcv.utils import TORCH_VERSION, digit_version
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.config.allow_internal_format = False
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
@@ -253,4 +258,5 @@ def main():
if __name__ == '__main__':
- main()
+ with default_patcher_builder.disable_patches("index").build():
+ main()
new file mode 100644
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+T=`date +%m%d%H%M`
+
+# -------------------------------------------------- #
+# Usually you only need to customize these variables #
+CFG=$1 #
+GPUS=$2 #
+# -------------------------------------------------- #
+GPUS_PER_NODE=$(($GPUS<8?$GPUS:8))
+NNODES=`expr $GPUS / $GPUS_PER_NODE`
+
+MASTER_PORT=${MASTER_PORT:-28596}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+RANK=${RANK:-0}
+
+WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
+# Intermediate files and logs will be saved to UniAD/projects/work_dirs/
+
+if [ ! -d ${WORK_DIR}logs ]; then
+ mkdir -p ${WORK_DIR}logs
+fi
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+torchrun \
+ --nproc_per_node=${GPUS_PER_NODE} \
+ --master_addr=${MASTER_ADDR} \
+ --master_port=${MASTER_PORT} \
+ --nnodes=${NNODES} \
+ --node_rank=${RANK} \
+ $(dirname "$0")/perf.py \
+ $CFG \
+ --launcher pytorch ${@:3} \
+ --deterministic \
+ --work-dir ${WORK_DIR} \
+ 2>&1 | tee ${WORK_DIR}logs/train.$T
\ No newline at end of file
@@ -22,7 +22,7 @@ if [ ! -d ${WORK_DIR}logs ]; then
fi
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
-python -m torch.distributed.launch \
+torchrun \
--nproc_per_node=${GPUS_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \