mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	Renames `inference_recorder.py` to `api_recorder.py` and extends it to support recording/replaying tool invocations in addition to inference calls. This allows us to record web-search, etc. tool calls and thereafter apply recordings for `tests/integration/responses` ## Test Plan ``` export OPENAI_API_KEY=... export TAVILY_SEARCH_API_KEY=... ./scripts/integration-tests.sh --stack-config ci-tests \ --suite responses --inference-mode record-if-missing ```
		
			
				
	
	
		
			620 lines
		
	
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			620 lines
		
	
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import json
 | |
| import os
 | |
| 
 | |
| import httpx
 | |
| import openai
 | |
| import pytest
 | |
| 
 | |
| from llama_stack import LlamaStackAsLibraryClient
 | |
| from llama_stack.core.datatypes import AuthenticationRequiredError
 | |
| from tests.common.mcp import dependency_tools, make_mcp_server
 | |
| 
 | |
| from .fixtures.test_cases import (
 | |
|     custom_tool_test_cases,
 | |
|     file_search_test_cases,
 | |
|     mcp_tool_test_cases,
 | |
|     multi_turn_tool_execution_streaming_test_cases,
 | |
|     multi_turn_tool_execution_test_cases,
 | |
|     web_search_test_cases,
 | |
| )
 | |
| from .helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment
 | |
| from .streaming_assertions import StreamingValidator
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", web_search_test_cases)
 | |
| def test_response_non_streaming_web_search(compat_client, text_model_id, case):
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input=case.input,
 | |
|         tools=case.tools,
 | |
|         stream=False,
 | |
|     )
 | |
|     assert len(response.output) > 1
 | |
|     assert response.output[0].type == "web_search_call"
 | |
|     assert response.output[0].status == "completed"
 | |
|     assert response.output[1].type == "message"
 | |
|     assert response.output[1].status == "completed"
 | |
|     assert response.output[1].role == "assistant"
 | |
|     assert len(response.output[1].content) > 0
 | |
|     assert case.expected.lower() in response.output_text.lower().strip()
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", file_search_test_cases)
 | |
| def test_response_non_streaming_file_search(
 | |
|     compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path, case
 | |
| ):
 | |
|     if isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("Responses API file search is not yet supported in library client.")
 | |
| 
 | |
|     vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
 | |
| 
 | |
|     if case.file_content:
 | |
|         file_name = "test_response_non_streaming_file_search.txt"
 | |
|         file_path = tmp_path / file_name
 | |
|         file_path.write_text(case.file_content)
 | |
|     elif case.file_path:
 | |
|         file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path)
 | |
|         file_name = os.path.basename(file_path)
 | |
|     else:
 | |
|         raise ValueError("No file content or path provided for case")
 | |
| 
 | |
|     file_response = upload_file(compat_client, file_name, file_path)
 | |
| 
 | |
|     # Attach our file to the vector store
 | |
|     compat_client.vector_stores.files.create(
 | |
|         vector_store_id=vector_store.id,
 | |
|         file_id=file_response.id,
 | |
|     )
 | |
| 
 | |
|     # Wait for the file to be attached
 | |
|     wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
 | |
| 
 | |
|     # Update our tools with the right vector store id
 | |
|     tools = case.tools
 | |
|     for tool in tools:
 | |
|         if tool["type"] == "file_search":
 | |
|             tool["vector_store_ids"] = [vector_store.id]
 | |
| 
 | |
|     # Create the response request, which should query our vector store
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input=case.input,
 | |
|         tools=tools,
 | |
|         stream=False,
 | |
|         include=["file_search_call.results"],
 | |
|     )
 | |
| 
 | |
|     # Verify the file_search_tool was called
 | |
|     assert len(response.output) > 1
 | |
|     assert response.output[0].type == "file_search_call"
 | |
|     assert response.output[0].status == "completed"
 | |
|     assert response.output[0].queries  # ensure it's some non-empty list
 | |
|     assert response.output[0].results
 | |
