import os
import unittest
import uuid
from unittest.mock import patch, MagicMock
from typing import Generator, Dict, Any
from deepinsight.core.agent.base import BaseAgent
from deepinsight.config.model import ModelConfig
from camel.agents import ChatAgent
from camel.responses import ChatAgentResponse
from deepinsight.core.agent.stream_chat_agent import StreamChatAgent
from deepinsight.core.types.messages import ChunkMessage
from camel.types import ModelPlatformType, ModelType
class MockAgent(BaseAgent):
"""Mock implementation of BaseAgent for testing purposes.
This class implements abstract methods from BaseAgent with simple mock behavior.
"""
def build_system_prompt(self) -> str:
return "System prompt"
def build_user_prompt(self, *, query: str, context: Dict[str, Any] | None = None) -> str:
return f"User prompt: {query}"
def parse_output(self, response: ChatAgentResponse) -> ChatAgentResponse:
return response
class TestBaseAgent(unittest.TestCase):
"""Test suite for BaseAgent functionality."""
def setUp(self):
"""Test setup that runs before each test method."""
self.mock_token_counter = patch('camel.models.openai_model.OpenAIModel.token_counter').start()
def side_effect(text):
return len(text.split())
self.mock_token_counter.side_effect = side_effect
os.environ["OPENAI_API_KEY"] = "sk-test"
def tearDown(self):
"""Test cleanup that runs after each test method."""
del os.environ["OPENAI_API_KEY"]
patch.stopall()
def _get_model_config(self, stream: bool) -> ModelConfig:
"""Helper method to create a ModelConfig with specified streaming setting.
Args:
stream: Boolean indicating whether streaming should be enabled
Returns:
Configured ModelConfig instance
"""
return ModelConfig(
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict={"stream": stream},
api_key="test_key"
)
def test_streaming_agent_initialization(self):
"""Test that agent initializes with StreamChatAgent when streaming is enabled."""
config = self._get_model_config(stream=True)
agent = MockAgent(
model_config=config,
mcp_client_timeout=30
)
self.assertIsInstance(agent.agent, StreamChatAgent)
def test_non_streaming_agent_initialization(self):
"""Test that agent initializes with regular ChatAgent when streaming is disabled."""
config = self._get_model_config(stream=False)
agent = MockAgent(
model_config=config,
mcp_client_timeout=30
)
self.assertIsInstance(agent.agent, ChatAgent)
self.assertNotIsInstance(agent.agent, StreamChatAgent)
def test_run_method_streaming(self):
"""Test the run() method in streaming mode produces expected chunks."""
config = self._get_model_config(stream=True)
agent = MockAgent(model_config=config)
mock_response = MagicMock(spec=ChatAgentResponse)
mock_response.output_messages = ["Test response"]
def mock_stream_step(prompt):
yield ChunkMessage(payload="Stream chunk 1", stream_id=str(uuid.uuid4()))
yield ChunkMessage(payload="Stream chunk 2", stream_id=str(uuid.uuid4()))
return mock_response
with patch.object(agent.agent, 'stream_step', new=mock_stream_step):
result = agent.run("test query")
self.assertIsInstance(result, Generator)
chunks = list(result)
self.assertEqual(len(chunks), 2)
self.assertEqual(chunks[0].payload, "Stream chunk 1")
self.assertEqual(chunks[1].payload, "Stream chunk 2")
def test_run_method_non_streaming(self):
"""Test the run() method in non-streaming mode produces expected response."""
config = self._get_model_config(stream=False)
agent = MockAgent(model_config=config)
mock_response = MagicMock(spec=ChatAgentResponse)
mock_response.output_messages = ["Test response"]
with patch.object(agent.agent, 'step', return_value=mock_response) as mock_step:
result = agent.run("test query")
self.assertIsInstance(result, Generator)
try:
while True:
item = next(result)
except StopIteration as e:
result = e.value
mock_step.assert_called_once_with("User prompt: test query")
self.assertEqual(result, mock_response)