mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 07:18:53 +00:00
feat(responses)!: improve responses + conversations implementations (#3810)
This PR updates the Conversation item related types and improves a couple critical parts of the implemenation: - it creates a streaming output item for the final assistant message output by the model. until now we only added content parts and included that message in the final response. - rewrites the conversation update code completely to account for items other than messages (tool calls, outputs, etc.) ## Test Plan Used the test script from https://github.com/llamastack/llama-stack-client-python/pull/281 for this ``` TEST_API_BASE_URL=http://localhost:8321/v1 \ pytest tests/integration/test_agent_turn_step_events.py::test_client_side_function_tool -xvs ```
This commit is contained in:
parent
add8cd801b
commit
e9b4278a51
129 changed files with 86266 additions and 903 deletions
|
@ -7,7 +7,8 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import Agent
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
||||
|
@ -56,12 +57,12 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
|
|||
with pytest.raises(Exception, match="Unauthorized"):
|
||||
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
response = llama_stack_client.tools.list(
|
||||
tools_list = llama_stack_client.tools.list(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
assert len(response) == 2
|
||||
assert {t.name for t in response} == {"greet_everyone", "get_boiling_point"}
|
||||
assert len(tools_list) == 2
|
||||
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="greet_everyone",
|
||||
|
@ -74,59 +75,83 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
|
|||
assert content[0].text == "Hello, world!"
|
||||
|
||||
print(f"Using model: {text_model_id}")
|
||||
tool_defs = [
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_url": uri,
|
||||
"server_label": test_toolgroup_id,
|
||||
"require_approval": "never",
|
||||
"allowed_tools": [tool.name for tool in tools_list],
|
||||
}
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=[test_toolgroup_id],
|
||||
tools=tool_defs,
|
||||
)
|
||||
session_id = agent.create_session("test-session")
|
||||
response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say hi to the world. Use tools to do so.",
|
||||
}
|
||||
],
|
||||
stream=False,
|
||||
extra_headers=auth_headers,
|
||||
chunks = list(
|
||||
agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Say hi to the world. Use tools to do so.",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
)
|
||||
steps = response.steps
|
||||
first = steps[0]
|
||||
assert first.step_type == "inference"
|
||||
assert len(first.api_model_response.tool_calls) == 1
|
||||
tool_call = first.api_model_response.tool_calls[0]
|
||||
assert tool_call.tool_name == "greet_everyone"
|
||||
|
||||
second = steps[1]
|
||||
assert second.step_type == "tool_execution"
|
||||
tool_response_content = second.tool_responses[0].content
|
||||
assert len(tool_response_content) == 1
|
||||
assert tool_response_content[0].type == "text"
|
||||
assert tool_response_content[0].text == "Hello, world!"
|
||||
events = [chunk.event for chunk in chunks]
|
||||
final_response = next((chunk.response for chunk in reversed(chunks) if chunk.response), None)
|
||||
assert final_response is not None
|
||||
|
||||
third = steps[2]
|
||||
assert third.step_type == "inference"
|
||||
issued_calls = [
|
||||
event for event in events if isinstance(event, StepProgress) and isinstance(event.delta, ToolCallIssuedDelta)
|
||||
]
|
||||
assert issued_calls and issued_calls[0].delta.tool_name == "greet_everyone"
|
||||
|
||||
tool_events = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
]
|
||||
assert tool_events and tool_events[0].result.tool_calls[0].tool_name == "greet_everyone"
|
||||
|
||||
assert "hello" in final_response.output_text.lower()
|
||||
|
||||
# when streaming, we currently don't check auth headers upfront and fail the request
|
||||
# early. but we should at least be generating a 401 later in the process.
|
||||
response = agent.create_turn(
|
||||
response_stream = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice? Use tools to answer.",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "What is the boiling point of polyjuice? Use tools to answer.",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
with pytest.raises(AuthenticationRequiredError):
|
||||
for _ in response:
|
||||
for _ in response_stream:
|
||||
pass
|
||||
else:
|
||||
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
|
||||
error_chunks = [chunk for chunk in response_stream if "error" in chunk.model_dump()]
|
||||
assert len(error_chunks) == 1
|
||||
chunk = error_chunks[0].model_dump()
|
||||
assert "Unauthorized" in chunk["error"]["message"]
|
||||
|
|
|
@ -348,7 +348,8 @@ class TestAgentWithMCPTools:
|
|||
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Library client required for local MCP server")
|
||||
|
||||
from llama_stack_client import Agent
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.turn_events import StepCompleted
|
||||
|
||||
test_toolgroup_id = "mcp::complex_agent"
|
||||
uri = mcp_server_with_complex_schemas["server_url"]
|
||||
|
@ -369,36 +370,56 @@ class TestAgentWithMCPTools:
|
|||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||
}
|
||||
|
||||
# Create agent with MCP tools
|
||||
tools_list = llama_stack_client.tools.list(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
tool_defs = [
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_url": uri,
|
||||
"server_label": test_toolgroup_id,
|
||||
"require_approval": "never",
|
||||
"allowed_tools": [tool.name for tool in tools_list],
|
||||
}
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant that can process orders and book flights.",
|
||||
tools=[test_toolgroup_id],
|
||||
tools=tool_defs,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
|
||||
session_id = agent.create_session("test-session-complex")
|
||||
|
||||
# Ask agent to use a tool with complex schema
|
||||
response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{"role": "user", "content": "Process an order with 2 widgets going to 123 Main St, San Francisco"}
|
||||
],
|
||||
stream=False,
|
||||
extra_headers=auth_headers,
|
||||
chunks = list(
|
||||
agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Process an order with 2 widgets going to 123 Main St, San Francisco",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
)
|
||||
|
||||
steps = response.steps
|
||||
events = [chunk.event for chunk in chunks]
|
||||
tool_execution_steps = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
]
|
||||
|
||||
# Verify agent was able to call the tool
|
||||
# (The LLM should have been able to understand the schema and formulate a valid call)
|
||||
tool_execution_steps = [s for s in steps if s.step_type == "tool_execution"]
|
||||
|
||||
# Agent might or might not call the tool depending on the model
|
||||
# But if it does, there should be no errors
|
||||
for step in tool_execution_steps:
|
||||
if step.tool_responses:
|
||||
for tool_response in step.tool_responses:
|
||||
assert tool_response.content is not None
|
||||
for tool_response in step.result.tool_responses:
|
||||
assert tool_response.get("content") is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue