import argparse
import codecs
import json
import os
from string import punctuation
from utils import get_datas,get_datas2, extract_top_k_index
def get_sim_datas(dataset, mode):
f = codecs.open("./rep/{}/sim/{}_sim.json".format(dataset, mode), "r", encoding="utf-8")
lines = f.readlines()
datas = [json.loads(l) for l in lines]
print("Load {} finished, Data size:{}".format(mode, len(datas)))
return datas
def get_pure_dialogue_content(words, losses, speakers):
clean_words = []
clean_losses = []
for word, loss in zip(words, losses):
if word in speakers and (loss == 100 or loss == -100):
continue
if word == ":" and (loss == 100 or loss == -100):
continue
clean_words.append(word)
clean_losses.append(loss)
return clean_words, clean_losses
def get_speakers(words):
"""
Get all speaker names from a dialogue
"""
speakers = []
utterances = " ".join(words).split("<|endoftext|>")[:-1]
for u in utterances:
s = u.split(":")[0].strip()
if s:
speakers.append(s)
return list(set(speakers))
def remove_speaker_and_first_u(words, losses, speakers):
"""
Remove speaker names and the first utterance from dialogue
"""
first_u = []
clean_words = []
clean_losses = []
for word, loss in zip(words, losses):
if word in speakers and (loss == 100 or loss == -100):
continue
if word == ":" and (loss == 100 or loss == -100):
continue
if loss == 100:
first_u.append(word)
continue
clean_words.append(word)
clean_losses.append(loss)
return clean_words, clean_losses, first_u
def remove_first_word(words, losses):
first_words = []
clean_words = []
clean_losses = []
first_w_indices = []
for index, word in enumerate(words):
if word == "<|endoftext|>":
first_w_indices.append(index + 1)
for index, (word, loss) in enumerate(zip(words, losses)):
if word == "<|endoftext|>":
continue
if index in first_w_indices:
first_words.append(word)
continue
clean_words.append(word)
clean_losses.append(loss)
return clean_words, clean_losses, first_words
def get_keywords_num(words, ratio):
return int(len(words) * ratio)
def get_topics_num(utterances, ratio):
return int(len(utterances) * ratio)
def get_loss_for_each_utterance(words, losses):
assert len(words) == len(losses)
utterances = []
utterances_loss = []
utterance = []
utterance_loss = []
for word, loss in zip(words, losses):
if word == "<|endoftext|>":
if len(utterance_loss) == 0:
utterance_loss = [0]
utterances.append(utterance)
utterances_loss.append(utterance_loss)
utterance = []
utterance_loss = []
else:
utterance.append(word)
utterance_loss.append(loss)
assert len(utterances) == len(utterances_loss)
loss_for_each_u = [sum(utterance_loss) / len(utterance_loss) for utterance_loss in utterances_loss]
return utterances, loss_for_each_u
def topic_segmentation(utterances, loss_for_each_utterance, topic_ratio):
i_u_l = [(index, u, loss) for index, (u, loss) in enumerate(zip(utterances, loss_for_each_utterance))]
topics_num = get_topics_num(utterances, topic_ratio)
index2loss = dict()
for index, loss in enumerate(loss_for_each_utterance):
index2loss[index] = loss
new_i_u_l = sorted(i_u_l, key=lambda x: x[-1], reverse=True)
seg_ids = []
for triple in new_i_u_l:
if len(seg_ids) == topics_num:
break
index = triple[0]
if index == 0 or index == 1:
continue
else:
seg_ids.append(index)
return seg_ids
def get_topic_segment_indices(data, topic_ratio):
words = data[0]
losses = data[1]
speakers = get_speakers(words)
words, losses = get_pure_dialogue_content(words, losses, speakers)
utterances, loss_for_each_utterance = get_loss_for_each_utterance(words, losses)
seg_ids = topic_segmentation(utterances, loss_for_each_utterance, topic_ratio)
return seg_ids
def rdd(data, threshold):
"""retain Chinese comments"""
utterances = data[0]
similarity_matrix = data[1]
summary = data[2]
if(len(similarity_matrix)+1 < len(utterances)):
utterances = utterances[:len(similarity_matrix)+1:]
assert len(similarity_matrix) + 1 == len(utterances)
redundancy_indices = []
true_indices = [i for i in range(len(utterances) - 1, 0, -1)]
for index_i, sims in enumerate(similarity_matrix):
q_index = true_indices[index_i]
for index_j, sim in enumerate(reversed(sims)):
if sim >= threshold:
if (q_index - index_j) not in redundancy_indices:
redundancy_indices.append(
q_index - index_j)
else:
break
redundancy_indices = sorted(redundancy_indices)
return utterances, redundancy_indices, summary
def create_dialogue(utterances, segment_ids, redundancy_ids):
dialogue = []
for index, utterance in enumerate(utterances):
if index in segment_ids:
dialogue.append("[TS]")
if index in redundancy_ids:
words = utterance.split()
assert words[1] == ":"
words.insert(2, "[RD]")
utterance = " ".join(words)
dialogue.append(utterance)
return " <|endoftext|> ".join(dialogue)
def process(datas, sim_datas, ratio, threshold, topic_ratio, dataset, mode):
res = []
for data, sim_data in zip(datas, sim_datas):
output = data[2]
"""Topic Segmentation"""
segment_ids = get_topic_segment_indices(data, topic_ratio)
"""Redundancy Detection"""
utterances, redundancy_indices, _ = rdd(sim_data, threshold=threshold)
"""combine"""
dialogue = create_dialogue(utterances, segment_ids, redundancy_indices)
input = dialogue
input = input.replace("<|endoftext|>", "|")
res.append([input, output])
if not os.path.exists("./data/{}/final".format(dataset)):
os.makedirs("./data/{}/final".format(dataset))
with open("./data/{}/final/{}.json".format(dataset, mode), 'w') as file_obj:
json.dump(res, file_obj)
if __name__ == "__main__":
parser = argparse.ArgumentParser("")
parser.add_argument('-d', type=str, default="samsum", help='dataset')
args = parser.parse_args()
dataset = args.d
assert dataset == "samsum" or dataset == "ami" or dataset == "dialogsum"
train_datas, valid_datas, test_datas = get_datas("./loss/{}/word/train_loss.json".format(dataset),
"./loss/{}/word/valid_loss.json".format(dataset),
"./loss/{}/word/test_loss.json".format(dataset))
train_sim_datas = get_sim_datas(dataset, "train")
valid_sim_datas = get_sim_datas(dataset, "valid")
test_sim_datas = get_sim_datas(dataset, "test")
if dataset == "dialogsum":
process(train_datas, train_sim_datas, ratio=0.15, threshold=0.99, topic_ratio=0.15, dataset=dataset,
mode="train")
process(valid_datas, valid_sim_datas, ratio=0.15, threshold=0.99, topic_ratio=0.15, dataset=dataset,
mode="valid")
process(test_datas, test_sim_datas, ratio=0.15, threshold=0.99, topic_ratio=0.15, dataset=dataset, mode="test")
elif dataset == "ami":
process(train_datas, train_sim_datas, ratio=0.04, threshold=0.95, topic_ratio=0.05, dataset=dataset,
mode="train")
process(valid_datas, valid_sim_datas, ratio=0.04, threshold=0.95, topic_ratio=0.05, dataset=dataset,
mode="valid")
process(test_datas, test_sim_datas, ratio=0.04, threshold=0.95, topic_ratio=0.05, dataset=dataset, mode="test")