import os
import copy
import json
import logging
import re
import torch
import tqdm
import pandas as pd
from torch import distributed as dist
from megatron.core import mpu
from megatron.training import get_args
from mindspeed_llm.tasks.utils.error_utils import check_divisible_by_zero
from mindspeed_llm.tasks.preprocess.templates import Role
from mindspeed_llm.tasks.evaluation.eval_api.chat import Chat
from mindspeed_llm.tasks.evaluation.eval_api.dataset_eval import DatasetEval
from mindspeed_llm.tasks.evaluation.eval_impl.template import BBH_TEMPLATE_DIR, BBH_COT_TEMPLATE_DIR, get_eval_template
from mindspeed_llm.tasks.evaluation.eval_utils.bbh_utils import bbh_mcq_postprocess, bbh_freeform_postprocess, bbh_true_or_false_questions
from mindspeed_llm.tasks.evaluation.utils import get_final_list_dataset
logger = logging.getLogger(__name__)
bbh_multiple_choice_sets = [
'temporal_sequences',
'disambiguation_qa',
'date_understanding',
'tracking_shuffled_objects_three_objects',
'penguins_in_a_table',
'geometric_shapes',
'snarks',
'ruin_names',
'tracking_shuffled_objects_seven_objects',
'tracking_shuffled_objects_five_objects',
'logical_deduction_three_objects',
'hyperbaton',
'logical_deduction_five_objects',
'logical_deduction_seven_objects',
'movie_recommendation',
'salient_translation_error_detection',
'reasoning_about_colored_objects',
]
bbh_free_form_sets = [
'multistep_arithmetic_two',
'navigate',
'dyck_languages',
'word_sorting',
'sports_understanding',
'boolean_expressions',
'object_counting',
'formal_fallacies',
'causal_judgement',
'web_of_lies',
]
class BBHEval(DatasetEval):
def __init__(self, test_dir, eval_args,
instruction_template="{fewshot_template}Q: {question}\nA:"):
self.test_dir = test_dir
self.instruction_template = instruction_template
self.batch_size = eval_args.evaluation_batch_size
self.rank = dist.get_rank()
self.file_pbar = None
self.task_pbar = None
self.eval_template = None
self.prompt_type = None
self.broadcast_rank = [[0]]
if eval_args.prompt_type is not None:
self.prompt_type = eval_args.prompt_type.strip()
self.eval_template = get_eval_template(eval_args.eval_language)
self.max_eval_samples = eval_args.max_eval_samples
def eval(self, chat: Chat) -> (dict, pd.DataFrame):
answer_result = {}
total_acc_n = 0
total_n = 0
score_datas = []
sample_n = 0
rank = None
args = get_args()
if args.chain_of_thought:
with open(BBH_COT_TEMPLATE_DIR, encoding='utf-8') as f:
bbh_template = json.load(f)
else:
with open(BBH_TEMPLATE_DIR, encoding='utf-8') as f:
bbh_template = json.load(f)
if self.rank == 0:
self.file_pbar = tqdm.tqdm(total=len(os.listdir(self.test_dir)), desc="total")
for file in os.listdir(self.test_dir):
file_path = os.path.join(self.test_dir, file)
if not os.path.exists(file_path):
raise FileExistsError("The file ({}) does not exist !".format(file_path))
with open(file_path, encoding='utf-8') as f:
bbh_dataset = json.load(f)
subject_name = re.sub(r'(?:_test|_val|_dev)?\.\w+$', "", file)
subject_result = {}
if 'examples' in bbh_dataset:
sample_n += len(bbh_dataset['examples'])
else:
raise ValueError(f"key 'examples' not found in the bbh_dataset")
if args.chain_of_thought:
sorted_dataset = bbh_dataset['examples']
else:
sorted_dataset = sorted(bbh_dataset['examples'], key=lambda x: len(x['input']))
instructions = []
targets = []
acc_n = 0
result_mapping, choices = None, None
if self.max_eval_samples is not None:
origin_len = len(sorted_dataset)
sorted_dataset = (
sorted_dataset[0:min(self.max_eval_samples, origin_len)]
)
logger.info("%s length from %s to %s !!!", subject_name, str(origin_len), str(len(sorted_dataset)))
if subject_name not in bbh_template:
logging.error(f"missing '{subject_name}' instruction_template in {BBH_TEMPLATE_DIR}")
if self.file_pbar is not None:
self.file_pbar.update()
continue
if args.broadcast:
group = self.broadcast_rank
align_start_dp_rank = 0
else:
sorted_dataset, group, align_start_dp_rank = get_final_list_dataset(sorted_dataset, dist.get_world_size(), args.tensor_model_parallel_size, args.pipeline_model_parallel_size)
if dist.get_rank() == 0:
self.task_pbar = tqdm.tqdm(total=len(sorted_dataset), desc=file, leave=False)
for idx, item in enumerate(sorted_dataset):
instruction = self.instruction_template.format(fewshot_template=bbh_template[subject_name], question=item['input'])
if self.prompt_type is not None:
instructions.append([])
instruction = instruction.split("\n\nQ: ")
choices, answer_idx = self.format_instructions(instruction, instructions)
if subject_name in bbh_multiple_choice_sets:
result_mapping = {value.strip(): key for key, value in re.findall(r'\(([A-Z])\)\s*([^\(\)]+)', instruction[-1][:answer_idx])}
elif args.chain_of_thought:
instruction = bbh_template.get(subject_name)
target_question = "Q: " + item['input']
instruction += target_question
instruction += "\nA: Let's think step by step."
instructions.append(instruction)
else:
instructions.append(instruction)
targets.append(item['target'].lstrip('(').rstrip(')'))
if len(instructions) == self.batch_size or len(bbh_dataset['examples']) == idx + 1:
chat_results, rank = chat.chat(instruction=instructions, history=[])
if align_start_dp_rank >= 0 and len(sorted_dataset) == idx + 1 and mpu.get_data_parallel_rank() >= align_start_dp_rank:
chat_results = chat_results[:-1]
if chat_results:
for index, chat_result in enumerate(chat_results):
answer = chat_result[0].lstrip()
try:
if dist.get_rank() in group[0]:
if args.chain_of_thought:
answer = bbh_mcq_postprocess(chat_result[0]) if subject_name in bbh_multiple_choice_sets else bbh_freeform_postprocess(chat_result[0], subject_name)
else:
answer = self.extract_answer(index, chat, answer, subject_name, instructions, result_mapping, choices)
logger.info(f"correct: {targets[index]}, AI: {answer.splitlines()[0]}, with rank {dist.get_rank()}")
subject_result[str(idx - len(chat_results) + index + 1)] = answer.splitlines()[0]
if subject_result[str(idx - len(chat_results) + index + 1)].lower() == targets[index].lower():
acc_n += 1
except Exception as e:
subject_result[str(idx - len(chat_results) + index + 1)] = str(
e) + f". AI answer: {answer}"
instructions = []
targets = []
if self.task_pbar is not None:
self.task_pbar.update()
if dist.get_rank() in group[0]:
question_num = len(sorted_dataset)
if align_start_dp_rank >= 0 and mpu.get_data_parallel_rank() >= align_start_dp_rank:
question_num -= 1
if not args.broadcast:
local_tensor = torch.tensor([acc_n, question_num], device=torch.cuda.current_device())
dist.all_reduce(local_tensor, op=dist.ReduceOp.SUM, group=mpu.get_data_parallel_group())
acc_n, total_questions = local_tensor.tolist()
else:
total_questions = question_num
if dist.get_rank() == 0:
logger.info(f'{subject_name} has {acc_n} corrects in {total_questions} questions, with accuracy {acc_n / total_questions}')
total_n += total_questions
total_acc_n += acc_n
answer_result[subject_name] = subject_result
score_datas.append([subject_name, total_questions, acc_n / total_questions])
if self.task_pbar is not None:
self.task_pbar.close()
if self.file_pbar is not None:
self.file_pbar.update()
if dist.get_rank() == 0:
logger.info(f"bbh acc = {total_acc_n}/{total_n}={check_divisible_by_zero(total_acc_n, total_n)}")
score_datas.append(["total", total_n, check_divisible_by_zero(total_acc_n, total_n)])
score_df = pd.DataFrame(columns=['subject', 'question_n', 'acc'], data=score_datas)
return answer_result, score_df
def top_k_eval(self, ) -> (dict, pd.DataFrame):
pass
def format_instructions(self, instruction_set, instruction_list) -> (list, int):
for idx in range(1, len(instruction_set) - 1):
if idx == 1:
prompt = instruction_set[0] + " Question: " + instruction_set[1]
else:
prompt = instruction_set[idx]
if prompt[-3:] != '(A)':
answer_index = prompt.rfind('A')
else:
answer_index = prompt[:-2].rfind('A')
instruction_list[-1].extend([
{'role': Role.USER.value, 'content': prompt[:answer_index] + 'Answer: '},
{'role': Role.ASSISTANT.value, 'content': prompt[answer_index + 3:].strip()}
])
final_answer_index = instruction_set[-1].rfind('A')
instruction_list[-1].append({
'role': Role.USER.value,
'content': instruction_set[-1][:final_answer_index] + 'Answer: '
})
options = re.findall(r'\(([A-Z])\)', instruction_set[-1][:final_answer_index])
return options, final_answer_index
def get_best_choice(self, idx, model, instruction_set, options) -> str:
loss_records = []
for option in options:
modified_instruction = copy.deepcopy(instruction_set)
modified_instruction[idx][-1]['content'] += option
model.infer_model._init_tokenizer(get_args())
token_ids, _ = model.infer_model._tokenize(modified_instruction)
token_ids = torch.tensor(token_ids).to(torch.cuda.current_device())
attention_mask = torch.arange(token_ids.size(1)).unsqueeze(0).to(torch.bool).to(torch.cuda.current_device())
with torch.no_grad():
logits = model.forward(token_ids, None, attention_mask)
shifted_logits = logits[..., :-1, :].float()
shifted_labels = token_ids[..., 1:]
loss_fn = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0)
loss_values = loss_fn(
shifted_logits.view(-1, shifted_logits.size(-1)), shifted_labels.view(-1)
).view(shifted_labels.size())
avg_loss = loss_values.sum(-1).cpu().numpy() / max(token_ids.size(1), 1)
loss_records.append(avg_loss)
return options[loss_records.index(min(loss_records))]
def extract_answer(self, idx, chat_instance, response, subject, instruction_set, result_map=None, options=None) -> str:
if response:
response = re.split(r'\.|\n\n|\n|\)', re.sub(r'.*\(', '', response.splitlines()[0]))[0]
response = response.strip('()')
if result_map and subject in bbh_multiple_choice_sets and response in result_map:
response = result_map[response]
if options and subject in bbh_multiple_choice_sets and (not response or response not in options):
response = self.get_best_choice(idx, chat_instance.model, instruction_set, options)
if subject in bbh_true_or_false_questions:
match = re.search(r'\b(yes|no|true|false|invalid|valid)\b', response, re.IGNORECASE)
response = match.group() if match else response
return response