|     assert case.expected.lower() in response.output[0].results[0].text.lower()
 | |
|     assert response.output[0].results[0].score > 0
 | |
| 
 | |
|     # Verify the output_text generated by the response
 | |
|     assert case.expected.lower() in response.output_text.lower().strip()
 | |
| 
 | |
| 
 | |
| def test_response_non_streaming_file_search_empty_vector_store(
 | |
|     compat_client, text_model_id, embedding_model_id, embedding_dimension
 | |
| ):
 | |
|     if isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("Responses API file search is not yet supported in library client.")
 | |
| 
 | |
|     vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
 | |
| 
 | |
|     # Create the response request, which should query our vector store
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input="How many experts does the Llama 4 Maverick model have?",
 | |
|         tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
 | |
|         stream=False,
 | |
|         include=["file_search_call.results"],
 | |
|     )
 | |
| 
 | |
|     # Verify the file_search_tool was called
 | |
|     assert len(response.output) > 1
 | |
|     assert response.output[0].type == "file_search_call"
 | |
|     assert response.output[0].status == "completed"
 | |
|     assert response.output[0].queries  # ensure it's some non-empty list
 | |
|     assert not response.output[0].results  # ensure we don't get any results
 | |
| 
 | |
|     # Verify some output_text was generated by the response
 | |
|     assert response.output_text
 | |
| 
 | |
| 
 | |
| def test_response_sequential_file_search(
 | |
|     compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path
 | |
| ):
 | |
|     """Test file search with sequential responses using previous_response_id."""
 | |
|     if isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("Responses API file search is not yet supported in library client.")
 | |
| 
 | |
|     vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
 | |
| 
 | |
|     # Create a test file with content
 | |
|     file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
 | |
|     file_name = "test_sequential_file_search.txt"
 | |
|     file_path = tmp_path / file_name
 | |
|     file_path.write_text(file_content)
 | |
| 
 | |
|     file_response = upload_file(compat_client, file_name, file_path)
 | |
| 
 | |
|     # Attach the file to the vector store
 | |
|     compat_client.vector_stores.files.create(
 | |
|         vector_store_id=vector_store.id,
 | |
|         file_id=file_response.id,
 | |
|     )
 | |
| 
 | |
|     # Wait for the file to be attached
 | |
|     wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
 | |
| 
 | |
|     tools = [{"type": "file_search", "vector_store_ids": [vector_store.id]}]
 | |
| 
 | |
|     # First response request with file search
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input="How many experts does the Llama 4 Maverick model have?",
 | |
|         tools=tools,
 | |
|         stream=False,
 | |
|         include=["file_search_call.results"],
 | |
|     )
 | |
| 
 | |
|     # Verify the file_search_tool was called
 | |
|     assert len(response.output) > 1
 | |
|     assert response.output[0].type == "file_search_call"
 | |
|     assert response.output[0].status == "completed"
 | |
|     assert response.output[0].queries
 | |
|     assert response.output[0].results
 | |
|     assert "128" in response.output_text or "experts" in response.output_text.lower()
 | |
| 
 | |
|     # Second response request using previous_response_id
 | |
|     response2 = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input="Can you tell me more about the architecture?",
 | |
|         tools=tools,
 | |
|         stream=False,
 | |
|         previous_response_id=response.id,
 | |
|         include=["file_search_call.results"],
 | |
|     )
 | |
| 
 | |
|     # Verify the second response has output
 | |
|     assert len(response2.output) >= 1
 | |
|     assert response2.output_text
 | |
| 
 | |
|     # The second response should maintain context from the first
 | |
|     final_message = [output for output in response2.output if output.type == "message"]
 | |
|     assert len(final_message) >= 1
 | |
|     assert final_message[-1].role == "assistant"
 | |
|     assert final_message[-1].status == "completed"
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", mcp_tool_test_cases)
 | |
| def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
 | |
|     if not isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("in-process MCP server is only supported in library client")
 | |
| 
 | |
|     with make_mcp_server() as mcp_server_info:
 | |
