import os
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Literal, Type, Callable, Coroutine
import openai.types.chat
import datetime
from beanie import UpdateResponse
import beanie
from configs import AgentTrainingConfig
from log import logger
from databases import InferenceService, DispatchedSamplingTask, Task, Record, NoTaskAvailableError, DistributedCounter, DistributedLock
from pymongo import ReturnDocument
from tenacity import retry, stop_after_attempt, wait_fixed
from packaging import version
class AsyncSampler(ABC):
"""
An abstract base class for running a asynchronous sampler.
"""
def __init__(
self,
config: AgentTrainingConfig,
*args,
task_class: Type[Task] = Task,
record_class: Type[Record] = Record,
split: Literal["train", "valid", "test", "eval"] = "train",
**kwargs
):
super().__init__()
self.config = config
self.split = split
self.num_generations = config.num_generations if self.split == "train" else kwargs.get("override_num_generations", 1)
self.task_class = task_class
self.record_class = record_class
assert issubclass(self.task_class, Task), "task_class must be a subclass of Task"
assert issubclass(self.record_class, Record), "record_class must be a subclass of Record"
self.record = None
async def __aenter__(self):
"""Enter the asynchronous context manager."""
self.task = await self.task_class.find_one(
self.task_class.num_samples < self.num_generations,
self.task_class.split == self.split,
self.task_class.status == self.task_class.Status.RUNNING,
with_children=True
).inc(
{self.task_class.num_samples: 1},
response_type = UpdateResponse.NEW_DOCUMENT
)
if self.task is None:
self.task = await self.task_class.find_one(
self.task_class.num_samples < self.num_generations,
self.task_class.split == self.split,
self.task_class.status == self.task_class.Status.CREATED,
with_children=True
).update(
{"$inc": {self.task_class.num_samples: 1}, "$set": {self.task_class.status: self.task_class.Status.RUNNING}},
response_type = UpdateResponse.NEW_DOCUMENT
)
if self.task is None:
raise NoTaskAvailableError("No task available for sampling. Please create a task first.")
self.record = self.record_class(
task=self.task,
traj_id=self.task.num_samples - 1,
split=self.split,
status=self.record_class.Status.RUNNING,
)
await self.record.save()
self.running_sampler = await DistributedCounter.create(name=f"running-sampler-{self.split}")
await self.running_sampler.inc()
self.infer_count = await DistributedCounter.create(name=f"infer")
self.split_lock = await DistributedLock.create(name=self.split)
self.global_step_counter = await DistributedCounter.create(name="global_step")
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit the asynchronous context manager."""
if self.record is not None:
if self.record.status == self.record_class.Status.RUNNING:
if exc_type is None:
self.record.status = self.record_class.Status.COMPLETED
else:
self.record.status = self.record_class.Status.FAILED
logger.warning(f"Record {self.record.traj_id} for task {self.record.task.id} failed with error: {exc_value}")
if self.config.resample_error_records:
await self.task_class.find_one(
self.task_class.id == self.task.id, with_children=True
).update({
"$inc": {"num_samples": -1},
})
await self.record.save()
if self.record.status == self.record_class.Status.COMPLETED:
self.record.status = self.record_class.Status.SCORING
await self.record.save()
try:
score = await self.evaluate_record(self.record)
if not isinstance(score, float):
score = float(score)
except Exception as e:
import traceback
logger.error("Error when calculate score for record "+ str(self.record.id)+"\n"+traceback.format_exc())
score = 0.0
self.record.score = score
self.record.status = self.record_class.Status.SCORED
await self.record.save()
self.task: Task = await self.task_class.find_one(
self.task_class.id == self.task.id, with_children=True
).update({
"$push": {"scores": score}
}, response_type=UpdateResponse.NEW_DOCUMENT)
status_to_count = [self.record_class.Status.SCORED]
if not self.config.resample_error_records:
status_to_count.append(self.record_class.Status.FAILED)
valid_records_count = await self.record_class.find_many(
{
"task.$id": self.task.id,
"status": {
"$in": status_to_count
}
}, with_children=True
).count()
if valid_records_count >= self.num_generations:
await self.record_class.find_many(
{
"task.$id":self.task.id,
"status": self.record_class.Status.SCORED
}, with_children=True
).update({
"$set": {
"status": self.record_class.Status.READY
}
})
await self.task.sync()
self.task.status = self.task_class.Status.COMPLETED
await self.task.save()
self.record = None
if self.task is not None:
self.task = None
await self.running_sampler.dec()
@retry(stop=stop_after_attempt(7), wait=wait_fixed(5))
async def create_chat_completions(
self,
messages: List[dict],
model: Optional[str] = None,
tools: Optional[List[dict]] = None,
priority: int = 0,
timeout: Optional[int] = None,
repeat_penalty: Optional[float] = 1.0,
finish_status: Optional[DispatchedSamplingTask.Status] = DispatchedSamplingTask.Status.COMPLETED,
**kwargs
) -> openai.types.chat.ChatCompletion:
"""
Asynchronously create chat completions and submit to scheduler.
Args:
messages (List[dict]): The messages to send to the model.
sampling_params (dict): Parameters for sampling.
model (Optional[str]): The model to use for completion.
Returns:
ChatCompletion: The response from the OpenAI API.
"""
if not self.config.sync_sampling:
await self.split_lock.wait()
await self.infer_count.wait_for(0, option="gt")
task = DispatchedSamplingTask(
task=self.task,
traj_id=self.record.traj_id,
req_type="chatcompletions",
request={
"messages": messages,
"model": model if model is not None else self.config.model_name,
"tools": tools,
**kwargs
},
priority=priority,
)
if version.parse(beanie.__version__) >= version.parse("2.0.0"):
service = await InferenceService.get_pymongo_collection().find_one_and_update(
{"status": "UP"},
{"$inc": {"running_req_count": 1}},
sort=[("running_req_count", 1)],
return_document=ReturnDocument.AFTER
)
else:
service = await InferenceService.get_motor_collection().find_one_and_update(
{"status": "UP"},
{"$inc": {"running_req_count": 1}},
sort=[("running_req_count", 1)],
return_document=ReturnDocument.AFTER
)
if service is None:
raise RuntimeError("No available inference service found.")
service = InferenceService.model_validate(service)
task.sampled_from = service
task.status = DispatchedSamplingTask.Status.RUNNING
await task.save()
self.record.traj.append(DispatchedSamplingTask.link_from_id(task.id))
await self.record.save()
try:
service = await InferenceService.find_one(InferenceService.id == service.id)
if service.status != "UP":
raise RuntimeError(f"Inference service {service.id} is not UP.")
await self.global_step_counter.sync()
task.created_at_step = self.global_step_counter.n
match service.connection_type:
case "openai":
import openai
async with openai.AsyncOpenAI(
api_key=service.configs.api_key,
base_url=service.configs.base_url,
) as client:
presence_penalty = self.config.presence_penalty if self.config.presence_penalty is not None else 0
frequency_penalty = self.config.frequency_penalty if self.config.frequency_penalty is not None else 0
response = await client.chat.completions.create(
**task.request,
logprobs=True,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
max_completion_tokens=self.config.max_new_tokens,
timeout=timeout,
)
assert response.choices[0].finish_reason not in ["abort"], f"Sampling for record {self.record.id} not finished properly with reason: {response.choices[0].finish_reason}"
logger.debug(f"Finish Chat Completions in service {service.id} for record {self.record.id}")
case _:
raise NotImplementedError(f"Connection type {service.connection_type} is not supported.")
except Exception as e:
task.status = DispatchedSamplingTask.Status.FAILED
task.response = str(e)
task.finish_time = datetime.datetime.now()
await task.save()
if not isinstance(e, AssertionError):
import traceback
logger.error(f"Error occurred while creating chat completions for task {task.id} in service {service.id}:\n{traceback.format_exc()}")
raise e
finally:
await InferenceService.find_one(InferenceService.id == service.id).inc({InferenceService.running_req_count: -1})
if response.choices[0].finish_reason in ["length", "content_filter"]:
logger.warning(f"Sampling for task {task.id} stopped due to {response.choices[0].finish_reason}")
task.status = DispatchedSamplingTask.Status.FAILED
else:
task.status = finish_status if finish_status is not None else DispatchedSamplingTask.Status.COMPLETED
task.response = response
task.finish_time = datetime.datetime.now()
await task.save()
return response
@abstractmethod
async def run(self, task: Optional[Task] = None):
"""
Run the sampler.
This method should be implemented by subclasses to define the specific behavior of the sampler.
"""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
async def evaluate_record(self, record: Record) -> float:
"""
Calculate the score for the given record using the provided score function.
Ones could also assign the score for each trajectory step in the record here.
Args:
record (Record): The record to calculate the score for.
Returns:
float: The calculated score.
"""
class StatefulSampler(AsyncSampler):
async def __aenter__(self):
"""Enter the asynchronous context manager."""
self.task = await self.task_class.find_one(
self.task_class.num_samples < self.config.num_generations if self.split == "train" else self.task_class.num_samples < 1,
self.task_class.split == self.split,
self.task_class.status == self.task_class.Status.WAITING,
with_children=True
).update(
{"$inc": {self.task_class.num_samples: 1}, "$set": {self.task_class.status: self.task_class.Status.RUNNING}},
response_type = UpdateResponse.NEW_DOCUMENT
)
if self.task is None:
self.task = await self.task_class.find_one(
self.task_class.num_samples < self.config.num_generations if self.split == "train" else self.task_class.num_samples < 1,
self.task_class.split == self.split,
self.task_class.status == self.task_class.Status.CREATED,
with_children=True
).update(
{"$inc": {self.task_class.num_samples: 1}, "$set": {self.task_class.status: self.task_class.Status.RUNNING}},
response_type = UpdateResponse.NEW_DOCUMENT
)
if self.task is None:
raise NoTaskAvailableError("No task available for sampling. Please create a task first.")
self.record = self.record_class(
task=self.task,
traj_id=self.task.num_samples - 1,
split=self.split,
)
await self.record.save()
self.infer_count = await DistributedCounter.create(name=f"infer")
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit the asynchronous context manager."""
await self.task.sync()
self.task.status = self.task_class.Status.WAITING if self.task.num_samples < (self.config.num_generations if self.split == "train" else 1) else self.task_class.Status.COMPLETED
await self.task.save()
return await super().__aexit__(exc_type, exc_value, traceback)