fix(responses): fix subtle bugs in non-function tool calling (#3817)

We were generating "FunctionToolCall" items even for MCP (and
file-search, etc.) server-side calls. ID mismatches, etc. galore.
This commit is contained in:
Ashwin Bharambe 2025-10-15 13:57:37 -07:00 committed by GitHub
parent d709eeb33f
commit 0a96a7faa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 10660 additions and 51 deletions

View file

@ -11,7 +11,6 @@ from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.core.datatypes import AuthenticationRequiredError
AUTH_TOKEN = "test-token"
@ -82,9 +81,11 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
"server_label": test_toolgroup_id,
"require_approval": "never",
"allowed_tools": [tool.name for tool in tools_list],
"headers": {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
}
]
agent = Agent(
client=llama_stack_client,
model=text_model_id,
@ -111,47 +112,22 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
extra_headers=auth_headers,
)
)
events = [chunk.event for chunk in chunks]
final_response = next((chunk.response for chunk in reversed(chunks) if chunk.response), None)
assert final_response is not None
issued_calls = [
event for event in events if isinstance(event, StepProgress) and isinstance(event.delta, ToolCallIssuedDelta)
]
assert issued_calls and issued_calls[0].delta.tool_name == "greet_everyone"
assert issued_calls
assert issued_calls[-1].delta.tool_name == "greet_everyone"
tool_events = [
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
]
assert tool_events and tool_events[0].result.tool_calls[0].tool_name == "greet_everyone"
assert tool_events
assert tool_events[-1].result.tool_calls[0].tool_name == "greet_everyone"
assert "hello" in final_response.output_text.lower()
# when streaming, we currently don't check auth headers upfront and fail the request
# early. but we should at least be generating a 401 later in the process.
response_stream = agent.create_turn(
session_id=session_id,
messages=[
{
"type": "message",
"role": "user",
"content": [
{
"type": "input_text",
"text": "What is the boiling point of polyjuice? Use tools to answer.",
}
],
}
],
stream=True,
)
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
with pytest.raises(AuthenticationRequiredError):
for _ in response_stream:
pass
else:
error_chunks = [chunk for chunk in response_stream if "error" in chunk.model_dump()]
assert len(error_chunks) == 1
chunk = error_chunks[0].model_dump()
assert "Unauthorized" in chunk["error"]["message"]