Add another test case

This commit is contained in:
Ashwin Bharambe 2024-09-19 13:02:57 -07:00
parent abb43936ab
commit 8b3ffa33de

View file

@ -14,7 +14,11 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from ..agent_instance import ChatAgent from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceImplConfig,
)
class MockInferenceAPI: class MockInferenceAPI:
@ -42,10 +46,22 @@ class MockInferenceAPI:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="progress", event_type="progress",
delta="Mock response", delta="AI is a fascinating field...",
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="progress",
delta=ToolCallDelta(
content=ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search.value,
arguments={"query": "AI history"},
),
parse_status="success",
),
) )
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="complete", event_type="complete",
@ -149,24 +165,32 @@ def mock_memory_api():
@pytest.fixture @pytest.fixture
def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
# You'll need to adjust this based on the actual ChatAgent constructor impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,
)
await impl.initialize()
agent_config = AgentConfig( agent_config = AgentConfig(
model="test_model", model="test_model",
instructions="You are a helpful assistant.", instructions="You are a helpful assistant.",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
tools=[], tools=[
SearchToolDefinition(
name="brave_search",
api_key="test_key",
),
],
tool_choice=ToolChoice.auto, tool_choice=ToolChoice.auto,
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
) )
return ChatAgent( response = await impl.create_agent(agent_config)
agent_config=agent_config, agent = AGENT_INSTANCES_BY_ID[response.agent_id]
inference_api=mock_inference_api, return agent
memory_api=mock_memory_api,
safety_api=mock_safety_api,
builtin_tools=[],
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -214,3 +238,68 @@ async def test_run_shields_wrapper(chat_agent):
assert len(responses) == 2 # StepStart, StepComplete assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call" assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.response.is_violation assert not responses[1].event.payload.step_details.response.is_violation
@pytest.mark.asyncio
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
async def test_chat_agent_complex_turn(chat_agent):
# Setup
session = chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
stream=True,
)
# Execute the turn
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
# Assertions
assert len(responses) > 0
# Check for the presence of different step types
step_types = [
response.event.payload.step_type
for response in responses
if hasattr(response.event.payload, "step_type")
]
assert "shield_call" in step_types, "Shield call step is missing"
assert "inference" in step_types, "Inference step is missing"
assert "tool_execution" in step_types, "Tool execution step is missing"
# Check for the presence of start and complete events
event_types = [
response.event.payload.event_type
for response in responses
if hasattr(response.event.payload, "event_type")
]
assert "start" in event_types, "Start event is missing"
assert "complete" in event_types, "Complete event is missing"
# Check for the presence of tool call
tool_calls = [
response.event.payload.tool_call
for response in responses
if hasattr(response.event.payload, "tool_call")
]
assert any(
tool_call
for tool_call in tool_calls
if tool_call and tool_call.content.get("name") == "memory"
), "Memory tool call is missing"
# Check for the final turn complete event
assert any(
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
for response in responses
), "Turn complete event is missing"
# Verify the turn was added to the session
assert len(session.turns) == 1, "Turn was not added to the session"
assert (
session.turns[0].input_messages == request.messages
), "Input messages do not match"