import re
import json
import gzip
import torch
import numpy
import random
import logging
import itertools
import numpy as np
from typing import *
LANGUAGE_TAG = {
"c" : "// language: C",
"c++" : "// language: C++",
"cpp" : "// language: C++",
"c#" : "// language: C#",
"csharp" : "// language: C#",
"c-sharp" : "// language: C#",
"css" : "/* language: CSS */",
"cuda" : "// language: Cuda",
"dart" : "// language: Dart",
"lua" : "// language: Lua",
"objectivec" : "// language: Objective-C",
"objective-c" : "// language: Objective-C",
"objective-c++": "// language: Objective-C++",
"python" : "# language: Python",
"perl" : "# language: Perl",
"prolog" : f"% language: Prolog",
"swift" : "// language: swift",
"lisp" : "; language: Lisp",
"java" : "// language: Java",
"scala" : "// language: Scala",
"tex" : f"% language: TeX",
"vue" : "<!--language: Vue-->",
"markdown" : "<!--language: Markdown-->",
"html" : "<!--language: HTML-->",
"php" : "// language: PHP",
"js" : "// language: JavaScript",
"javascript" : "// language: JavaScript",
"typescript" : "// language: TypeScript",
"go" : "// language: Go",
"shell" : "# language: Shell",
"rust" : "// language: Rust",
"sql" : "-- language: SQL",
"kotlin" : "// language: Kotlin",
"vb" : "' language: Visual Basic",
"ruby" : "# language: Ruby",
"pascal" : "// language: Pascal",
"r" : "# language: R",
"fortran" : "!language: Fortran",
"lean" : "-- language: Lean",
"matlab" : f"% language: Matlab",
"delphi" : "{language: Delphi}",
"scheme" : "; language: Scheme",
"basic" : "' language: Basic",
"assembly" : "; language: Assembly",
"groovy" : "// language: Groovy",
"abap" : "* language: Abap",
"gdscript" : "# language: GDScript",
"haskell" : "-- language: Haskell",
"julia" : "# language: Julia",
"elixir" : "# language: Elixir",
"excel" : "' language: Excel",
"clojure" : "; language: Clojure",
"actionscript" : "// language: ActionScript",
"solidity" : "// language: Solidity",
"powershell" : "# language: PowerShell",
"erlang" : f"% language: Erlang",
"cobol" : "// language: Cobol",
"alloy" : "/* language: Alloy */",
"awk" : "// language: AWK",
"thrift" : "/* language: Thrift */",
"sparql" : "# language: SPARQL",
"augeas" : "// language: Augeas",
"cmake" : "# language: CMake",
"f-sharp" : "// language: F#",
"stan" : "// language: Stan",
"isabelle" : "(*language: Isabelle*)",
"dockerfile" : "# language: Dockerfile",
"rmarkdown" : "# language: RMarkdown",
"literate-agda": "-- language: Literate Agda",
"tcl" : "// language: Augeas",
"glsl" : "// language: GLSL",
"antlr" : "// language: ANTLR",
"verilog" : "// language: Verilog",
"racket" : "; language: Racket",
"standard-ml" : "(*language:Standard ML*)",
"elm" : "-- language: Elm",
"yaml" : "# language: YAML",
"smalltalk" : "'' language: Smalltalk",
"ocaml" : "(*language: OCaml*)",
"idris" : "-- language: Idris",
"visual-basic" : "' language: Visual Basic",
"protocol-buffer": "// language: Protocol Buffer",
"bluespec" : "// language: Bluespec",
"applescript" : "-- language: AppleScript",
"makefile" : "# language: Makefile",
"tcsh" : "# language: TCSH",
"maple" : "# language: Maple",
"systemverilog": "// language: SystemVerilog",
"literate-coffeescript": "# language: Literate CoffeeScript",
"vhdl" : "-- language: VHDL",
"restructuredtext": ".. language: reStructuredText",
"sas" : "* language: SAS",
"literate-haskell": "> language: Literate Haskell",
"java-server-pages": "// language: Java Server Pages",
"coffeescript" : "# language: CoffeeScript",
"emacs-lisp" : "; language: Emacs Lisp",
"mathematica" : "// language: Mathematica",
"xslt" : "<!--language: XSLT-->",
"zig" : "// language: Zig",
"common-lisp" : "; language: Common Lisp",
"stata" : "* language: Stata",
"agda" : "-- language: Agda",
"ada" : "-- language: Ada",
}
LANGUAGE_COMMENT_SIGN = {}
for lang in LANGUAGE_TAG:
LANGUAGE_COMMENT_SIGN[lang] = LANGUAGE_TAG[lang].split("language:")[0].strip()
IMPORT_HELPER = {
"python": [
"import math",
"import re",
"import sys",
"import copy",
"import datetime",
"import itertools",
"import collections",
"import heapq",
"import statistics",
"import functools",
"import hashlib",
"import numpy",
"import numpy as np",
"import string",
"from typing import *",
"from collections import *",
],
"go" : [
"math",
"strings",
"fmt",
"strconv",
"time",
"bytes",
"regexp",
"sort",
"math/rand",
"crypto/md5",
],
"cpp" : [
"#include<stdlib.h>",
"#include<algorithm>",
"#include<math.h>",
"#include<stdio.h>",
"#include<vector>",
"#include<string>",
"#include<climits>",
"#include<cstring>",
"#include<iostream>",
],
}
def set_random_seed(seed):
"""Set random seed for reproducability."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, "rt") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
else:
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
def stream_jsonl_all(filename: str) -> Iterable[Dict]:
results = []
if filename.endswith(".gz"):
fp = gzip.open(open(filename, "rb"), "rt")
else:
fp = open(filename, "r")
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()
return results
def read_dataset(
data_file: str = None,
dataset_type: str = "humanevalx",
) -> Dict:
if "humanevalx" in dataset_type.lower():
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
elif "mbpp" in dataset_type.lower():
problems = {task["task_id"]: task for task in stream_jsonl(data_file)}
task_ids = sorted(problems.keys())[10:510]
dataset = {}
for task_id in task_ids:
sample = problems[task_id]
description = sample["text"]
test_example = sample["test_list"][0]
prompt = f'"""\n{description}\n{test_example}\n"""\n'
sample["prompt"] = prompt
dataset[task_id] = sample
elif "ds1000" in dataset_type.lower():
from ds1000 import DS1000Dataset
ds1000 = DS1000Dataset(source_dir=data_file, libs="all", mode="Completion")
for lib in ds1000.libs:
for problem_id in range(len(ds1000[lib])):
prefix = ""
suffix = ""
insert_flag = False
first_line_flag = True
for line in ds1000[lib][problem_id]["prompt"].split("\n"):
if "[insert]" in line:
insert_flag = True
continue
if first_line_flag:
first_line_flag = False
else:
line = "\n" + line
if not insert_flag:
prefix += line
else:
suffix += line
else:
raise f"Dataset: {dataset_type} not supported."
return dataset
def read_translation_dataset(
data_file_src: str = None,
data_file_tgt: str = None,
lang_src: str = None,
lang_tgt: str = None,
dataset_type: str = "humanevalx",
) -> Dict:
if "humanevalx" in dataset_type.lower():
dataset_src = {task["task_id"]: task for task in stream_jsonl(data_file_src)}
dataset_tgt = {task["task_id"].split("/")[-1]: task for task in stream_jsonl(data_file_tgt)}
for k, sample in dataset_src.items():
prompt = "code translation\n"
if lang_src == "cpp":
prompt += "C++:\n"
elif lang_src == "js":
prompt += "JavaScript:\n"
else:
prompt += f"{lang_src}:\n".capitalize()
prompt += dataset_src[k]["declaration"] + "\n" + dataset_src[k]["canonical_solution"].rstrip() + "\n"
if lang_tgt == "cpp":
prompt += "C++:\n"
elif lang_tgt == "js":
prompt += "JavaScript:\n"
else:
prompt += f"{lang_tgt}:\n".capitalize()
prompt += dataset_tgt[k.split("/")[-1]]["declaration"]
dataset_src[k]["prompt"] = prompt
else:
raise f"Dataset: {dataset_type} not supported."
return dataset_src
def process_extra_prompt(
prompt: str,
language_type: str = "python",
dataset_type: str = None,
generation_mode: str = "completion",
) -> str:
"""
Processes the extra prompt.
"""
language = language_type.lower()
if dataset_type == "humanevalx":
extra_prompt = ""
prompt = prompt.strip()
if generation_mode == "instruction":
return "问:" + extra_prompt + prompt + "\n答:"
return extra_prompt + prompt
elif dataset_type == "mbpp":
extra_prompt = ""
prompt = prompt.strip()
return extra_prompt + prompt
else:
return prompt
def is_code_generation_finished(
code: str,
dataset_type: str = None,
language_type: str = None,
):
"""
Checks whether the generated code is finished.
"""
if dataset_type == "mbpp":
end_words = ["\ndef", "\nassert"]
for w in end_words:
if w == "\ndef":
if code.count(w) > 1:
return True
else:
if w in code:
return True
else:
if language_type.lower() == "python":
for line in code.split("\n"):
if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
return True
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
for w in end_words:
if w in code:
return True
elif language_type.lower() == "java":
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "go":
if "\nfunc main(" in code:
return True
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "js":
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "cpp":
if "\nint main()" in code:
return True
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "rust":
if "\nfn main()" in code:
return True
if code.count("{") + 1 == code.count("}"):
return True
return False
stop_words=["\nclass", "\nassert", '\n"""', "\nprint", "\nif"]
def first_block(string, stop_words):
"""Split off first block of code by scanning for class, def etc. on newlines."""
return re.split("|".join(stop_words), string)[0].rstrip()
def cleanup_code(
code: str,
dataset_type: str = None,
language_type: str = None,
):
"""
Cleans up the generated code.
"""
if dataset_type == "mbpp":
end_words = ["\nassert", "\ndef"]
for w in end_words:
if w == "\ndef":
if code.count(w) > 1:
code = code[:code.rfind(w)]
else:
code = code[:code.rfind(w)]
code = first_block(code, stop_words)
elif dataset_type == "humanevalx":
if language_type.lower() == "python":
code_splits = code.split("\n")
is_empty_line = False
ind_empty_line = None
for i, line in enumerate(code_splits):
if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
is_empty_line = True
ind_empty_line = i
break
if is_empty_line:
code = "\n".join(code_splits[:ind_empty_line])
else:
end_words = ["\ndef", "\nclass", "\n#", "\nassert", '\n"""', "\nprint", "\nif", "\n\n\n"]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
elif language_type.lower() == "java":
main_pos = code.find("public static void main")
if main_pos != -1:
code = code[:main_pos] + '}'
if '}' in code:
code = code[:code.rfind('}')] + '}'
if code.count('{') + 1 == code.count('}'):
code += "\n}"
elif language_type.lower() == "go":
if "\nfunc main(" in code:
code = code[:code.rfind("func main(")]
if '}' in code:
code = code[:code.rfind('}')] + '}'
elif language_type.lower() == "cpp":
if "\nint main()" in code:
code = code[:code.rfind("int main()")]
if '}' in code:
code = code[:code.rfind('}')] + '}'
elif language_type.lower() == "js":
if '}' in code:
code = code[:code.rfind('}')] + '}'
elif language_type.lower() == "rust":
if '}' in code:
code = code[:code.rfind('}')] + '}'
return code
def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int
) -> np.ndarray:
"""
Estimates pass@k of each problem and returns them in an array.
"""
def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
class Logger:
def __init__(self, name, log_level=logging.INFO, log_file=None, log_mode="both", disable_formatter=False):
self.logger = logging.getLogger(name)
self.logger.setLevel(log_level)
self.formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
if log_mode == "both" or log_mode == "terminal":
console_handler = logging.StreamHandler()
if not disable_formatter:
console_handler.setFormatter(self.formatter)
self.logger.addHandler(console_handler)
if log_file is not None:
if log_mode == "both" or log_mode == "file":
file_handler = logging.FileHandler(log_file, mode='w')
if not disable_formatter:
file_handler.setFormatter(self.formatter)
self.logger.addHandler(file_handler)
def add_file_handler(self, file_name):
file_handler = logging.FileHandler(file_name, mode='w')
file_handler.setFormatter(self.formatter)
self.logger.addHandler(file_handler)
def debug(self, message):
self.logger.debug(message)
def info(self, message):
self.logger.info(message)
def warning(self, message):
self.logger.warning(message)
def error(self, message):
self.logger.error(message)
def critical(self, message):
self.logger.critical(message)