f3c6888b创建于 2024年12月16日历史提交
import datasets
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartPretrainedModel
from torch.utils.data import DataLoader
import torch
import config_dialogsum
from tqdm import tqdm
import os
from utils.rouge_with_pyrouge import rouge_with_pyrouge
import time
import json
import random
from bert_score import BERTScorer
from model.bart import BartForConditionalGeneration_TripleDecoder
from transformers import AutoTokenizer
from enhance_dataset_dialogsum.dialoguedataset import DialogSumDataset_total
from transformers import Seq2SeqTrainingArguments
from trainer import TripleDecoderTrainer
import numpy as np
import nltk
from rouge_score import rouge_scorer, scoring
from datasets import load_metric


class Experiment(object):
    def __init__(self):
        #load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config_dialogsum.finetune_weight_path_tokenizer)
        
        #set GPU
        self.device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
        print('######################################################################')
        print('Device:', self.device)
        print('Current cuda device:', torch.cuda.current_device())
        print('Count of using GPUs:', torch.cuda.device_count())
        print(torch.cuda.get_device_name())
        print('######################################################################')
        
        

        # load dataset
        self.total_dataset = DialogSumDataset_total(encoder_max_len = 1024,decoder_max_len=100,tokenizer = self.tokenizer,extra_supervision=True,paracomet=True,relation="xReason",supervision_relation="xIntent",roberta=False, sentence_transformer=False)
        self.train_dataset = self.total_dataset.getTrainData()
        self.eval_dataset = self.total_dataset.getEvalData()
        self.test_dataset = self.total_dataset.getTestData()
        print('######################################################################')
        print('Training Dataset Size is : ')
        print(len(self.train_dataset))
        print('Validation Dataset Size is : ')
        print(len(self.eval_dataset))
        print('Test Dataset Size is : ')
        print(len(self.test_dataset))
        print('######################################################################')
    

        #set Training Arguments
        self.finetune_args = Seq2SeqTrainingArguments(
            output_dir = config_dialogsum.finetune_weight_path,
            overwrite_output_dir = True,
            do_train=True,
            do_eval=True,
            do_predict=True,
            evaluation_strategy='epoch',
            #eval_steps=args.display_step,
            per_device_train_batch_size = config_dialogsum.train_batch_size,
            per_device_eval_batch_size =  config_dialogsum.val_batch_size,
            learning_rate=config_dialogsum.init_lr,
            weight_decay=config_dialogsum.weight_decay,
            adam_beta1=config_dialogsum.adam_beta1,
            adam_beta2=config_dialogsum.adam_beta2,
            adam_epsilon=config_dialogsum.adam_eps,
            num_train_epochs=config_dialogsum.epoch,
            max_grad_norm=0.1,
            #label_smoothing_factor=0.1,
            gradient_accumulation_steps=2,
            gradient_checkpointing=True,
            # max_steps= ,
            lr_scheduler_type='polynomial',
            #warmup_ratio= ,
            warmup_steps= config_dialogsum.warm_up,
            logging_strategy="epoch",
            save_strategy= "epoch",
            save_total_limit=1,
            fp16=True,
            seed = 216,
            load_best_model_at_end=True,
            predict_with_generate=True,
            prediction_loss_only=False,
            generation_max_length=100,
            generation_num_beams=20,
            greater_is_better=True,
        )

    def load_model(self,path,device):
        print("-----------load the best model----------")
        self.model = BartForConditionalGeneration_TripleDecoder.from_pretrained(path)
        self.model.to(device)
        self.model.gradient_checkpointing_enable()
        self.optimizer = torch.optim.AdamW(self.model.parameters())
    
    def clear(self, folder_path):
        files = os.listdir(folder_path)
        for file in files:
            file_path = os.path.join(folder_path, file)
            try:
                if os.path.isfile(file_path):
                    os.remove(file_path)
                    
            except Exception as e:
                print(f"Error while deleting {file_path}: {e}")

    
    def train(self):
        self.load_model(config_dialogsum.model_path,self.device)
        finetune_trainer = TripleDecoderTrainer(
            model = self.model,
            args = self.finetune_args,
            train_dataset = self.train_dataset,
            eval_dataset = self.eval_dataset,
            tokenizer = self.tokenizer
        )
        finetune_trainer.train()
        finetune_trainer.save_model(config_dialogsum.best_finetune_weight_path)


    def inference(self):
        self.load_model(config_dialogsum.finetune_weight_path,self.device)
        test_dataloader = DataLoader(dataset=self.test_dataset, batch_size=1, shuffle=False)
        self.clear(config_dialogsum.label_path)
        self.clear(config_dialogsum.predition_path))
        total_loss = 0
        rouge_1 = 0
        rouge_2 = 0
        rouge_L = 0
        bert_s = 0
        item = 1
        with torch.no_grad():
            for idx, data in enumerate(tqdm(test_dataloader),0):
                x = data['input_ids'].to(self.device, dtype=torch.long)
                mask = data['attention_mask'].to(self.device, dtype=torch.long)
                y = data['labels'].to(self.device, dtype=torch.long)

                generated_ids = self.model.generate(
                    input_ids= x,
                    attention_mask=mask,
                    max_length=100,
                    min_length=5,
                    num_beams=config_dialogsum.num_beams
                )

                generated_ids = generated_ids.cpu()
                y = y.cpu()

                decoded_preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                y = np.where(y != -100, y, self.tokenizer.pad_token_id)
                decoded_labels = self.tokenizer.batch_decode(y, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
                print("decoded_preds:",decoded_preds)
                decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
                print("decoded_labels:",decoded_labels)
                with open(config_dialogsum.predition_path+"1.dec", 'w' ,encoding="utf-8") as file:
                    text = ''.join(decoded_preds).replace('\n', '')
                    file.write(text + "\n")
                    text1 = []
                    text1.append(text)
                with open(config_dialogsum.label_path+"1.ref", 'w' ,encoding="utf-8") as file:
                    text = ''.join(decoded_labels).replace('\n', '')
                    file.write(text + "\n")
                    text2 = []
                    text2.append(text)
                rouge1, rouge2, rougel = rouge_with_pyrouge(preds=config_dialogsum.predition_path, refs=config_dialogsum.label_path, item=1)
                print("rouge1:",rouge1,"rouge2:",rouge2,"rougel:",rougel)

                scorer = BERTScorer(model_type='please paste your bertscore path in here', num_layers=6)
                P, R, F = scorer.score(text1,text2)
                print("BERTS:",F.item())
                
                
                rouge_1+=rouge1
                rouge_2+=rouge2
                rouge_L+=rougel
                bert_s+=F.item()
                
                print('final_ROUGE-1:',rouge_1/item,'final_ROUGE-2:',rouge_2/item,'final_ROUGE-L:',rouge_L/item,'平均BERTS:',bert_s/item)
                item+=1