import os
import unittest
from typing import Generator
from unittest.mock import patch
import httpx
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType
from openai import OpenAI, Stream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from deepinsight.core.agent.stream_chat_agent import StreamChatAgent
class TestStreamChatAgent(unittest.TestCase):
"""Unit tests for StreamChatAgent class."""
def setUp(self):
"""Test setup that runs before each test method."""
self.patcher1 = patch('camel.models.openai_model.OpenAIModel.token_counter')
self.mock_token_counter = self.patcher1.start()
def side_effect(text):
return len(text.split())
self.mock_token_counter.side_effect = side_effect
os.environ["OPENAI_API_KEY"] = "sk-test"
self.model_config = {
"stream": True,
}
def tearDown(self):
"""Test cleanup that runs after each test method."""
self.patcher1.stop()
del os.environ["OPENAI_API_KEY"]
def test_stream_chat_agent_with_mock_response(self):
"""Test StreamChatAgent with mocked streaming responses."""
with patch("camel.models.model_manager.ModelManager.run") as mock_model_run:
def mock_stream_response() -> Generator[ChatCompletionChunk, None, None]:
chunks = [
ChatCompletionChunk(
id="1",
choices=[
Choice(
index=0,
delta=ChoiceDelta(content="a")
)
],
created=123,
model="gpt-4",
object="chat.completion.chunk"
),
ChatCompletionChunk(
id="1",
choices=[
Choice(
index=0,
delta=ChoiceDelta(content="b")
)
],
created=123,
model="gpt-4",
object="chat.completion.chunk"
),
ChatCompletionChunk(
id="1",
choices=[
Choice(
index=0,
delta=ChoiceDelta(content="c")
)
],
created=123,
model="gpt-4",
object="chat.completion.chunk",
usage=CompletionUsage(
completion_tokens=1,
prompt_tokens=1,
total_tokens=2
)
),
]
def generator():
for chunk in chunks:
yield chunk
stream = Stream(
cast_to=None,
response=httpx.Response(status_code=200),
client=OpenAI(),
)
stream._iterator = generator()
return stream
mock_model_run.return_value = mock_stream_response()
stream_chat_agent = StreamChatAgent(
system_message="",
model=ModelFactory.create(
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=self.model_config
)
)
generator = stream_chat_agent.stream_step("test input")
content_result = ""
try:
while True:
item = next(generator)
if hasattr(item, "payload"):
content_result += item.payload
except StopIteration as e:
response = e.value
self.assertEqual(content_result, "abc")
self.assertEqual(response.info["usage"], dict(
completion_tokens=1,
prompt_tokens=1,
total_tokens=2
))
mock_model_run.assert_called_once()