"""
MCP-specific rewarding logic for task completion and record filtering.
"""
from log import logger
import random
from typing import TYPE_CHECKING, Any, Dict
from beanie.operators import In
from beanie import UpdateResponse
if TYPE_CHECKING:
from .models import MCPTask
from databases import Record
from configs import AgentTrainingConfig
async def apply_mcp_rewarding_logic(
task: "MCPTask",
record_class: type,
task_class: type,
config: Any
) -> None:
"""
Apply MCP-specific rewarding logic similar to tool_rewarding.py.
This includes variance check, balance_sample, filter_wrong, etc.
Args:
task: The MCPTask instance
record_class: The Record class type
task_class: The Task class type
config: The AgentTrainingConfig instance
"""
if max(task.scores) - min(task.scores) < 0.1:
task.status = task_class.Status.COMPLETED
await record_class.find_many(
{"task.$id": task.id}, with_children=True
).update({"$set": {"status": record_class.Status.ABANDONED}})
await task.save()
return
if getattr(config, "balance_sample", False):
pipeline = [
{"$match": {"task.$id": task.id, "score": 0}},
{"$addFields": {"traj_length": {"$size": "$traj"}}},
{"$sort": {"traj_length": 1}}
]
dict_records = await record_class.aggregate(pipeline).to_list()
wrong_records = [record_class.model_validate(rec) for rec in dict_records]
pipeline = [
{"$match": {"task.$id": task.id, "score": {"$ne": 0}}},
{"$addFields": {"traj_length": {"$size": "$traj"}}},
{"$sort": {"traj_length": 1}}
]
dict_records = await record_class.aggregate(pipeline).to_list()
right_records = [record_class.model_validate(rec) for rec in dict_records]
valid_wrong_records = []
for rec in wrong_records:
final_answer = rec.meta_infos.get("final_answer", "")
if final_answer == "failed_error":
rec.status = record_class.Status.ABANDONED
await rec.save()
else:
valid_wrong_records.append(rec)
min_len = min(len(valid_wrong_records), len(right_records))
random.shuffle(valid_wrong_records)
random.shuffle(right_records)
for i, rec in enumerate(valid_wrong_records):
if i < min_len + 1:
rec.status = record_class.Status.READY
else:
rec.status = record_class.Status.ABANDONED
await rec.save()
for i, rec in enumerate(right_records):
if i < min_len + 2:
rec.status = record_class.Status.READY
else:
rec.status = record_class.Status.ABANDONED
await rec.save()
task.status = task_class.Status.COMPLETED
await task.save()
return
await record_class.find_many(
{"task.$id": task.id}, with_children=True
).update({"$set": {"status": record_class.Status.READY}})
all_records = await record_class.find_many(
{"task.$id": task.id}, with_children=True
).to_list()
all_records = [record_class.model_validate(rec) for rec in all_records]
failed_error_records = [rec for rec in all_records if rec.meta_infos.get("final_answer", "") == "failed_error"]
for rec in failed_error_records:
rec.status = record_class.Status.ABANDONED
await rec.save()
context_error_types = ["length_limit_error", "turns_limit_error"]
context_error_records = [rec for rec in all_records if rec.meta_infos.get("final_answer", "") in context_error_types]
if len(context_error_records) > 0:
context_error_records = [record_class.model_validate(rec) for rec in context_error_records]
random.shuffle(context_error_records)
reserve_count = getattr(config, "reserve_context_error_count", 0)
reserve_count = max(0, min(reserve_count, len(context_error_records)))
for i, rec in enumerate(context_error_records):
if i < reserve_count:
pass
else:
rec.status = record_class.Status.ABANDONED
await rec.save()
logger.info(f"Reserved {reserve_count} out of {len(context_error_records)} context error records.")
task.status = task_class.Status.COMPLETED
await task.save()