import math
import logging
import numpy as np
import torch
from scipy.optimize import curve_fit
def weights_init(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.1)
def initLogging(log_file: str, level: str = "INFO"):
logging.basicConfig(filename=log_file, filemode='a',
level=getattr(logging, level, None),
format='[%(levelname)s %(asctime)s] %(message)s',
datefmt='%m-%d %H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler())
def quintic_spline(x, z, a, b, c, d, e):
return z + a * x + b * x ** 2 + c * x ** 3 + d * x ** 4 + e * x ** 5
def fitting_traj_by_qs(x, y):
param, loss = curve_fit(quintic_spline, x, y,
bounds=([y[0], -np.inf, -np.inf, -np.inf, -np.inf, -np.inf], [y[0]+1e-6, np.inf, np.inf, np.inf, np.inf, np.inf]))
return param
def outputActivation(x, displacement=True):
if displacement:
x[:, :, 0:2] = torch.stack([torch.sum(x[0:i, :, 0:2], dim=0) for i in range(1, x.shape[0] + 1)], 0)
muX = x[:, :, 0:1]
muY = x[:, :, 1:2]
sigX = x[:, :, 2:3]
sigY = x[:, :, 3:4]
rho = x[:, :, 4:5]
sigX = torch.exp(sigX)
sigY = torch.exp(sigY)
rho = torch.tanh(rho)
out = torch.cat([muX, muY, sigX, sigY, rho], dim=2)
return out
def maskedNLL(y_pred, y_gt, mask):
acc = torch.zeros_like(mask, device=mask.device)
muX = y_pred[:, :, 0]
muY = y_pred[:, :, 1]
sigX = y_pred[:, :, 2]
sigY = y_pred[:, :, 3]
rho = y_pred[:, :, 4]
ohr = torch.pow(1 - torch.pow(rho, 2), -0.5)
x = y_gt[:, :, 0]
y = y_gt[:, :, 1]
out = 0.5 * torch.pow(ohr, 2) * \
(torch.pow(sigX, 2) * torch.pow(x - muX, 2) + torch.pow(sigY, 2) * torch.pow(y - muY, 2) - 2 * rho *
torch.pow(sigX, 1) * torch.pow(sigY, 1) * (x - muX) * (y - muY)) - torch.log(sigX * sigY * ohr) \
+ torch.log(torch.tensor(2 * math.pi))
acc[:, :, 0] = out
acc[:, :, 1] = out
acc = acc * mask
lossVal = torch.sum(acc) / (torch.sum(mask) + 1e-8)
return lossVal
def maskedNLLTest(fut_pred, lat_pred, lon_pred, fut, op_mask,
num_lat_classes=3, num_lon_classes=2,
use_maneuvers=True, avg_along_time=False, separately=False):
if use_maneuvers:
acc = torch.zeros(op_mask.shape[0], op_mask.shape[1], num_lon_classes * num_lat_classes, device=fut.device)
count = 0
for k in range(num_lon_classes):
for l in range(num_lat_classes):
wts = lat_pred[:, l] * lon_pred[:, k]
wts = wts.repeat(len(fut_pred[0]), 1)
y_pred = fut_pred[k * num_lat_classes + l]
y_gt = fut
muX = y_pred[:, :, 0]
muY = y_pred[:, :, 1]
sigX = y_pred[:, :, 2]
sigY = y_pred[:, :, 3]
rho = y_pred[:, :, 4]
ohr = torch.pow(1 - torch.pow(rho, 2), -0.5)
x = y_gt[:, :, 0]
y = y_gt[:, :, 1]
out = -(0.5 * torch.pow(ohr, 2) * (torch.pow(sigX, 2) * torch.pow(x - muX, 2) + torch.pow(sigY, 2) * torch.pow(y - muY, 2)
- 2 * rho * torch.pow(sigX, 1) * torch.pow(sigY, 1) * (x - muX) * (y - muY)) - torch.log(sigX * sigY * ohr)
+ torch.log(torch.tensor(2 * math.pi)))
acc[:, :, count] = out + torch.log(wts)
count += 1
acc = -logsumexp(acc, dim=2)
acc = acc * op_mask[:, :, 0]
if avg_along_time:
lossVal = torch.sum(acc) / torch.sum(op_mask[:, :, 0])
return lossVal
else:
if separately:
lossVal = acc
counts = op_mask[:, :, 0]
return lossVal, counts
else:
lossVal = torch.sum(acc, dim=1)
counts = torch.sum(op_mask[:, :, 0], dim=1)
return lossVal, counts
else:
acc = torch.zeros(op_mask.shape[0], op_mask.shape[1], 1, device=op_mask.device)
y_pred = fut_pred
y_gt = fut
muX = y_pred[:, :, 0]
muY = y_pred[:, :, 1]
sigX = y_pred[:, :, 2]
sigY = y_pred[:, :, 3]
rho = y_pred[:, :, 4]
ohr = torch.pow(1 - torch.pow(rho, 2), -0.5)
x = y_gt[:, :, 0]
y = y_gt[:, :, 1]
out = +(0.5 * torch.pow(ohr, 2) * (torch.pow(sigX, 2) * torch.pow(x - muX, 2) + torch.pow(sigY, 2) * torch.pow(y - muY, 2)
- 2 * rho * torch.pow(sigX, 1) * torch.pow(sigY, 1) * (x - muX) * (y - muY)) - torch.log(sigX * sigY * ohr)
+ torch.log(torch.tensor(2 * math.pi)))
acc[:, :, 0] = out
acc = acc * op_mask[:, :, 0:1]
if avg_along_time:
lossVal = torch.sum(acc[:, :, 0]) / (torch.sum(op_mask[:, :, 0]) + 1e-8)
return lossVal
else:
if separately:
lossVal = acc[:, :, 0]
counts = op_mask[:, :, 0]
return lossVal, counts
else:
lossVal = torch.sum(acc[:, :, 0], dim=1)
counts = torch.sum(op_mask[:, :, 0], dim=1)
return lossVal, counts
def maskedMSE(y_pred, y_gt, mask):
acc = torch.zeros_like(mask, device=mask.device)
muX = y_pred[:, :, 0]
muY = y_pred[:, :, 1]
x = y_gt[:, :, 0]
y = y_gt[:, :, 1]
out = torch.pow(x - muX, 2) + torch.pow(y - muY, 2)
acc[:, :, 0] = out
acc[:, :, 1] = out
acc = acc * mask
lossVal = torch.sum(acc) / (torch.sum(mask) + 1e-8)
return lossVal
def maskedMSETest(y_pred, y_gt, mask, separately=False):
acc = torch.zeros_like(mask, device=mask.device)
muX = y_pred[:, :, 0]
muY = y_pred[:, :, 1]
x = y_gt[:, :, 0]
y = y_gt[:, :, 1]
out = torch.pow(x - muX, 2) + torch.pow(y - muY, 2)
acc[:, :, 0] = out
acc[:, :, 1] = out
acc = acc * mask
if separately:
return acc[:, :, 0], mask[:, :, 0]
else:
lossVal = torch.sum(acc[:, :, 0], dim=1)
counts = torch.sum(mask[:, :, 0], dim=1)
counts = torch.clamp(counts, min=1e-8)
return lossVal, counts
def MAPE(y_true, y_pred, null_val=0):
with np.errstate(divide="ignore", invalid="ignore"):
if np.isnan(null_val):
mask = ~np.isnan(y_true)
else:
mask = np.not_equal(y_true, null_val)
mask = mask.astype("float32")
mask /= np.mean(mask)
mape = np.abs(np.divide((y_pred - y_true).astype("float32"), y_true))
mape = np.nan_to_num(mask * mape)
return np.mean(mape) * 100
def maskedMAPETest(y_pred, y_gt, mask, separately=False):
null_val = 0
acc = torch.zeros_like(mask, device=mask.device)
eps = 1e-8
muX = y_pred[:, :, 0]
muY = y_pred[:, :, 1]
x = y_gt[:, :, 0]
y = y_gt[:, :, 1]
out_x = torch.abs(torch.divide((x - muX), x + eps))
out_y = torch.abs(torch.divide((y - muY), y + eps))
out = out_x + out_y
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
out_0 = torch.zeros_like(out, device=out.device)
out = torch.where(out > 0.75, out_0, out)
acc[:, :, 0] = out
acc[:, :, 1] = out
acc = acc * mask
if separately:
return acc[:, :, 0], mask[:, :, 0]
else:
lossVal = torch.sum(acc[:, :, 0], dim=1)
counts = torch.sum(mask[:, :, 0], dim=1)
return lossVal, counts
def TTC_test(plan_veh_real, tar_veh_pred, tar_veh_real, tar_count):
ttc_threshold = 3
delta_t = 0.2
target_count = 0
sum_count = 0
ttc_count = 0
speed_count = 0
for num in range(len(tar_count)):
plan_veh_x = plan_veh_real[1, num, 0]
plan_veh_y = plan_veh_real[1, num, 1]
plan_speed = (plan_veh_real[1, num, 1] - plan_veh_real[0, num, 1])/delta_t
for target1 in range(tar_count[num]-1):
target1_x = tar_veh_pred[1, target_count+target1, 0] + tar_veh_real[0, target_count+target1, 0]
target1_y = tar_veh_pred[1, target_count+target1, 1] + tar_veh_real[0, target_count+target1, 1]
target1_speed = (tar_veh_pred[1, target_count+target1, 1] - tar_veh_pred[0, target_count+target1, 1]) / delta_t
if TTC_judge(plan_veh_x, plan_veh_y, plan_speed, target1_x, target1_y, target1_speed, ttc_threshold) == 0:
sum_count += 1
elif TTC_judge(plan_veh_x, plan_veh_y, plan_speed, target1_x, target1_y, target1_speed, ttc_threshold) == 1:
sum_count += 1
ttc_count += 1
speed_count += 1
else:
sum_count += 1
speed_count += 1
for target2 in range(target1+1, tar_count[num]):
target2_x = tar_veh_pred[1, target_count + target2, 0] + tar_veh_real[0, target_count + target2, 0]
target2_y = tar_veh_pred[1, target_count + target2, 1] + tar_veh_real[0, target_count + target2, 1]
target2_speed = (tar_veh_pred[1, target_count + target2, 1] - tar_veh_pred[0, target_count + target2, 1]) / delta_t
if TTC_judge(target1_x, target1_y, target1_speed, target2_x, target2_y, target2_speed, ttc_threshold) == 0:
sum_count += 1
elif TTC_judge(target1_x, target1_y, target1_speed, target2_x, target2_y, target2_speed, ttc_threshold) == 1:
sum_count += 1
ttc_count += 1
speed_count += 1
else:
sum_count += 1
speed_count += 1
target_count += tar_count[num]
ttc_rate = ttc_count / sum_count
speed_rate = speed_count / sum_count
return ttc_rate, speed_rate
def TTC_judge(veh1_x, veh1_y, veh1_speed, veh2_x, veh2_y, veh2_speed, ttc_threshold):
if abs(veh1_x-veh2_x) > 10:
return 0
if veh1_y > veh2_y:
if veh1_speed >= veh2_speed:
return 0
else:
ttc = (veh1_y - veh2_y)/(veh2_speed - veh1_speed)
if ttc <= ttc_threshold:
return 1
else:
return 2
else:
if veh2_speed >= veh1_speed:
return 0
else:
ttc = (veh2_y - veh1_y)/(veh1_speed - veh2_speed)
if ttc <= ttc_threshold:
return 1
else:
return 2
def idm_loss_fn(pred_pos, true_pos, mask):
"""
pred_pos: [B, T, 2]
true_pos: [B, T, 2]
mask: [B, T] or [B, T, 1]
"""
if mask.dim() == 3:
mask = mask[..., 0]
diff = (pred_pos - true_pos) ** 2
error = torch.sum(diff, dim=-1)
error = error * mask
return torch.sum(error) / torch.sum(mask)
def idm_accel_torch(params, s, v, dv, v_pre):
"""
输入 shape = [B, T],返回 [B, T]
params = (v0,T,s0,a_max,b,delta) – 6 个 float / tensor
"""
v0, T, s0, a_max, b, delta = params
s = torch.clamp(s, min=0.1)
s_star = s0 + v*T + v*dv/(2*torch.sqrt(a_max*b))
acc = a_max * (1 - (v/v0).pow(delta) - (s_star/s).pow(2))
return torch.clamp(acc, -5.0, 3.0)
def logsumexp(inputs, dim=None, keepdim=False):
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs