forked from phoenix-oss/llama-stack-mirror
Add another test case
This commit is contained in:
parent
abb43936ab
commit
8b3ffa33de
1 changed files with 102 additions and 13 deletions
|
@ -14,7 +14,11 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from ..agent_instance import ChatAgent
|
||||
from ..agents import (
|
||||
AGENT_INSTANCES_BY_ID,
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
|
@ -42,10 +46,22 @@ class MockInferenceAPI:
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta="Mock response",
|
||||
delta="AI is a fascinating field...",
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search.value,
|
||||
arguments={"query": "AI history"},
|
||||
),
|
||||
parse_status="success",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
|
@ -149,24 +165,32 @@ def mock_memory_api():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||
# You'll need to adjust this based on the actual ChatAgent constructor
|
||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceImplConfig(),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
memory_api=mock_memory_api,
|
||||
)
|
||||
await impl.initialize()
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
sampling_params=SamplingParams(),
|
||||
tools=[],
|
||||
tools=[
|
||||
SearchToolDefinition(
|
||||
name="brave_search",
|
||||
api_key="test_key",
|
||||
),
|
||||
],
|
||||
tool_choice=ToolChoice.auto,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
)
|
||||
return ChatAgent(
|
||||
agent_config=agent_config,
|
||||
inference_api=mock_inference_api,
|
||||
memory_api=mock_memory_api,
|
||||
safety_api=mock_safety_api,
|
||||
builtin_tools=[],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -214,3 +238,68 @@ async def test_run_shields_wrapper(chat_agent):
|
|||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.response.is_violation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
|
||||
async def test_chat_agent_complex_turn(chat_agent):
|
||||
# Setup
|
||||
session = chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id="random",
|
||||
session_id=session.session_id,
|
||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Execute the turn
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
# Assertions
|
||||
assert len(responses) > 0
|
||||
|
||||
# Check for the presence of different step types
|
||||
step_types = [
|
||||
response.event.payload.step_type
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "step_type")
|
||||
]
|
||||
|
||||
assert "shield_call" in step_types, "Shield call step is missing"
|
||||
assert "inference" in step_types, "Inference step is missing"
|
||||
assert "tool_execution" in step_types, "Tool execution step is missing"
|
||||
|
||||
# Check for the presence of start and complete events
|
||||
event_types = [
|
||||
response.event.payload.event_type
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "event_type")
|
||||
]
|
||||
assert "start" in event_types, "Start event is missing"
|
||||
assert "complete" in event_types, "Complete event is missing"
|
||||
|
||||
# Check for the presence of tool call
|
||||
tool_calls = [
|
||||
response.event.payload.tool_call
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "tool_call")
|
||||
]
|
||||
assert any(
|
||||
tool_call
|
||||
for tool_call in tool_calls
|
||||
if tool_call and tool_call.content.get("name") == "memory"
|
||||
), "Memory tool call is missing"
|
||||
|
||||
# Check for the final turn complete event
|
||||
assert any(
|
||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
for response in responses
|
||||
), "Turn complete event is missing"
|
||||
|
||||
# Verify the turn was added to the session
|
||||
assert len(session.turns) == 1, "Turn was not added to the session"
|
||||
assert (
|
||||
session.turns[0].input_messages == request.messages
|
||||
), "Input messages do not match"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue