import argparse
import csv
import json
import os
import asyncio
from openai import AsyncOpenAI
from dotenv import load_dotenv
from pathlib import Path
env_file = Path.home() / ".openviking_benchmark_env"
load_dotenv(env_file)
async def grade_answer(
llm_client, model: str, question: str, gold_answer: str, response: str
) -> tuple[bool, str]:
system_prompt = """
You are an expert grader that determines if answers to questions match a gold standard answer
"""
ACCURACY_PROMPT = f"""
Your task is to label an answer to a question as 'CORRECT' or 'WRONG'. You will be given the following data:
(1) a question (posed by one user to another user),
(2) a 'gold' (ground truth) answer,
(3) a generated answer
which you will score as CORRECT/WRONG.
The point of the question is to ask about something one user should know about the other user based on their prior conversations.
The gold answer will usually be a concise and short answer that includes the referenced topic, for example:
Question: Do you remember what I got the last time I went to Hawaii?
Gold answer: A shell necklace
The generated answer might be much longer, but you should be generous with your grading - as long as it touches on the same topic as the gold answer, it should be counted as CORRECT.
For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date.
Now it's time for the real question:
Question: {question}
Gold answer: {gold_answer}
Generated answer: {response}
First, provide a short (one sentence) explanation of your reasoning, then finish with CORRECT or WRONG.
Do NOT include both CORRECT and WRONG in your response, or it will break the evaluation script.
Respond with JSON only: {{"is_correct": "CORRECT" or "WRONG", "reasoning": "your explanation"}}
"""
try:
resp = await llm_client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": ACCURACY_PROMPT},
],
temperature=0,
timeout=60,
)
content = resp.choices[0].message.content.strip()
start_idx = content.find("{")
end_idx = content.rfind("}")
if start_idx != -1 and end_idx != -1:
json_str = content[start_idx : end_idx + 1].strip()
result = json.loads(json_str)
is_correct = result.get("is_correct", "WRONG").strip().upper() == "CORRECT"
reasoning = result.get("reasoning", "")
return is_correct, reasoning
return False, f"[PARSE ERROR] Invalid response: {content}"
except Exception as e:
return False, f"[API ERROR] {str(e)}"
def load_answers(input_path: str) -> tuple[list[dict], list[str]]:
"""加载待评分的回答,返回所有行和表头"""
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input file not found: {input_path}")
with open(input_path, "r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
fieldnames = reader.fieldnames.copy()
if "reasoning" not in fieldnames:
fieldnames.append("reasoning")
rows = list(reader)
return rows, fieldnames
async def main():
parser = argparse.ArgumentParser(
description="VikingBot QA judge script, same logic as openclaw evaluation"
)
parser.add_argument(
"--input",
default="./result/locomo_qa_result_only_sys_memory.csv",
help="Path to QA result csv file, default: ./result/locomo_qa_result.csv",
)
parser.add_argument(
"--base-url",
default=os.getenv("OPENAI_BASE_URL", "https://ark.cn-beijing.volces.com/api/coding/v3"),
help="Volcengine API base URL, default from OPENAI_BASE_URL env var or https://ark.cn-beijing.volces.com/api/coding/v3",
)
parser.add_argument(
"--token",
default=os.getenv("ARK_API_KEY", os.getenv("OPENAI_API_KEY", "")),
help="Volcengine API token, default from ARK_API_KEY or OPENAI_API_KEY env var",
)
parser.add_argument(
"--model",
default=os.getenv("JUDGE_MODEL", "doubao-seed-2.0-code"),
help="Judge model name, default from JUDGE_MODEL env var or doubao-seed-2.0-code",
)
parser.add_argument(
"--parallel", type=int, default=5, help="Parallel request count, default: 5"
)
args = parser.parse_args()
if not args.token:
print("Error: API token is required")
print("\n请通过以下方式设置 API key:")
print(" 1. 创建 ~/.openviking_benchmark_env 文件,内容如下:")
print(" ARK_API_KEY=你的key")
print(" 2. 或者通过 --token 参数传入")
print(" 3. 或者设置环境变量: export ARK_API_KEY=你的key")
exit(1)
rows, fieldnames = load_answers(args.input)
valid_rows = []
ungraded = []
for i, row in enumerate(rows):
category = row.get("category", "")
if category == "5":
continue
valid_rows.append(i)
if not row.get("result"):
ungraded.append(i)
total = len(rows)
valid_total = len(valid_rows)
print(f"Total answers: {total}, valid (category != 5): {valid_total}, ungraded: {len(ungraded)}")
if not ungraded:
print("All valid answers already graded, exit")
return
client = AsyncOpenAI(base_url=args.base_url, api_key=args.token)
semaphore = asyncio.Semaphore(args.parallel)
file_lock = asyncio.Lock()
async def save_results():
"""保存当前所有结果到CSV文件,使用临时文件+原子替换避免文件损坏"""
async with file_lock:
temp_file = f"{args.input}.tmp"
with open(temp_file, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
os.replace(temp_file, args.input)
async def process_row(idx):
async with semaphore:
row = rows[idx]
question = row["question"]
gold = row.get("expected") or row.get("answer")
response = row["response"]
print(f"Grading {idx + 1}/{total}: {question[:60]}...")
is_correct, reasoning = await grade_answer(client, args.model, question, gold, response)
row["result"] = "CORRECT" if is_correct else "WRONG"
row["reasoning"] = reasoning
await save_results()
print(f"Saved result for {idx + 1}/{total}: {row['result']}")
return idx, row
tasks = [process_row(idx) for idx in ungraded]
await asyncio.gather(*tasks)
correct = 0
total_graded = 0
for row in rows:
category = row.get("category", "")
if category == "5":
continue
if row.get("result"):
total_graded += 1
if row.get("result") == "CORRECT":
correct += 1
accuracy = correct / total_graded if total_graded > 0 else 0.0
print(f"\nGrading completed: {correct}/{total_graded} correct, accuracy: {accuracy:.2%}")
print(f"All results saved to {args.input}")
if __name__ == "__main__":
asyncio.run(main())