import json
import os
import logging
import re
import sys
import subprocess
from typing import Iterable, Dict
import pandas as pd
import tqdm
import torch
from megatron.core import mpu
from megatron.training import get_args
from torch import distributed as dist
from mindspeed_llm.tasks.evaluation.eval_impl.template import CODE_TEST_LOG_DIR
from mindspeed_llm.tasks.evaluation.eval_api.dataset_eval import DatasetEval
from mindspeed_llm.tasks.evaluation.eval_api.chat import Chat
from mindspeed_llm.tasks.utils.error_utils import check_divisible_by_zero
from mindspeed_llm.training.utils import WRITE_FILE_DEFAULT_FLAGS, WRITE_FILE_DEFAULT_MODES
from mindspeed_llm.tasks.evaluation.eval_utils.human_utils import humaneval_postprocess, get_score
logger = logging.getLogger(__name__)
def extract_answer_code(answer, task: dict):
"""
:param answer:
:param task:
:return:
"""
task_id = task['task_id']
target_func = task['entry_point']
test_case = task['test']
save_file = f"{task_id}.py".replace("/", "-")
code = answer
code_lines = code.split("\n")
target_func_flag = False
if not os.path.exists(CODE_TEST_LOG_DIR):
os.makedirs(CODE_TEST_LOG_DIR)
test_code_path = "{}/{}".format(CODE_TEST_LOG_DIR, save_file)
with os.fdopen(os.open(test_code_path, WRITE_FILE_DEFAULT_FLAGS, WRITE_FILE_DEFAULT_MODES), 'w') as f:
f.write("from typing import List\n")
f.write("import math\n")
for i, line in enumerate(code_lines):
if i == 0 and line.lower() == "python":
continue
line_strip = line.strip()
if len(line_strip) < 1:
continue
if line_strip[0] == line[0]:
if line.startswith("from") or line.startswith("import"):
f.write(line)
f.write('\n')
elif line.startswith("def"):
if re.split(r"\s+", line)[1] == target_func:
target_func_flag = True
f.write(line)
f.write('\n')
else:
if target_func_flag:
break
else:
f.write(line)
f.write('\n')
f.write(test_case)
f.write('\n')
f.write(f'check({target_func})')
return test_code_path
class HumanEval(DatasetEval):
def __init__(self, test_dir, eval_args):
self.test_dir = test_dir
instruction_template = eval_args.instruction_template
if instruction_template:
self.instruction_template = instruction_template
else:
self.instruction_template = "The definition and function description of the python function are as follows. " \
"Please complete the implementation of the python function.\n{prompt}"
self.rank = dist.get_rank()
self.file_pbar = None
self.task_pbar = None
self.prompt = 'Complete the following python code:\n{prompt}'
def read_problems(self) -> Dict[str, Dict]:
return {task["task_id"]: task for task in self.stream_jsonl(self.test_dir)}
def stream_jsonl(self, test_dir: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
for file in os.listdir(test_dir):
test_code_path = os.path.join(self.test_dir, file)
if not os.path.exists(test_code_path):
raise FileNotFoundError(f"Error: {test_code_path} does not exist.")
with open(test_code_path, 'r') as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
def eval(self, chat: Chat) -> (dict, pd.DataFrame):
problems = self.read_problems()
success_n = 0
rank = None
answer_result = {}
args = get_args()
predictions, references = [], []
if self.rank == 0:
self.task_pbar = tqdm.tqdm(total=len(problems), leave=False)
if not args.alternative_prompt:
for _, (task_id, task) in enumerate(problems.items()):
instruction = self.instruction_template.format(prompt=task['prompt'])
chat_result, rank = chat.beam_search_chat(instruction=instruction, history=[])
answer = None
if chat_result:
answer = chat_result[0].lstrip()
try:
if rank == 0:
python_execute = sys.executable
answer = task['prompt'] + ' ' + answer
logger.info(f'answer: {answer}')
test_file = extract_answer_code(answer, task)
result = subprocess.run([python_execute, test_file], capture_output=True, timeout=10)
os.remove(test_file)
if result.returncode != 0:
error_msg = result.stderr.decode("utf-8")
logger.info(error_msg)
answer_result[task_id] = error_msg
else:
answer_result[task_id] = 'passed'
success_n += 1
logger.info("%s : passed , acc : %s", task_id,
check_divisible_by_zero(success_n, len(problems)))
except Exception as e:
if rank == 0:
logger.info("%s failed. %s", task_id, e)
finally:
pass
if self.task_pbar is not None:
self.task_pbar.update()
else:
for _, (task_id, task) in enumerate(problems.items()):
instruction = self.prompt.format(prompt=task['prompt'])
chat_result, rank = chat.beam_search_chat(instruction=instruction, history=[])
logger.info(chat_result[0])
if rank == 0:
answer = humaneval_postprocess(chat_result[0])
predictions.append(answer)
references.append(task_id)
if self.task_pbar is not None:
self.task_pbar.update()
if self.task_pbar is not None:
self.task_pbar.close()
if args.alternative_prompt:
logger.info(f'Evaluating Human Eval Prediction...')
result = get_score(predictions, references, problems.values(), problems)
if rank == 0:
logger.info(f"Human Eval accuracy = {result.get('humaneval_pass@1', None)}")
else:
if rank == 0:
logger.info("acc = %s", {check_divisible_by_zero(success_n, len(problems))})
return answer_result, None
def top_k_eval(self, ) -> (dict, pd.DataFrame):
pass