|         tools = setup_mcp_tools(case.tools, mcp_server_info)
 | |
| 
 | |
|         response = compat_client.responses.create(
 | |
|             model=text_model_id,
 | |
|             input=case.input,
 | |
|             tools=tools,
 | |
|             stream=False,
 | |
|         )
 | |
| 
 | |
|         assert len(response.output) >= 3
 | |
|         list_tools = response.output[0]
 | |
|         assert list_tools.type == "mcp_list_tools"
 | |
|         assert list_tools.server_label == "localmcp"
 | |
|         assert len(list_tools.tools) == 2
 | |
|         assert {t.name for t in list_tools.tools} == {
 | |
|             "get_boiling_point",
 | |
|             "greet_everyone",
 | |
|         }
 | |
| 
 | |
|         call = response.output[1]
 | |
|         assert call.type == "mcp_call"
 | |
|         assert call.name == "get_boiling_point"
 | |
|         assert json.loads(call.arguments) == {
 | |
|             "liquid_name": "myawesomeliquid",
 | |
|             "celsius": True,
 | |
|         }
 | |
|         assert call.error is None
 | |
|         assert "-100" in call.output
 | |
| 
 | |
|         # sometimes the model will call the tool again, so we need to get the last message
 | |
|         message = response.output[-1]
 | |
|         text_content = message.content[0].text
 | |
|         assert "boiling point" in text_content.lower()
 | |
| 
 | |
|     with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
 | |
|         tools = setup_mcp_tools(case.tools, mcp_server_info)
 | |
| 
 | |
|         exc_type = (
 | |
|             AuthenticationRequiredError
 | |
|             if isinstance(compat_client, LlamaStackAsLibraryClient)
 | |
|             else (httpx.HTTPStatusError, openai.AuthenticationError)
 | |
|         )
 | |
|         with pytest.raises(exc_type):
 | |
|             compat_client.responses.create(
 | |
|                 model=text_model_id,
 | |
|                 input=case.input,
 | |
|                 tools=tools,
 | |
|                 stream=False,
 | |
|             )
 | |
| 
 | |
|         for tool in tools:
 | |
|             if tool["type"] == "mcp":
 | |
|                 tool["headers"] = {"Authorization": "Bearer test-token"}
 | |
| 
 | |
|         response = compat_client.responses.create(
 | |
|             model=text_model_id,
 | |
|             input=case.input,
 | |
|             tools=tools,
 | |
|             stream=False,
 | |
|         )
 | |
|         assert len(response.output) >= 3
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", mcp_tool_test_cases)
 | |
| def test_response_sequential_mcp_tool(compat_client, text_model_id, case):
 | |
|     if not isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("in-process MCP server is only supported in library client")
 | |
| 
 | |
|     with make_mcp_server() as mcp_server_info:
 | |
|         tools = setup_mcp_tools(case.tools, mcp_server_info)
 | |
| 
 | |
|         response = compat_client.responses.create(
 | |
|             model=text_model_id,
 | |
|             input=case.input,
 | |
|             tools=tools,
 | |
|             stream=False,
 | |
|         )
 | |
| 
 | |
|         assert len(response.output) >= 3
 | |
|         list_tools = response.output[0]
 | |
|         assert list_tools.type == "mcp_list_tools"
 | |
|         assert list_tools.server_label == "localmcp"
 | |
|         assert len(list_tools.tools) == 2
 | |
|         assert {t.name for t in list_tools.tools} == {
 | |
|             "get_boiling_point",
 | |
|             "greet_everyone",
 | |
|         }
 | |
| 
 | |
|         call = response.output[1]
 | |
|         assert call.type == "mcp_call"
 | |
|         assert call.name == "get_boiling_point"
 | |
|         assert json.loads(call.arguments) == {
 | |
|             "liquid_name": "myawesomeliquid",
 | |
|             "celsius": True,
 | |
|         }
 | |
|         assert call.error is None
 | |
|         assert "-100" in call.output
 | |
