import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch_npu
import argparse
import logging
from torch_npu.npu.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from model import pipNet
from data import highwayTrajDataset
from utils_pt import initLogging, maskedNLL, maskedMSE, maskedNLLTest
from torch.optim.lr_scheduler import ReduceLROnPlateau
torch.npu.set_device(0)
parser = argparse.ArgumentParser(description='Training: Human_like Trajectory Prediction for Autonomous Driving')
parser.add_argument('--use_cuda', action="store_true", help='if use cuda (default: True)', default = True)
parser.add_argument('--use_planning', action="store_false", help='if use planning coupled module (default: True)',default = True)
parser.add_argument('--use_fusion', action="store_false", help='if use targets fusion module (default: True)',default = True)
parser.add_argument('--train_output_flag', action="store_false", help='if concatenate with true maneuver label (default: True)', default = True)
parser.add_argument('--batch_size', type=int, help='batch size to use (default: 64)', default=64)
parser.add_argument('--learning_rate', type=float, help='learning rate (default: 1e-3)', default=0.001)
parser.add_argument('--tensorboard', action="store_true", help='if use tensorboard (default: True)', default = True)
parser.add_argument('--grid_size', type=int, help='default: (25,5)', nargs=2, default = [25, 5])
parser.add_argument('--in_length', type=int, help='History sequence (default: 10)',default = 16)
parser.add_argument('--out_length', type=int, help='Predict sequence (default: 15)',default = 25)
parser.add_argument('--num_lat_classes', type=int, help='Classes of lateral behaviors', default = 3)
parser.add_argument('--num_lon_classes', type=int, help='Classes of longitute behaviors', default = 2)
parser.add_argument('--temporal_embedding_size', type=int, help='Embedding size of the input traj', default = 32)
parser.add_argument('--encoder_size', type=int, help='lstm encoder size', default = 64)
parser.add_argument('--decoder_size', type=int, help='lstm decoder size', default = 128)
parser.add_argument('--soc_conv_depth', type=int, help='The 1st social conv depth', default = 64)
parser.add_argument('--soc_conv2_depth', type=int, help='The 2nd social conv depth', default = 16)
parser.add_argument('--dynamics_encoding_size', type=int, help='Embedding size of the vehicle dynamic', default = 32)
parser.add_argument('--social_context_size', type=int, help='Embedding size of the social context tensor', default = 80)
parser.add_argument('--fuse_enc_size', type=int, help='Feature size to be fused', default = 112)
parser.add_argument('--num_layers', type=int, help='Num of Transformer Layers', default = 3)
parser.add_argument('--num_heads', type=int, help='Number of attention heads', default = 4)
parser.add_argument('--feed_forward_dim', type=int, help='Dimension of feed forward layer', default = 64)
parser.add_argument('--dropout', type=float, help='Dropout probability', default = 0.1)
parser.add_argument('--name', type=str, help='log name (default: "1")', default="npu_train")
parser.add_argument('--train_set', type=str, help='Path to train datasets', default='../autodl-tmp/Train_stop_and_go.mat')
parser.add_argument('--val_set', type=str, help='Path to validation datasets', default='../autodl-tmp/Val_stop_and_go.mat')
parser.add_argument("--num_workers", type=int, default=8, help="number of workers used for dataloader")
parser.add_argument('--pretrain_epochs', type=int, help='epochs of pre-training using MSE', default = 10)
parser.add_argument('--train_epochs', type=int, help='epochs of training using NLL', default = 20)
parser.add_argument('--eval_batch_num', type=int, default=20, help='Validation batches per epoch (0 means full validation set)')
parser.add_argument('--seed', type=int, default=3407, help='Global random seed for reproducibility')
parser.add_argument('--start_epoch', type=int, default=None, help='Start epoch for resuming training (optional)')
parser.add_argument('--continue_path', type=str, default="", help="Path to pretrained model checkpoint (optional)")
def set_global_seed(seed: int):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
def seed_worker(worker_id: int):
worker_seed = torch.initial_seed() % (2 ** 32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def train_model():
args = parser.parse_args()
set_global_seed(args.seed)
log_path = "./trained_models/{}/".format(args.name)
os.makedirs(log_path, exist_ok=True)
initLogging(log_file=log_path+'train.log')
if args.tensorboard:
logger = SummaryWriter(log_path + 'train-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs))
logger_val = SummaryWriter(log_path + 'validation-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs))
logging.info("------------- {} -------------".format(args.name))
logging.info("Batch size : {}".format(args.batch_size))
logging.info("Learning rate : {}".format(args.learning_rate))
logging.info("Seed : {}".format(args.seed))
logging.info("Use Planning Coupled: {}".format(args.use_planning))
logging.info("Use Target Fusion: {}".format(args.use_fusion))
PiP = pipNet(args)
PiP = PiP.npu()
start_epoch = 0
if args.continue_path and os.path.exists(args.continue_path) and args.start_epoch is not None:
PiP.load_state_dict(torch.load(args.continue_path))
start_epoch = args.start_epoch
logging.info(f"Resuming training from epoch {start_epoch}, loaded weights from {args.continue_path}")
else:
if not args.continue_path:
logging.info("No checkpoint path provided.")
elif not os.path.exists(args.continue_path):
logging.warning(f"Checkpoint path '{args.continue_path}' does not exist.")
elif args.start_epoch is None:
logging.info("Checkpoint path provided but no --start_epoch. Training from scratch.")
logging.info("Starting training from epoch 0.")
scaler = GradScaler(
init_scale=2.0,
growth_factor=1.5,
backoff_factor=0.25,
growth_interval=2000
)
optimizer = torch.optim.Adam(PiP.parameters(), lr=args.learning_rate)
crossEnt = nn.BCELoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
pretrainEpochs = args.pretrain_epochs
trainEpochs = args.train_epochs
batch_size = args.batch_size
eval_batch_num = args.eval_batch_num
logging.info("Train dataset: {}".format(args.train_set))
trSet = highwayTrajDataset(path=args.train_set,
targ_enc_size=args.social_context_size+args.dynamics_encoding_size,
grid_size=args.grid_size,
fit_plan_traj=False)
logging.info("Validation dataset: {}".format(args.val_set))
valSet = highwayTrajDataset(path=args.val_set,
targ_enc_size=args.social_context_size+args.dynamics_encoding_size,
grid_size=args.grid_size,
fit_plan_traj=True)
trDataloader = DataLoader(trSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=trSet.collate_fn, prefetch_factor=2, pin_memory=True, persistent_workers=True)
valDataloader = DataLoader(valSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=valSet.collate_fn, prefetch_factor=2, pin_memory=True, persistent_workers=True)
logging.info("DataSet Prepared : {} train data, {} validation data\n".format(len(trSet), len(valSet)))
logging.info("Network structure: {}\n".format(PiP))
def run_validation(epoch_num: int, max_batches: int = 0):
"""运行验证,返回平均loss和实际验证的batch数"""
total_val_loss = 0.0
val_batches_count = 0
with torch.no_grad():
PiP.eval()
PiP.train_output_flag = False
for val_i, data in enumerate(valDataloader):
nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, idx, space_h, dv, v_pre = data
if args.use_cuda:
nbsHist = nbsHist.npu()
nbsMask = nbsMask.npu()
planFut = planFut.npu()
planMask = planMask.npu()
targsHist = targsHist.npu()
targsEncMask = targsEncMask.npu()
lat_enc = lat_enc.npu()
lon_enc = lon_enc.npu()
targsFut = targsFut.npu()
targsFutMask = targsFutMask.npu()
space_h = space_h.npu()
dv = dv.npu()
v_pre = v_pre.npu()
if epoch_num < pretrainEpochs:
PiP.train_output_flag = True
fut_pred, _, _ = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask,
lat_enc, lon_enc, idx)
batch_loss = maskedMSE(fut_pred, targsFut, targsFutMask)
else:
fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist,
targsEncMask, lat_enc, lon_enc, idx, space_h, dv, v_pre)
batch_loss = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut, targsFutMask, avg_along_time=True)
total_val_loss += batch_loss.item()
val_batches_count += 1
if max_batches > 0 and val_i + 1 >= max_batches:
break
PiP.train_output_flag = True
PiP.train()
avg_val = total_val_loss / max(val_batches_count, 1)
return avg_val, val_batches_count
data_stream = torch.npu.Stream()
for epoch_num in range(start_epoch, pretrainEpochs + trainEpochs):
epoch_start_time = time.time()
batch_checkpoint_time = epoch_start_time
if epoch_num == 0:
logging.info('Pretrain with MSE loss')
elif epoch_num == pretrainEpochs:
logging.info('Train with NLL loss')
avg_time_tr, avg_loss_tr = 0, 0
PiP.train()
PiP.train_output_flag = True
total_batches = len(trDataloader)
for i, data in enumerate(trDataloader):
st_time = time.time()
(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask,
targsFut, targsFutMask, lat_enc, lon_enc, _, space_h, dv, v_pre) = data
if args.use_cuda:
with torch.npu.stream(data_stream):
nbsHist = nbsHist.contiguous().npu(non_blocking=True)
nbsMask = nbsMask.contiguous().npu(non_blocking=True)
planFut = planFut.contiguous().npu(non_blocking=True)
planMask = planMask.contiguous().npu(non_blocking=True)
targsHist = targsHist.contiguous().npu(non_blocking=True)
targsEncMask = targsEncMask.contiguous().npu(non_blocking=True)
lat_enc = lat_enc.contiguous().npu(non_blocking=True)
lon_enc = lon_enc.contiguous().npu(non_blocking=True)
targsFut = targsFut.contiguous().npu(non_blocking=True)
targsFutMask = targsFutMask.contiguous().npu(non_blocking=True)
space_h = space_h.contiguous().npu(non_blocking=True)
dv = dv.contiguous().npu(non_blocking=True)
v_pre = v_pre.contiguous().npu(non_blocking=True)
torch.npu.current_stream().wait_stream(data_stream)
with autocast():
fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask,
targsHist, targsEncMask, lat_enc, lon_enc,
_, space_h, dv, v_pre)
if epoch_num < pretrainEpochs:
l = maskedMSE(fut_pred, targsFut, targsFutMask)
else:
l = maskedNLL(fut_pred, targsFut, targsFutMask) + crossEnt(lat_pred, lat_enc) + crossEnt(lon_pred, lon_enc)
optimizer.zero_grad(set_to_none=True)
scaler.scale(l).backward()
torch.nn.utils.clip_grad_norm_(PiP.parameters(), max_norm=10.0)
scaler.step(optimizer)
scaler.update()
batch_time = time.time() - st_time
avg_loss_tr += l.item()
avg_time_tr += batch_time
if i % 100 == 99:
torch.npu.synchronize()
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] "
f"Epoch {epoch_num+1} | Batch {i+1} | "
f"100-batch time: {time.time() - batch_checkpoint_time:.2f}s | "
f"Loss: {avg_loss_tr/100:.4f}")
batch_checkpoint_time = time.time()
if args.tensorboard:
global_step = epoch_num * len(trDataloader) + i + 1
logger.add_scalar("Train/RMSE" if epoch_num < pretrainEpochs else "Train/NLL",
avg_loss_tr / 100, global_step)
avg_time_tr, avg_loss_tr = 0, 0
avg_loss_val, val_batches_count = run_validation(epoch_num, eval_batch_num)
logging.info(f"Epoch {epoch_num+1} validation loss: {avg_loss_val:.4f}")
if args.tensorboard:
logger_val.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL",
avg_loss_val, epoch_num+1)
epoCount = epoch_num + 1
if epoCount < pretrainEpochs:
torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, epoCount, 0))
else:
torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, pretrainEpochs, epoCount - pretrainEpochs))
if torch.isfinite(torch.tensor(avg_loss_val)):
scheduler.step(avg_loss_val)
else:
logging.warning(f"Epoch {epoch_num+1}: Validation loss is NaN, not updating scheduler")
if args.tensorboard:
lr_now = optimizer.param_groups[0]['lr']
logger.add_scalar("LearningRate", lr_now, epoch_num)
torch.npu.synchronize()
epoch_total_time = time.time() - epoch_start_time
logging.info(f"Epoch {epoch_num+1} 耗时: {epoch_total_time:.2f} 秒")
torch.save(PiP.state_dict(), log_path+"{}.tar".format(args.name))
logging.info("Model saved in trained_models/{}/{}.tar\n".format(args.name, args.name))
if __name__ == '__main__':
train_model()