import torch
if torch.__version__ >= "1.8":
import torch_npu
print(torch.__version__)
import sys
from dataset import VideoDataSet
import os
import json
import torch.nn.parallel
import torch.optim as optim
import numpy as np
import opts
from models import BMN
import pandas as pd
from post_processing import BMN_post_processing
from eval import evaluation_proposal
from apex import amp
import torch.distributed as dist
import time
import torch.npu
from loss_function import bmn_loss_func, get_mask
sys.dont_write_bytecode = True
def train_BMN(data_loader, model, optimizer, epoch, bm_mask):
model.train()
epoch_pemreg_loss = 0
epoch_pemclr_loss = 0
epoch_tem_loss = 0
epoch_loss = 0
for n_iter, (input_data, label_confidence, label_start, label_end) in enumerate(data_loader):
if opt["local_rank"] == 0:
if n_iter == 0:
time_iter0 = time.time()
if n_iter == 5:
time_iter5 = time.time()
input_data = input_data.npu()
label_start = label_start.npu()
label_end = label_end.npu()
label_confidence = label_confidence.npu()
confidence_map, start, end = model(input_data)
loss = bmn_loss_func(confidence_map, start, end, label_confidence, label_start, label_end, bm_mask.npu())
if opt["world_size"] == 1 and epoch == 0 and n_iter == 6:
with torch.autograd.profiler.profile(use_npu=True) as prof:
optimizer.zero_grad()
with amp.scale_loss(loss[0], optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
prof.export_chrome_trace("910A_1p.prof")
else:
optimizer.zero_grad()
with amp.scale_loss(loss[0], optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
epoch_pemreg_loss += loss[2].cpu().detach().numpy()
epoch_pemclr_loss += loss[3].cpu().detach().numpy()
epoch_tem_loss += loss[1].cpu().detach().numpy()
epoch_loss += loss[0].cpu().detach().numpy()
if opt["local_rank"] == 0:
time_avg = time.time() - time_iter0
if (n_iter + 1) == (data_loader.dataset.__len__() // (opt["batch_size"])) and opt["local_rank"] == 0:
time_avg = time.time() - time_iter5
fps = (opt["batch_size"]) * (data_loader.dataset.__len__() // (opt["batch_size"])) / time_avg
print("Epoch: %d,FPS: %.2f,time: %.2f"%(epoch, fps, time_avg))
if opt["local_rank"] == 0:
print(
"BMN training loss(epoch %d, n_iter %d): tem_loss: %.03f, pem class_loss: %.03f, pem reg_loss: %.03f, total_loss: %.03f" % (
epoch, n_iter, epoch_tem_loss / (n_iter + 1),
epoch_pemclr_loss / (n_iter + 1),
epoch_pemreg_loss / (n_iter + 1),
epoch_loss / (n_iter + 1)))
def test_BMN(data_loader, model, epoch, bm_mask):
model.eval()
best_loss = 1e10
epoch_pemreg_loss = 0
epoch_pemclr_loss = 0
epoch_tem_loss = 0
epoch_loss = 0
for n_iter, (input_data, label_confidence, label_start, label_end) in enumerate(data_loader):
if n_iter == 0 and opt["local_rank"] == 0:
time_iter0 = time.time()
input_data = input_data.npu()
label_start = label_start.npu()
label_end = label_end.npu()
label_confidence = label_confidence.npu()
confidence_map, start, end = model(input_data)
loss = bmn_loss_func(confidence_map, start, end, label_confidence, label_start, label_end, bm_mask.npu())
epoch_pemreg_loss += loss[2].cpu().detach().numpy()
epoch_pemclr_loss += loss[3].cpu().detach().numpy()
epoch_tem_loss += loss[1].cpu().detach().numpy()
epoch_loss += loss[0].cpu().detach().numpy()
if opt["local_rank"] == 0:
time_avg = time.time() - time_iter0
if opt["local_rank"] == 0:
print(
"BMN test loss(epoch %d, n_iter %d): tem_loss: %.03f, pem class_loss: %.03f, pem reg_loss: %.03f, total_loss: %.03f" % (
epoch, n_iter, epoch_tem_loss / (n_iter + 1),
epoch_pemclr_loss / (n_iter + 1),
epoch_pemreg_loss / (n_iter + 1),
epoch_loss / (n_iter + 1)))
if opt["local_rank"] == 0:
state = {'epoch': epoch + 1,
'state_dict': model.state_dict()}
torch.save(state, opt["checkpoint_path"] + "/BMN_checkpoint.pth.tar")
if epoch_loss < best_loss:
best_loss = epoch_loss
torch.save(state, opt["checkpoint_path"] + "/BMN_best.pth.tar")
def BMN_Train(opt):
model = BMN(opt)
model = model.to(f'npu:{opt["local_rank"]}')
print(model)
if opt["finetune"] == 1:
checkpoint = torch.load(opt["pth_path"], map_location='npu:0')
base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
model.load_state_dict(base_dict)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt["training_lr"],weight_decay=opt["weight_decay"])
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', loss_scale=128, combine_grad=True)
if not isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt["local_rank"]], broadcast_buffers=False)
train_loader_sampler = torch.utils.data.distributed.DistributedSampler(VideoDataSet(opt, subset="train"))
train_loader_batch_size = int(opt["batch_size"] / int(opt["world_size"]))
train_loader = torch.utils.data.DataLoader(VideoDataSet(opt, subset="train"),
batch_size=train_loader_batch_size, shuffle=False,
num_workers=8, pin_memory=False, drop_last = True, sampler = train_loader_sampler)
test_loader_sampler = torch.utils.data.distributed.DistributedSampler(VideoDataSet(opt, subset="validation"))
test_loader_batch_size = int(opt["batch_size"] / int(opt["world_size"]))
test_loader = torch.utils.data.DataLoader(VideoDataSet(opt, subset="validation"),
batch_size=test_loader_batch_size, shuffle=False,
num_workers=8, pin_memory=False, drop_last = True, sampler = test_loader_sampler)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["step_size"], gamma=opt["step_gamma"])
bm_mask = get_mask(opt["temporal_scale"])
for epoch in range(opt["train_epochs"]):
train_loader.sampler.set_epoch(epoch)
test_loader.sampler.set_epoch(epoch)
train_BMN(train_loader, model, optimizer, epoch, bm_mask)
test_BMN(test_loader, model, epoch, bm_mask)
scheduler.step()
def BMN_inference(opt):
model = BMN(opt)
model = model.to('npu:0')
checkpoint = torch.load(opt["checkpoint_path"] + "/BMN_best.pth.tar", map_location='npu:0')
base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
model.load_state_dict(base_dict)
model.eval()
test_loader = torch.utils.data.DataLoader(VideoDataSet(opt, subset="validation"),
batch_size=1, shuffle=False,
num_workers=8, pin_memory=True, drop_last=False)
tscale = opt["temporal_scale"]
with torch.no_grad():
for idx, input_data in test_loader:
video_name = test_loader.dataset.video_list[idx[0]]
input_data = input_data.npu()
confidence_map, start, end = model(input_data)
start_scores = start[0].detach().cpu().numpy()
end_scores = end[0].detach().cpu().numpy()
clr_confidence = (confidence_map[0][1]).detach().cpu().numpy()
reg_confidence = (confidence_map[0][0]).detach().cpu().numpy()
new_props = []
for idx in range(tscale):
for jdx in range(tscale):
start_index = idx
end_index = jdx + 1
if start_index < end_index and end_index<tscale :
xmin = start_index / tscale
xmax = end_index / tscale
xmin_score = start_scores[start_index]
xmax_score = end_scores[end_index]
clr_score = clr_confidence[idx, jdx]
reg_score = reg_confidence[idx, jdx]
score = xmin_score * xmax_score * clr_score * reg_score
new_props.append([xmin, xmax, xmin_score, xmax_score, clr_score, reg_score, score])
new_props = np.stack(new_props)
col_name = ["xmin", "xmax", "xmin_score", "xmax_score", "clr_score", "reg_socre", "score"]
new_df = pd.DataFrame(new_props, columns=col_name)
new_df.to_csv("./output/BMN_results/" + video_name + ".csv", index=False)
def main(opt):
if opt["is_distributed"] == 0:
torch.npu.set_device(0)
else:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29681'
dist.init_process_group(backend='hccl',world_size=opt["world_size"], rank=opt["local_rank"])
local_device = torch.device(f'npu:{opt["local_rank"]}')
torch.npu.set_device(local_device)
if opt["local_rank"] == 0:
print("using npu :{}".format(opt["DeviceID"]))
opt["feature_path"] = opt["data_path"]
if opt["mode"] == "train":
BMN_Train(opt)
elif opt["mode"] == "inference":
if not os.path.exists("output/BMN_results"):
os.makedirs("output/BMN_results")
BMN_inference(opt)
print("Post processing start")
BMN_post_processing(opt)
print("Post processing finished")
evaluation_proposal(opt)
elif opt["mode"] == "full":
opt["mode"] = "train"
BMN_Train(opt)
if opt["local_rank"] == 0:
opt["mode"] = "inference"
if not os.path.exists("output/BMN_results"):
os.makedirs("output/BMN_results")
BMN_inference(opt)
print("Post processing start")
BMN_post_processing(opt)
print("Post processing finished")
evaluation_proposal(opt)
if __name__ == '__main__':
opt = opts.parse_opt()
opt = vars(opt)
if not os.path.exists(opt["checkpoint_path"]):
os.makedirs(opt["checkpoint_path"])
opt_file = open(opt["checkpoint_path"] + "/opts.json", "w")
json.dump(opt, opt_file)
opt_file.close()
print(opt)
main(opt)