import torch
from torch import nn
from utils_pt import outputActivation
from torch_npu.npu.amp import autocast
from torchinfo import summary
def build_grid_adj_matrix(grid_size, device):
"""构建网格节点的邻接矩阵(25×5网格→125个节点)"""
H, W = grid_size
num_nodes = H * W
A = torch.zeros((num_nodes, num_nodes), device=device)
for i in range(H):
for j in range(W):
node_idx = i * W + j
if i > 0:
A[node_idx][(i-1)*W + j] = 1
if i < H-1:
A[node_idx][(i+1)*W + j] = 1
if j > 0:
A[node_idx][i*W + (j-1)] = 1
if j < W-1:
A[node_idx][i*W + (j+1)] = 1
D = torch.diag(torch.sum(A, dim=1))
D_inv_sqrt = torch.inverse(torch.sqrt(D + 1e-6))
A_norm = D_inv_sqrt @ A @ D_inv_sqrt
return A_norm.contiguous()
class ConvBNReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class AttentionLayer(nn.Module):
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
Make sure the tensor is permuted to correct shape before attention.
E.g.
- Input shape (batch_size, in_steps, num_nodes, model_dim).
- Then the attention will be performed across the nodes.
Also, it supports different src and tgt length.
But must `src length == K length == V length`.
"""
def __init__(self, model_dim, num_heads=8, mask=False):
super().__init__()
self.model_dim = model_dim
self.num_heads = num_heads
self.mask = mask
self.head_dim = model_dim // num_heads
self.FC_Q = nn.Linear(model_dim, model_dim)
self.FC_K = nn.Linear(model_dim, model_dim)
self.FC_V = nn.Linear(model_dim, model_dim)
self.out_proj = nn.Linear(model_dim, model_dim)
def forward(self, query, key, value):
batch_size = query.shape[0]
tgt_length = query.shape[-2]
src_length = key.shape[-2]
query = self.FC_Q(query)
key = self.FC_K(key)
value = self.FC_V(value)
query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
key = key.transpose(
-1, -2
)
attn_score = (
query @ key
) / self.head_dim**0.5
if self.mask:
mask = torch.ones(
tgt_length, src_length, dtype=torch.bool, device=query.device
).tril()
attn_score.masked_fill_(~mask, -torch.inf)
attn_score = torch.softmax(attn_score, dim=-1)
out = attn_score @ value
out = torch.cat(
torch.split(out, batch_size, dim=0), dim=-1
)
out = self.out_proj(out)
return out
class SelfAttentionLayer(nn.Module):
def __init__(
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
):
super().__init__()
self.attn = AttentionLayer(model_dim, num_heads, mask)
self.feed_forward = nn.Sequential(
nn.Linear(model_dim, feed_forward_dim),
nn.ReLU(inplace=True),
nn.Linear(feed_forward_dim, model_dim),
)
self.ln1 = nn.LayerNorm(model_dim)
self.ln2 = nn.LayerNorm(model_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, dim=-2):
x = x.transpose(dim, -2)
residual = x
out = self.attn(x, x, x)
out = self.dropout1(out)
sum_ln1 = residual + out
sum_ln1_fp32 = sum_ln1.to(torch.float32)
out_ln1 = self.ln1(sum_ln1_fp32)
out = out_ln1.to(residual.dtype)
residual = out
out = self.feed_forward(out)
out = self.dropout2(out)
sum_ln2 = residual + out
sum_ln2_fp32 = sum_ln2.to(torch.float32)
out_ln2 = self.ln2(sum_ln2_fp32)
out = out_ln2.to(residual.dtype)
return out.transpose(dim, -2)
class pipNet(nn.Module):
def __init__(self, args):
super(pipNet, self).__init__()
self.args = args
self.use_cuda = args.use_cuda
self.train_output_flag = args.train_output_flag
self.use_planning = args.use_planning
self.use_fusion = args.use_fusion
self.grid_size = args.grid_size
self.in_length = args.in_length
self.out_length = args.out_length
self.num_lat_classes = args.num_lat_classes
self.num_lon_classes = args.num_lon_classes
self.temporal_embedding_size = args.temporal_embedding_size
self.encoder_size = args.encoder_size
self.decoder_size = args.decoder_size
self.soc_conv_depth = args.soc_conv_depth
self.soc_conv2_depth = args.soc_conv2_depth
self.dynamics_encoding_size = args.dynamics_encoding_size
self.social_context_size = args.social_context_size
self.targ_enc_size = self.social_context_size + self.dynamics_encoding_size
self.fuse_enc_size = args.fuse_enc_size
self.fuse_conv1_size = 2 * self.fuse_enc_size
self.fuse_conv2_size = 4 * self.fuse_enc_size
self.graph_alpha = nn.Parameter(torch.tensor(-0.847))
self.num_layers = args.num_layers
self.feed_forward_dim = args.feed_forward_dim
self.num_heads = args.num_heads
self.dropout = args.dropout
self.leaky_relu = nn.LeakyReLU(0.1)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
''' Convert traj to temporal embedding'''
self.temporalConv = nn.Conv1d(in_channels=2, out_channels=self.temporal_embedding_size, kernel_size=3, padding=1)
''' Encode the input temporal embedding '''
self.nbh_lstm = nn.LSTM(
input_size=self.temporal_embedding_size,
hidden_size=self.encoder_size,
num_layers=1
)
if self.use_planning:
self.plan_lstm = nn.LSTM(
input_size=self.temporal_embedding_size,
hidden_size=self.encoder_size,
num_layers=1
)
self.attn_layers_t = nn.ModuleList(
[
SelfAttentionLayer(self.temporal_embedding_size, self.feed_forward_dim, self.num_heads, self.dropout)
for _ in range(self.num_layers)
]
)
self.attn_layers_s = nn.ModuleList(
[
SelfAttentionLayer(self.temporal_embedding_size, self.feed_forward_dim, self.num_heads, self.dropout)
for _ in range(self.num_layers)
]
)
''' Encoded dynamic to dynamics_encoding_size'''
self.dyn_emb = nn.Linear(self.encoder_size, self.dynamics_encoding_size)
''' Convolutional Social Pooling on the planned vehicle and all nbrs vehicles '''
self.nbrs_conv_social = nn.Sequential(
nn.Conv2d(self.encoder_size, self.soc_conv_depth, 3),
self.leaky_relu,
nn.MaxPool2d((3, 3), stride=2),
nn.Conv2d(self.soc_conv_depth, self.soc_conv2_depth, (3, 1)),
self.leaky_relu
)
if self.use_planning:
self.plan_conv_social = nn.Sequential(
nn.Conv2d(self.encoder_size, self.soc_conv_depth, 3),
self.leaky_relu,
nn.MaxPool2d((3, 3), stride=2),
nn.Conv2d(self.soc_conv_depth, self.soc_conv2_depth, (3, 1)),
self.leaky_relu
)
self.pool_after_merge = nn.MaxPool2d((2, 2), padding=(1, 0))
else:
self.pool_after_merge = nn.MaxPool2d((2, 1), padding=(1, 0))
''' Target Fusion Module'''
if self.use_fusion:
''' Fused Structure'''
self.fcn_conv1 = ConvBNReLU(self.targ_enc_size, self.fuse_conv1_size, kernel_size=3, stride=1, padding=1)
self.fcn_pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.fcn_conv2 = ConvBNReLU(self.fuse_conv1_size, self.fuse_conv2_size, kernel_size=3, stride=1, padding=1)
self.fcn_pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.fcn_convTrans1 = nn.ConvTranspose2d(self.fuse_conv2_size, self.fuse_conv1_size, kernel_size=3, stride=2, padding=1)
self.back_bn1 = nn.BatchNorm2d(self.fuse_conv1_size)
self.fcn_convTrans2 = nn.ConvTranspose2d(self.fuse_conv1_size, self.fuse_enc_size, kernel_size=3, stride=2, padding=1)
self.back_bn2 = nn.BatchNorm2d(self.fuse_enc_size)
else:
self.fuse_enc_size = 0
''' Decoder LSTM'''
self.op_lat = nn.Linear(self.targ_enc_size + self.fuse_enc_size,
self.num_lat_classes)
self.op_lon = nn.Linear(self.targ_enc_size + self.fuse_enc_size,
self.num_lon_classes)
self.dec_lstm = nn.LSTM(input_size=self.targ_enc_size + self.fuse_enc_size + self.num_lat_classes + self.num_lon_classes,
hidden_size=self.decoder_size)
''' Output layers '''
self.op = nn.Linear(self.decoder_size, 5)
if self.use_fusion or self.use_planning:
self.A_norm = build_grid_adj_matrix(self.grid_size, torch.device('npu'))
self.A_norm = self.A_norm.half()
def forward(self, nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask,
lat_enc, lon_enc, idx, space_h=None, dv=None, v_pre=None):
''' Forward target vehicle's dynamic'''
with autocast(enabled=False):
dyn_enc = self.leaky_relu(self.temporalConv(targsHist.permute(1, 2, 0)))
dyn_enc = dyn_enc.permute(0, 2, 1)
for attn in self.attn_layers_t:
dyn_enc = attn(dyn_enc, dim=1)
dyn_enc = dyn_enc.half()
with autocast(enabled=False):
_, (dyn_enc, _) = self.nbh_lstm(dyn_enc.permute(1, 0, 2).float())
dyn_enc = self.leaky_relu(self.dyn_emb(dyn_enc.view(dyn_enc.shape[1], dyn_enc.shape[2])))
dyn_enc = dyn_enc.to(dtype=torch.float16, device=dyn_enc.device)
''' Forward neighbour vehicles'''
with autocast():
nbrs_enc = self.leaky_relu(self.temporalConv(nbsHist.permute(1, 2, 0)))
with autocast(enabled=False):
_, (nbrs_enc, _) = self.nbh_lstm(nbrs_enc.permute(2, 0, 1).float())
nbrs_enc = nbrs_enc.view(nbrs_enc.shape[1], nbrs_enc.shape[2])
nbrs_enc = nbrs_enc.to(dtype=torch.float16, device=nbrs_enc.device)
''' Masked neighbour vehicles'''
nbrs_grid_static = torch.zeros_like(nbsMask, dtype=torch.float16, device=nbsMask.device)
nbrs_grid_static = nbrs_grid_static.masked_scatter(nbsMask.bool(), nbrs_enc)
nbrs_grid = nbrs_grid_static.permute(0, 3, 2, 1)
BS, C_enc, H, W = nbrs_grid.shape
num_nodes = H * W
nbrs_node_feat = nbrs_grid.reshape(BS, C_enc, num_nodes).permute(0, 2, 1)
nbrs_node_feat = nbrs_node_feat.half()
with autocast(enabled=False):
A_norm_batch = self.A_norm.expand(BS, -1, -1)
nbrs_node_feat_fused = torch.bmm(
A_norm_batch.float(),
nbrs_node_feat.float()
)
nbrs_grid_fused = nbrs_node_feat_fused.permute(0, 2, 1).reshape(BS, C_enc, H, W)
alpha = torch.sigmoid(self.graph_alpha)
nbrs_grid_fused = nbrs_grid + alpha * (nbrs_grid_fused - nbrs_grid)
with autocast():
nbrs_grid_fused = self.nbrs_conv_social[0](nbrs_grid_fused)
nbrs_grid = self.nbrs_conv_social[1:](nbrs_grid_fused)
if self.use_planning:
''' Forward planned vehicle'''
with autocast():
plan_enc = self.leaky_relu(self.temporalConv(planFut.permute(1, 2, 0)))
with autocast(enabled=False):
_, (plan_enc, _) = self.plan_lstm(plan_enc.permute(2, 0, 1).float())
plan_enc = plan_enc.view(plan_enc.shape[1], plan_enc.shape[2])
plan_enc = plan_enc.to(dtype=torch.float16, device=plan_enc.device)
''' Masked planned vehicle'''
plan_grid = torch.zeros_like(planMask, dtype=torch.float16, device=planMask.device)
plan_grid = plan_grid.masked_scatter(planMask.bool(), plan_enc)
plan_grid = plan_grid.permute(0, 3, 2, 1)
with autocast():
plan_grid = self.plan_conv_social(plan_grid)
''' Merge neighbour and planned vehicle'''
merge_grid = torch.cat((nbrs_grid, plan_grid), dim=3)
merge_grid = self.pool_after_merge(merge_grid)
else:
merge_grid = self.pool_after_merge(nbrs_grid)
social_context = merge_grid.view(-1, self.social_context_size)
'''Concatenate social_context (neighbors + ego's planing) and dyn_enc, then place into the targsEncMask '''
target_enc = torch.cat((social_context, dyn_enc), 1)
target_grid = torch.zeros_like(targsEncMask, dtype=torch.float16, device=targsEncMask.device)
target_grid = target_grid.masked_scatter(targsEncMask.bool(), target_enc)
if self.use_fusion:
'''Fully Convolutional network to get a grid to be fused'''
with autocast():
fuse_conv1 = self.fcn_conv1(target_grid.permute(0, 3, 2, 1))
fuse_conv1 = self.fcn_pool1(fuse_conv1)
fuse_conv2 = self.fcn_conv2(fuse_conv1)
fuse_conv2 = self.fcn_pool2(fuse_conv2)
fuse_trans1 = self.relu(self.fcn_convTrans1(fuse_conv2))
fuse_trans1 = self.back_bn1(fuse_trans1)
fuse_trans2 = self.relu(self.fcn_convTrans2(fuse_trans1))
fuse_trans2 = self.back_bn2(fuse_trans2)
fuse_grid_mask = targsEncMask[:, :, :, 0:self.fuse_enc_size]
fuse_grid = torch.zeros_like(fuse_grid_mask, dtype=torch.float16, device=fuse_grid_mask.device)
fuse_grid = fuse_grid.masked_scatter(fuse_grid_mask.bool(), fuse_trans2.permute(0, 3, 2, 1))
'''Finally, Integrate everything together'''
enc_rows_mark = targsEncMask[:, :, :, 0].view(-1)
enc_rows = [i for i in range(len(enc_rows_mark)) if enc_rows_mark[i]]
enc = torch.cat([target_grid, fuse_grid], dim=3)
enc = enc.view(-1, self.fuse_enc_size + self.targ_enc_size)
enc = enc[enc_rows, :]
else:
enc = target_enc
'''Maneuver recognition'''
with autocast():
lat_pred = self.softmax(self.op_lat(enc))
lon_pred = self.softmax(self.op_lon(enc))
if self.train_output_flag:
enc = torch.cat((enc, lat_enc, lon_enc), 1)
fut_pred = self.decode(enc)
return fut_pred, lat_pred, lon_pred
else:
fut_pred = []
for k in range(self.num_lon_classes):
for l in range(self.num_lat_classes):
lat_enc_tmp = torch.zeros_like(lat_enc)
lon_enc_tmp = torch.zeros_like(lon_enc)
lat_enc_tmp[:, l] = 1
lon_enc_tmp[:, k] = 1
enc_tmp = torch.cat((enc, lat_enc_tmp, lon_enc_tmp), 1)
fut_pred.append(self.decode(enc_tmp))
return fut_pred, lat_pred, lon_pred
def decode(self, enc):
with autocast(enabled=False):
enc = enc.repeat(self.out_length, 1, 1).float()
h_dec, _ = self.dec_lstm(enc)
with autocast():
h_dec = h_dec.permute(1, 0, 2)
fut_pred = self.op(h_dec)
fut_pred = fut_pred.permute(1, 0, 2)
fut_pred = outputActivation(fut_pred)
return fut_pred