| 
 | |
|         # sometimes the model will call the tool again, so we need to get the last message
 | |
|         message = response.output[-1]
 | |
|         text_content = message.content[0].text
 | |
|         assert "boiling point" in text_content.lower()
 | |
| 
 | |
|         response2 = compat_client.responses.create(
 | |
|             model=text_model_id, input=case.input, tools=tools, stream=False, previous_response_id=response.id
 | |
|         )
 | |
| 
 | |
|         assert len(response2.output) >= 1
 | |
|         message = response2.output[-1]
 | |
|         text_content = message.content[0].text
 | |
|         assert "boiling point" in text_content.lower()
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", mcp_tool_test_cases)
 | |
| @pytest.mark.parametrize("approve", [True, False])
 | |
| def test_response_mcp_tool_approval(compat_client, text_model_id, case, approve):
 | |
|     if not isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("in-process MCP server is only supported in library client")
 | |
| 
 | |
|     with make_mcp_server() as mcp_server_info:
 | |
|         tools = setup_mcp_tools(case.tools, mcp_server_info)
 | |
|         for tool in tools:
 | |
|             tool["require_approval"] = "always"
 | |
| 
 | |
|         response = compat_client.responses.create(
 | |
|             model=text_model_id,
 | |
|             input=case.input,
 | |
|             tools=tools,
 | |
|             stream=False,
 | |
|         )
 | |
| 
 | |
|         assert len(response.output) >= 2
 | |
|         list_tools = response.output[0]
 | |
|         assert list_tools.type == "mcp_list_tools"
 | |
|         assert list_tools.server_label == "localmcp"
 | |
|         assert len(list_tools.tools) == 2
 | |
|         assert {t.name for t in list_tools.tools} == {
 | |
|             "get_boiling_point",
 | |
|             "greet_everyone",
 | |
|         }
 | |
| 
 | |
|         approval_request = response.output[1]
 | |
|         assert approval_request.type == "mcp_approval_request"
 | |
|         assert approval_request.name == "get_boiling_point"
 | |
|         assert json.loads(approval_request.arguments) == {
 | |
|             "liquid_name": "myawesomeliquid",
 | |
|             "celsius": True,
 | |
|         }
 | |
| 
 | |
|         # send approval response
 | |
|         response = compat_client.responses.create(
 | |
|             previous_response_id=response.id,
 | |
|             model=text_model_id,
 | |
|             input=[{"type": "mcp_approval_response", "approval_request_id": approval_request.id, "approve": approve}],
 | |
|             tools=tools,
 | |
|             stream=False,
 | |
|         )
 | |
| 
 | |
|         if approve:
 | |
|             assert len(response.output) >= 3
 | |
|             list_tools = response.output[0]
 | |
|             assert list_tools.type == "mcp_list_tools"
 | |
|             assert list_tools.server_label == "localmcp"
 | |
|             assert len(list_tools.tools) == 2
 | |
|             assert {t.name for t in list_tools.tools} == {
 | |
|                 "get_boiling_point",
 | |
|                 "greet_everyone",
 | |
|             }
 | |
| 
 | |
|             call = response.output[1]
 | |
|             assert call.type == "mcp_call"
 | |
|             assert call.name == "get_boiling_point"
 | |
|             assert json.loads(call.arguments) == {
 | |
|                 "liquid_name": "myawesomeliquid",
 | |
|                 "celsius": True,
 | |
|             }
 | |
|             assert call.error is None
 | |
|             assert "-100" in call.output
 | |
| 
 | |
|             # sometimes the model will call the tool again, so we need to get the last message
 | |
|             message = response.output[-1]
 | |
|             text_content = message.content[0].text
 | |
|             assert "boiling point" in text_content.lower()
 | |
|         else:
 | |
|             assert len(response.output) >= 1
 | |
|             for output in response.output:
 | |
|                 assert output.type != "mcp_call"
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", custom_tool_test_cases)
 | |
| def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input=case.input,
 | |
|         tools=case.tools,
 | |
