import datasets
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartPretrainedModel
from torch.utils.data import DataLoader
import torch
import config_dialogsum
from tqdm import tqdm
import os
from utils.rouge_with_pyrouge import rouge_with_pyrouge
import time
import json
import random
from bert_score import BERTScorer
from model.bart import BartForConditionalGeneration_TripleDecoder
from transformers import AutoTokenizer
from enhance_dataset_dialogsum.dialoguedataset import DialogSumDataset_total
from transformers import Seq2SeqTrainingArguments
from trainer import TripleDecoderTrainer
import numpy as np
import nltk
from rouge_score import rouge_scorer, scoring
from datasets import load_metric
class Experiment(object):
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(config_dialogsum.finetune_weight_path_tokenizer)
self.device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
print('######################################################################')
print('Device:', self.device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())
print(torch.cuda.get_device_name())
print('######################################################################')
self.total_dataset = DialogSumDataset_total(encoder_max_len = 1024,decoder_max_len=100,tokenizer = self.tokenizer,extra_supervision=True,paracomet=True,relation="xReason",supervision_relation="xIntent",roberta=False, sentence_transformer=False)
self.train_dataset = self.total_dataset.getTrainData()
self.eval_dataset = self.total_dataset.getEvalData()
self.test_dataset = self.total_dataset.getTestData()
print('######################################################################')
print('Training Dataset Size is : ')
print(len(self.train_dataset))
print('Validation Dataset Size is : ')
print(len(self.eval_dataset))
print('Test Dataset Size is : ')
print(len(self.test_dataset))
print('######################################################################')
self.finetune_args = Seq2SeqTrainingArguments(
output_dir = config_dialogsum.finetune_weight_path,
overwrite_output_dir = True,
do_train=True,
do_eval=True,
do_predict=True,
evaluation_strategy='epoch',
per_device_train_batch_size = config_dialogsum.train_batch_size,
per_device_eval_batch_size = config_dialogsum.val_batch_size,
learning_rate=config_dialogsum.init_lr,
weight_decay=config_dialogsum.weight_decay,
adam_beta1=config_dialogsum.adam_beta1,
adam_beta2=config_dialogsum.adam_beta2,
adam_epsilon=config_dialogsum.adam_eps,
num_train_epochs=config_dialogsum.epoch,
max_grad_norm=0.1,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
lr_scheduler_type='polynomial',
warmup_steps= config_dialogsum.warm_up,
logging_strategy="epoch",
save_strategy= "epoch",
save_total_limit=1,
fp16=True,
seed = 216,
load_best_model_at_end=True,
predict_with_generate=True,
prediction_loss_only=False,
generation_max_length=100,
generation_num_beams=20,
greater_is_better=True,
)
def load_model(self,path,device):
print("-----------load the best model----------")
self.model = BartForConditionalGeneration_TripleDecoder.from_pretrained(path)
self.model.to(device)
self.model.gradient_checkpointing_enable()
self.optimizer = torch.optim.AdamW(self.model.parameters())
def clear(self, folder_path):
files = os.listdir(folder_path)
for file in files:
file_path = os.path.join(folder_path, file)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except Exception as e:
print(f"Error while deleting {file_path}: {e}")
def train(self):
self.load_model(config_dialogsum.model_path,self.device)
finetune_trainer = TripleDecoderTrainer(
model = self.model,
args = self.finetune_args,
train_dataset = self.train_dataset,
eval_dataset = self.eval_dataset,
tokenizer = self.tokenizer
)
finetune_trainer.train()
finetune_trainer.save_model(config_dialogsum.best_finetune_weight_path)
def inference(self):
self.load_model(config_dialogsum.finetune_weight_path,self.device)
test_dataloader = DataLoader(dataset=self.test_dataset, batch_size=1, shuffle=False)
self.clear(config_dialogsum.label_path)
self.clear(config_dialogsum.predition_path))
total_loss = 0
rouge_1 = 0
rouge_2 = 0
rouge_L = 0
bert_s = 0
item = 1
with torch.no_grad():
for idx, data in enumerate(tqdm(test_dataloader),0):
x = data['input_ids'].to(self.device, dtype=torch.long)
mask = data['attention_mask'].to(self.device, dtype=torch.long)
y = data['labels'].to(self.device, dtype=torch.long)
generated_ids = self.model.generate(
input_ids= x,
attention_mask=mask,
max_length=100,
min_length=5,
num_beams=config_dialogsum.num_beams
)
generated_ids = generated_ids.cpu()
y = y.cpu()
decoded_preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
y = np.where(y != -100, y, self.tokenizer.pad_token_id)
decoded_labels = self.tokenizer.batch_decode(y, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
print("decoded_preds:",decoded_preds)
decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
print("decoded_labels:",decoded_labels)
with open(config_dialogsum.predition_path+"1.dec", 'w' ,encoding="utf-8") as file:
text = ''.join(decoded_preds).replace('\n', '')
file.write(text + "\n")
text1 = []
text1.append(text)
with open(config_dialogsum.label_path+"1.ref", 'w' ,encoding="utf-8") as file:
text = ''.join(decoded_labels).replace('\n', '')
file.write(text + "\n")
text2 = []
text2.append(text)
rouge1, rouge2, rougel = rouge_with_pyrouge(preds=config_dialogsum.predition_path, refs=config_dialogsum.label_path, item=1)
print("rouge1:",rouge1,"rouge2:",rouge2,"rougel:",rougel)
scorer = BERTScorer(model_type='please paste your bertscore path in here', num_layers=6)
P, R, F = scorer.score(text1,text2)
print("BERTS:",F.item())
rouge_1+=rouge1
rouge_2+=rouge2
rouge_L+=rougel
bert_s+=F.item()
print('final_ROUGE-1:',rouge_1/item,'final_ROUGE-2:',rouge_2/item,'final_ROUGE-L:',rouge_L/item,'平均BERTS:',bert_s/item)
item+=1