|         stream=False,
 | |
|     )
 | |
|     assert len(response.output) == 1
 | |
|     assert response.output[0].type == "function_call"
 | |
|     assert response.output[0].status == "completed"
 | |
|     assert response.output[0].name == "get_weather"
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", custom_tool_test_cases)
 | |
| def test_response_function_call_ordering_1(compat_client, text_model_id, case):
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input=case.input,
 | |
|         tools=case.tools,
 | |
|         stream=False,
 | |
|     )
 | |
|     assert len(response.output) == 1
 | |
|     assert response.output[0].type == "function_call"
 | |
|     assert response.output[0].status == "completed"
 | |
|     assert response.output[0].name == "get_weather"
 | |
|     inputs = []
 | |
|     inputs.append(
 | |
|         {
 | |
|             "role": "user",
 | |
|             "content": case.input,
 | |
|         }
 | |
|     )
 | |
|     inputs.append(
 | |
|         {
 | |
|             "type": "function_call_output",
 | |
|             "output": "It is raining.",
 | |
|             "call_id": response.output[0].call_id,
 | |
|         }
 | |
|     )
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
 | |
|     )
 | |
|     assert len(response.output) == 1
 | |
| 
 | |
| 
 | |
| def test_response_function_call_ordering_2(compat_client, text_model_id):
 | |
|     tools = [
 | |
|         {
 | |
|             "type": "function",
 | |
|             "name": "get_weather",
 | |
|             "description": "Get current temperature for a given location.",
 | |
|             "parameters": {
 | |
|                 "additionalProperties": False,
 | |
|                 "properties": {
 | |
|                     "location": {
 | |
|                         "description": "City and country e.g. Bogotá, Colombia",
 | |
|                         "type": "string",
 | |
|                     }
 | |
|                 },
 | |
|                 "required": ["location"],
 | |
|                 "type": "object",
 | |
|             },
 | |
|         }
 | |
|     ]
 | |
|     inputs = [
 | |
|         {
 | |
|             "role": "user",
 | |
|             "content": "Is the weather better in San Francisco or Los Angeles?",
 | |
|         }
 | |
|     ]
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input=inputs,
 | |
|         tools=tools,
 | |
|         stream=False,
 | |
|     )
 | |
|     for output in response.output:
 | |
|         if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
 | |
|             inputs.append(output)
 | |
|     for output in response.output:
 | |
|         if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
 | |
|             weather = "It is raining."
 | |
|             if "Los Angeles" in output.arguments:
 | |
|                 weather = "It is cloudy."
 | |
|             inputs.append(
 | |
|                 {
 | |
|                     "type": "function_call_output",
 | |
|                     "output": weather,
 | |
|                     "call_id": output.call_id,
 | |
|                 }
 | |
|             )
 | |
|     response = compat_client.responses.create(
 | |
|         model=text_model_id,
 | |
|         input=inputs,
 | |
|         tools=tools,
 | |
|         stream=False,
 | |
|     )
 | |
|     assert len(response.output) == 1
 | |
|     assert "Los Angeles" in response.output_text
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
 | |
| def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
 | |
|     """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
 | |
|     if not isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("in-process MCP server is only supported in library client")
 | |
| 
 | |
|     with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
 | |
|         tools = setup_mcp_tools(case.tools, mcp_server_info)
 | |
| 
 | |
|         response = compat_client.responses.create(
 | |
|             input=case.input,
 | |
|             model=text_model_id,
 | |
|             tools=tools,
 | |
|         )
 | |
| 
 | |
|         # Verify we have MCP tool calls in the output
 | |
|         mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
 | |
|         mcp_calls = [output for output in response.output if output.type == "mcp_call"]
 | |
|         message_outputs = [output for output in response.output if output.type == "message"]
 | |
| 
 | |
|         # Should have exactly 1 MCP list tools message (at the beginning)
 | |
|         assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
 | |
|         assert mcp_list_tools[0].server_label == "localmcp"
 | |
|         assert len(mcp_list_tools[0].tools) == 5  # Updated for dependency tools
 | |
|         expected_tool_names = {
 | |
|             "get_user_id",
 | |
|             "get_user_permissions",
 | |
|             "check_file_access",
 | |
|             "get_experiment_id",
 | |
|             "get_experiment_results",
 | |
|         }
 | |
|         assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
 | |
| 
 | |
|         assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
 | |
|         for mcp_call in mcp_calls:
 | |
|             assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
 | |
| 
 | |
|         assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
 | |
| 
 | |
|         final_message = message_outputs[-1]
 | |
|         assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
 | |
|         assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
 | |
|         assert len(final_message.content) > 0, "Final message should have content"
 | |
| 
 | |
|         expected_output = case.expected
 | |
|         assert expected_output.lower() in response.output_text.lower(), (
 | |
|             f"Expected '{expected_output}' to appear in response: {response.output_text}"
 | |
|         )
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
 | |
| def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
 | |
|     """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
 | |
|     if not isinstance(compat_client, LlamaStackAsLibraryClient):
 | |
|         pytest.skip("in-process MCP server is only supported in library client")
 | |
| 
 | |
|     with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
 | |
|         tools = setup_mcp_tools(case.tools, mcp_server_info)
 | |
| 
 | |
|         stream = compat_client.responses.create(
 | |
|             input=case.input,
 | |
|             model=text_model_id,
 | |
|             tools=tools,
 | |
|             stream=True,
 | |
|         )
 | |
| 
 | |
|         chunks = []
 | |
|         for chunk in stream:
 | |
|             chunks.append(chunk)
 | |
| 
 | |
|         # Use validator for common streaming checks
 | |
|         validator = StreamingValidator(chunks)
 | |
|         validator.assert_basic_event_sequence()
 | |
|         validator.assert_response_consistency()
 | |
|         validator.assert_has_tool_calls()
 | |
|         validator.assert_has_mcp_events()
 | |
|         validator.assert_rich_streaming()
 | |
| 
 | |
|         # Get the final response from the last chunk
 | |
|         final_chunk = chunks[-1]
 | |
|         if hasattr(final_chunk, "response"):
 | |
|             final_response = final_chunk.response
 | |
| 
 | |
|             # Verify multi-turn MCP tool execution results
 | |
|             mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
 | |
|             mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
 | |
|             message_outputs = [output for output in final_response.output if output.type == "message"]
 | |
| 
 | |
|             # Should have exactly 1 MCP list tools message (at the beginning)
 | |
|             assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
 | |
|             assert mcp_list_tools[0].server_label == "localmcp"
 | |
|             assert len(mcp_list_tools[0].tools) == 5  # Updated for dependency tools
 | |
|             expected_tool_names = {
 | |
|                 "get_user_id",
 | |
|                 "get_user_permissions",
 | |
|                 "check_file_access",
 | |
|                 "get_experiment_id",
 | |
|                 "get_experiment_results",
 | |
|             }
 | |
|             assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
 | |
| 
 | |
|             # Should have at least 1 MCP call (the model should call at least one tool)
 | |
|             assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
 | |
| 
 | |
|             # All MCP calls should be completed (verifies our tool execution works)
 | |
|             for mcp_call in mcp_calls:
 | |
|                 assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
 | |
| 
 | |
|             # Should have at least one final message response
 | |
|             assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
 | |
| 
 | |
|             # Final message should be from assistant and completed
 | |
|             final_message = message_outputs[-1]
 | |
|             assert final_message.role == "assistant", (
 | |
|                 f"Final message should be from assistant, got {final_message.role}"
 | |
|             )
 | |
|             assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
 | |
|             assert len(final_message.content) > 0, "Final message should have content"
 | |
| 
 | |
|             # Check that the expected output appears in the response
 | |
|             expected_output = case.expected
 | |
|             assert expected_output.lower() in final_response.output_text.lower(), (
 | |
|                 f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
 | |
|             )
 |