From 8ed69978f9d08b146ba3d3d57cc5458fdd48bc54 Mon Sep 17 00:00:00 2001 From: ashwinb Date: Fri, 15 Aug 2025 00:05:36 +0000 Subject: [PATCH] refactor(tests): make the responses tests nicer (#3161) # What does this PR do? A _bunch_ on cleanup for the Responses tests. - Got rid of YAML test cases, moved them to just use simple pydantic models - Splitting the large monolithic test file into multiple focused test files: - `test_basic_responses.py` for basic and image response tests - `test_tool_responses.py` for tool-related tests - `test_file_search.py` for file search specific tests - Adding a `StreamingValidator` helper class to standardize streaming response validation ## Test Plan Run the tests: ``` pytest -s -v tests/integration/non_ci/responses/ \ --stack-config=starter \ --text-model openai/gpt-4o \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ -k "client_with_models" ``` --- .../non_ci/responses/fixtures/fixtures.py | 14 - .../non_ci/responses/fixtures/load.py | 16 - .../non_ci/responses/fixtures/test_cases.py | 262 ++++ .../fixtures/test_cases/chat_completion.yaml | 397 ------ .../fixtures/test_cases/responses.yaml | 166 --- tests/integration/non_ci/responses/helpers.py | 64 + .../non_ci/responses/streaming_assertions.py | 145 +++ .../non_ci/responses/test_basic_responses.py | 188 +++ .../non_ci/responses/test_file_search.py | 318 +++++ .../non_ci/responses/test_responses.py | 1143 ----------------- .../non_ci/responses/test_tool_responses.py | 335 +++++ 11 files changed, 1312 insertions(+), 1736 deletions(-) delete mode 100644 tests/integration/non_ci/responses/fixtures/load.py create mode 100644 tests/integration/non_ci/responses/fixtures/test_cases.py delete mode 100644 tests/integration/non_ci/responses/fixtures/test_cases/chat_completion.yaml delete mode 100644 tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml create mode 100644 tests/integration/non_ci/responses/helpers.py create mode 100644 tests/integration/non_ci/responses/streaming_assertions.py create mode 100644 tests/integration/non_ci/responses/test_basic_responses.py create mode 100644 tests/integration/non_ci/responses/test_file_search.py delete mode 100644 tests/integration/non_ci/responses/test_responses.py create mode 100644 tests/integration/non_ci/responses/test_tool_responses.py diff --git a/tests/integration/non_ci/responses/fixtures/fixtures.py b/tests/integration/non_ci/responses/fixtures/fixtures.py index 2069010ad..62c4ae086 100644 --- a/tests/integration/non_ci/responses/fixtures/fixtures.py +++ b/tests/integration/non_ci/responses/fixtures/fixtures.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -import re from pathlib import Path import pytest @@ -48,19 +47,6 @@ def _load_all_verification_configs(): return {"providers": all_provider_configs} -def case_id_generator(case): - """Generate a test ID from the case's 'case_id' field, or use a default.""" - case_id = case.get("case_id") - if isinstance(case_id, str | int): - return re.sub(r"\\W|^(?=\\d)", "_", str(case_id)) - return None - - -# Helper to get the base test name from the request object -def get_base_test_name(request): - return request.node.originalname - - # --- End Helper Functions --- diff --git a/tests/integration/non_ci/responses/fixtures/load.py b/tests/integration/non_ci/responses/fixtures/load.py deleted file mode 100644 index 0184ee146..000000000 --- a/tests/integration/non_ci/responses/fixtures/load.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pathlib import Path - -import yaml - - -def load_test_cases(name: str): - fixture_dir = Path(__file__).parent / "test_cases" - yaml_path = fixture_dir / f"{name}.yaml" - with open(yaml_path) as f: - return yaml.safe_load(f) diff --git a/tests/integration/non_ci/responses/fixtures/test_cases.py b/tests/integration/non_ci/responses/fixtures/test_cases.py new file mode 100644 index 000000000..bdd1a5d81 --- /dev/null +++ b/tests/integration/non_ci/responses/fixtures/test_cases.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +import pytest +from pydantic import BaseModel + + +class ResponsesTestCase(BaseModel): + # Input can be a simple string or complex message structure + input: str | list[dict[str, Any]] + expected: str + # Tools as flexible dict structure (gets validated at runtime by the API) + tools: list[dict[str, Any]] | None = None + # Multi-turn conversations with input/output pairs + turns: list[tuple[str | list[dict[str, Any]], str]] | None = None + # File search specific fields + file_content: str | None = None + file_path: str | None = None + # Streaming flag + stream: bool | None = None + + +# Basic response test cases +basic_test_cases = [ + pytest.param( + ResponsesTestCase( + input="Which planet do humans live on?", + expected="earth", + ), + id="earth", + ), + pytest.param( + ResponsesTestCase( + input="Which planet has rings around it with a name starting with letter S?", + expected="saturn", + ), + id="saturn", + ), + pytest.param( + ResponsesTestCase( + input=[ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "what teams are playing in this image?", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg", + } + ], + }, + ], + expected="brooklyn nets", + ), + id="image_input", + ), +] + +# Multi-turn test cases +multi_turn_test_cases = [ + pytest.param( + ResponsesTestCase( + input="", # Not used for multi-turn + expected="", # Not used for multi-turn + turns=[ + ("Which planet do humans live on?", "earth"), + ("What is the name of the planet from your previous response?", "earth"), + ], + ), + id="earth", + ), +] + +# Web search test cases +web_search_test_cases = [ + pytest.param( + ResponsesTestCase( + input="How many experts does the Llama 4 Maverick model have?", + tools=[{"type": "web_search", "search_context_size": "low"}], + expected="128", + ), + id="llama_experts", + ), +] + +# File search test cases +file_search_test_cases = [ + pytest.param( + ResponsesTestCase( + input="How many experts does the Llama 4 Maverick model have?", + tools=[{"type": "file_search"}], + expected="128", + file_content="Llama 4 Maverick has 128 experts", + ), + id="llama_experts", + ), + pytest.param( + ResponsesTestCase( + input="How many experts does the Llama 4 Maverick model have?", + tools=[{"type": "file_search"}], + expected="128", + file_path="pdfs/llama_stack_and_models.pdf", + ), + id="llama_experts_pdf", + ), +] + +# MCP tool test cases +mcp_tool_test_cases = [ + pytest.param( + ResponsesTestCase( + input="What is the boiling point of myawesomeliquid in Celsius?", + tools=[{"type": "mcp", "server_label": "localmcp", "server_url": ""}], + expected="Hello, world!", + ), + id="boiling_point_tool", + ), +] + +# Custom tool test cases +custom_tool_test_cases = [ + pytest.param( + ResponsesTestCase( + input="What's the weather like in San Francisco?", + tools=[ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for a given location.", + "parameters": { + "additionalProperties": False, + "properties": { + "location": { + "description": "City and country e.g. Bogotá, Colombia", + "type": "string", + } + }, + "required": ["location"], + "type": "object", + }, + } + ], + expected="", # No specific expected output for custom tools + ), + id="sf_weather", + ), +] + +# Image test cases +image_test_cases = [ + pytest.param( + ResponsesTestCase( + input=[ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Identify the type of animal in this image.", + }, + { + "type": "input_image", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg", + }, + ], + }, + ], + expected="llama", + ), + id="llama_image", + ), +] + +# Multi-turn image test cases +multi_turn_image_test_cases = [ + pytest.param( + ResponsesTestCase( + input="", # Not used for multi-turn + expected="", # Not used for multi-turn + turns=[ + ( + [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'.", + }, + { + "type": "input_image", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg", + }, + ], + }, + ], + "llama", + ), + ( + "What country do you find this animal primarily in? What continent?", + "peru", + ), + ], + ), + id="llama_image_understanding", + ), +] + +# Multi-turn tool execution test cases +multi_turn_tool_execution_test_cases = [ + pytest.param( + ResponsesTestCase( + input="I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.", + tools=[{"type": "mcp", "server_label": "localmcp", "server_url": ""}], + expected="yes", + ), + id="user_file_access_check", + ), + pytest.param( + ResponsesTestCase( + input="I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius.", + tools=[{"type": "mcp", "server_label": "localmcp", "server_url": ""}], + expected="100°C", + ), + id="experiment_results_lookup", + ), +] + +# Multi-turn tool execution streaming test cases +multi_turn_tool_execution_streaming_test_cases = [ + pytest.param( + ResponsesTestCase( + input="Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.", + tools=[{"type": "mcp", "server_label": "localmcp", "server_url": ""}], + expected="no", + stream=True, + ), + id="user_permissions_workflow", + ), + pytest.param( + ResponsesTestCase( + input="I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process.", + tools=[{"type": "mcp", "server_label": "localmcp", "server_url": ""}], + expected="85%", + stream=True, + ), + id="experiment_analysis_streaming", + ), +] diff --git a/tests/integration/non_ci/responses/fixtures/test_cases/chat_completion.yaml b/tests/integration/non_ci/responses/fixtures/test_cases/chat_completion.yaml deleted file mode 100644 index 0c9f1fe9e..000000000 --- a/tests/integration/non_ci/responses/fixtures/test_cases/chat_completion.yaml +++ /dev/null @@ -1,397 +0,0 @@ -test_chat_basic: - test_name: test_chat_basic - test_params: - case: - - case_id: "earth" - input: - messages: - - content: Which planet do humans live on? - role: user - output: Earth - - case_id: "saturn" - input: - messages: - - content: Which planet has rings around it with a name starting with letter - S? - role: user - output: Saturn -test_chat_input_validation: - test_name: test_chat_input_validation - test_params: - case: - - case_id: "messages_missing" - input: - messages: [] - output: - error: - status_code: 400 - - case_id: "messages_role_invalid" - input: - messages: - - content: Which planet do humans live on? - role: fake_role - output: - error: - status_code: 400 - - case_id: "tool_choice_invalid" - input: - messages: - - content: Which planet do humans live on? - role: user - tool_choice: invalid - output: - error: - status_code: 400 - - case_id: "tool_choice_no_tools" - input: - messages: - - content: Which planet do humans live on? - role: user - tool_choice: required - output: - error: - status_code: 400 - - case_id: "tools_type_invalid" - input: - messages: - - content: Which planet do humans live on? - role: user - tools: - - type: invalid - output: - error: - status_code: 400 -test_chat_image: - test_name: test_chat_image - test_params: - case: - - input: - messages: - - content: - - text: What is in this image? - type: text - - image_url: - url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg - type: image_url - role: user - output: llama -test_chat_structured_output: - test_name: test_chat_structured_output - test_params: - case: - - case_id: "calendar" - input: - messages: - - content: Extract the event information. - role: system - - content: Alice and Bob are going to a science fair on Friday. - role: user - response_format: - json_schema: - name: calendar_event - schema: - properties: - date: - title: Date - type: string - name: - title: Name - type: string - participants: - items: - type: string - title: Participants - type: array - required: - - name - - date - - participants - title: CalendarEvent - type: object - type: json_schema - output: valid_calendar_event - - case_id: "math" - input: - messages: - - content: You are a helpful math tutor. Guide the user through the solution - step by step. - role: system - - content: how can I solve 8x + 7 = -23 - role: user - response_format: - json_schema: - name: math_reasoning - schema: - $defs: - Step: - properties: - explanation: - title: Explanation - type: string - output: - title: Output - type: string - required: - - explanation - - output - title: Step - type: object - properties: - final_answer: - title: Final Answer - type: string - steps: - items: - $ref: '#/$defs/Step' - title: Steps - type: array - required: - - steps - - final_answer - title: MathReasoning - type: object - type: json_schema - output: valid_math_reasoning -test_tool_calling: - test_name: test_tool_calling - test_params: - case: - - input: - messages: - - content: You are a helpful assistant that can use tools to get information. - role: system - - content: What's the weather like in San Francisco? - role: user - tools: - - function: - description: Get current temperature for a given location. - name: get_weather - parameters: - additionalProperties: false - properties: - location: - description: "City and country e.g. Bogot\xE1, Colombia" - type: string - required: - - location - type: object - type: function - output: get_weather_tool_call - -test_chat_multi_turn_tool_calling: - test_name: test_chat_multi_turn_tool_calling - test_params: - case: - - case_id: "text_then_weather_tool" - input: - messages: - - - role: user - content: "What's the name of the Sun in latin?" - - - role: user - content: "What's the weather like in San Francisco?" - tools: - - function: - description: Get the current weather - name: get_weather - parameters: - type: object - properties: - location: - description: "The city and state (both required), e.g. San Francisco, CA." - type: string - required: ["location"] - type: function - tool_responses: - - response: "{'response': '70 degrees and foggy'}" - expected: - - num_tool_calls: 0 - answer: ["sol"] - - num_tool_calls: 1 - tool_name: get_weather - tool_arguments: - location: "San Francisco, CA" - - num_tool_calls: 0 - answer: ["foggy", "70 degrees"] - - case_id: "weather_tool_then_text" - input: - messages: - - - role: user - content: "What's the weather like in San Francisco?" - tools: - - function: - description: Get the current weather - name: get_weather - parameters: - type: object - properties: - location: - description: "The city and state (both required), e.g. San Francisco, CA." - type: string - required: ["location"] - type: function - tool_responses: - - response: "{'response': '70 degrees and foggy'}" - expected: - - num_tool_calls: 1 - tool_name: get_weather - tool_arguments: - location: "San Francisco, CA" - - num_tool_calls: 0 - answer: ["foggy", "70 degrees"] - - case_id: "add_product_tool" - input: - messages: - - - role: user - content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id." - tools: - - function: - description: Add a new product - name: addProduct - parameters: - type: object - properties: - name: - description: "Name of the product" - type: string - price: - description: "Price of the product" - type: number - inStock: - description: "Availability status of the product." - type: boolean - tags: - description: "List of product tags" - type: array - items: - type: string - required: ["name", "price", "inStock"] - type: function - tool_responses: - - response: "{'response': 'Successfully added product with id: 123'}" - expected: - - num_tool_calls: 1 - tool_name: addProduct - tool_arguments: - name: "Widget" - price: 19.99 - inStock: true - tags: - - "new" - - "sale" - - num_tool_calls: 0 - answer: ["123", "product id: 123"] - - case_id: "get_then_create_event_tool" - input: - messages: - - - role: system - content: "Todays date is 2025-03-01." - - role: user - content: "Do i have any meetings on March 3rd at 10 am? Yes or no?" - - - role: user - content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id." - tools: - - function: - description: Create a new event - name: create_event - parameters: - type: object - properties: - name: - description: "Name of the event" - type: string - date: - description: "Date of the event in ISO format" - type: string - time: - description: "Event Time (HH:MM)" - type: string - location: - description: "Location of the event" - type: string - participants: - description: "List of participant names" - type: array - items: - type: string - required: ["name", "date", "time", "location", "participants"] - type: function - - function: - description: Get an event by date and time - name: get_event - parameters: - type: object - properties: - date: - description: "Date of the event in ISO format" - type: string - time: - description: "Event Time (HH:MM)" - type: string - required: ["date", "time"] - type: function - tool_responses: - - response: "{'response': 'No events found for 2025-03-03 at 10:00'}" - - response: "{'response': 'Successfully created new event with id: e_123'}" - expected: - - num_tool_calls: 1 - tool_name: get_event - tool_arguments: - date: "2025-03-03" - time: "10:00" - - num_tool_calls: 0 - answer: ["no", "no events found", "no meetings"] - - num_tool_calls: 1 - tool_name: create_event - tool_arguments: - name: "Team Building" - date: "2025-03-03" - time: "10:00" - location: "Main Conference Room" - participants: - - "Alice" - - "Bob" - - "Charlie" - - num_tool_calls: 0 - answer: ["e_123", "event id: e_123"] - - case_id: "compare_monthly_expense_tool" - input: - messages: - - - role: system - content: "Todays date is 2025-03-01." - - role: user - content: "what was my monthly expense in Jan of this year?" - - - role: user - content: "Was it less than Feb of last year? Only answer with yes or no." - tools: - - function: - description: Get monthly expense summary - name: getMonthlyExpenseSummary - parameters: - type: object - properties: - month: - description: "Month of the year (1-12)" - type: integer - year: - description: "Year" - type: integer - required: ["month", "year"] - type: function - tool_responses: - - response: "{'response': 'Total expenses for January 2025: $1000'}" - - response: "{'response': 'Total expenses for February 2024: $2000'}" - expected: - - num_tool_calls: 1 - tool_name: getMonthlyExpenseSummary - tool_arguments: - month: 1 - year: 2025 - - num_tool_calls: 0 - answer: ["1000", "$1,000", "1,000"] - - num_tool_calls: 1 - tool_name: getMonthlyExpenseSummary - tool_arguments: - month: 2 - year: 2024 - - num_tool_calls: 0 - answer: ["yes"] diff --git a/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml b/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml deleted file mode 100644 index 353a64291..000000000 --- a/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml +++ /dev/null @@ -1,166 +0,0 @@ -test_response_basic: - test_name: test_response_basic - test_params: - case: - - case_id: "earth" - input: "Which planet do humans live on?" - output: "earth" - - case_id: "saturn" - input: "Which planet has rings around it with a name starting with letter S?" - output: "saturn" - - case_id: "image_input" - input: - - role: user - content: - - type: input_text - text: "what teams are playing in this image?" - - role: user - content: - - type: input_image - image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg" - output: "brooklyn nets" - -test_response_multi_turn: - test_name: test_response_multi_turn - test_params: - case: - - case_id: "earth" - turns: - - input: "Which planet do humans live on?" - output: "earth" - - input: "What is the name of the planet from your previous response?" - output: "earth" - -test_response_web_search: - test_name: test_response_web_search - test_params: - case: - - case_id: "llama_experts" - input: "How many experts does the Llama 4 Maverick model have?" - tools: - - type: web_search - search_context_size: "low" - output: "128" - -test_response_file_search: - test_name: test_response_file_search - test_params: - case: - - case_id: "llama_experts" - input: "How many experts does the Llama 4 Maverick model have?" - tools: - - type: file_search - # vector_store_ids param for file_search tool gets added by the test runner - file_content: "Llama 4 Maverick has 128 experts" - output: "128" - - case_id: "llama_experts_pdf" - input: "How many experts does the Llama 4 Maverick model have?" - tools: - - type: file_search - # vector_store_ids param for file_search toolgets added by the test runner - file_path: "pdfs/llama_stack_and_models.pdf" - output: "128" - -test_response_mcp_tool: - test_name: test_response_mcp_tool - test_params: - case: - - case_id: "boiling_point_tool" - input: "What is the boiling point of myawesomeliquid in Celsius?" - tools: - - type: mcp - server_label: "localmcp" - server_url: "" - output: "Hello, world!" - -test_response_custom_tool: - test_name: test_response_custom_tool - test_params: - case: - - case_id: "sf_weather" - input: "What's the weather like in San Francisco?" - tools: - - type: function - name: get_weather - description: Get current temperature for a given location. - parameters: - additionalProperties: false - properties: - location: - description: "City and country e.g. Bogot\xE1, Colombia" - type: string - required: - - location - type: object - -test_response_image: - test_name: test_response_image - test_params: - case: - - case_id: "llama_image" - input: - - role: user - content: - - type: input_text - text: "Identify the type of animal in this image." - - type: input_image - image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg" - output: "llama" - -# the models are really poor at tool calling after seeing images :/ -test_response_multi_turn_image: - test_name: test_response_multi_turn_image - test_params: - case: - - case_id: "llama_image_understanding" - turns: - - input: - - role: user - content: - - type: input_text - text: "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'." - - type: input_image - image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg" - output: "llama" - - input: "What country do you find this animal primarily in? What continent?" - output: "peru" - -test_response_multi_turn_tool_execution: - test_name: test_response_multi_turn_tool_execution - test_params: - case: - - case_id: "user_file_access_check" - input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response." - tools: - - type: mcp - server_label: "localmcp" - server_url: "" - output: "yes" - - case_id: "experiment_results_lookup" - input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius." - tools: - - type: mcp - server_label: "localmcp" - server_url: "" - output: "100°C" - -test_response_multi_turn_tool_execution_streaming: - test_name: test_response_multi_turn_tool_execution_streaming - test_params: - case: - - case_id: "user_permissions_workflow" - input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response." - tools: - - type: mcp - server_label: "localmcp" - server_url: "" - stream: true - output: "no" - - case_id: "experiment_analysis_streaming" - input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process." - tools: - - type: mcp - server_label: "localmcp" - server_url: "" - stream: true - output: "85%" diff --git a/tests/integration/non_ci/responses/helpers.py b/tests/integration/non_ci/responses/helpers.py new file mode 100644 index 000000000..7c988402f --- /dev/null +++ b/tests/integration/non_ci/responses/helpers.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time + + +def new_vector_store(openai_client, name): + """Create a new vector store, cleaning up any existing one with the same name.""" + # Ensure we don't reuse an existing vector store + vector_stores = openai_client.vector_stores.list() + for vector_store in vector_stores: + if vector_store.name == name: + openai_client.vector_stores.delete(vector_store_id=vector_store.id) + + # Create a new vector store + vector_store = openai_client.vector_stores.create(name=name) + return vector_store + + +def upload_file(openai_client, name, file_path): + """Upload a file, cleaning up any existing file with the same name.""" + # Ensure we don't reuse an existing file + files = openai_client.files.list() + for file in files: + if file.filename == name: + openai_client.files.delete(file_id=file.id) + + # Upload a text file with our document content + return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants") + + +def wait_for_file_attachment(compat_client, vector_store_id, file_id): + """Wait for a file to be attached to a vector store.""" + file_attach_response = compat_client.vector_stores.files.retrieve( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + while file_attach_response.status == "in_progress": + time.sleep(0.1) + file_attach_response = compat_client.vector_stores.files.retrieve( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}" + assert not file_attach_response.last_error + return file_attach_response + + +def setup_mcp_tools(tools, mcp_server_info): + """Replace placeholder MCP server URLs with actual server info.""" + # Create a deep copy to avoid modifying the original test case + import copy + + tools_copy = copy.deepcopy(tools) + + for tool in tools_copy: + if tool["type"] == "mcp" and tool["server_url"] == "": + tool["server_url"] = mcp_server_info["server_url"] + return tools_copy diff --git a/tests/integration/non_ci/responses/streaming_assertions.py b/tests/integration/non_ci/responses/streaming_assertions.py new file mode 100644 index 000000000..4279ffbab --- /dev/null +++ b/tests/integration/non_ci/responses/streaming_assertions.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + + +class StreamingValidator: + """Helper class for validating streaming response events.""" + + def __init__(self, chunks: list[Any]): + self.chunks = chunks + self.event_types = [chunk.type for chunk in chunks] + + def assert_basic_event_sequence(self): + """Verify basic created -> completed event sequence.""" + assert len(self.chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(self.chunks)}" + assert self.chunks[0].type == "response.created", ( + f"First chunk should be response.created, got {self.chunks[0].type}" + ) + assert self.chunks[-1].type == "response.completed", ( + f"Last chunk should be response.completed, got {self.chunks[-1].type}" + ) + + # Verify event order + created_index = self.event_types.index("response.created") + completed_index = self.event_types.index("response.completed") + assert created_index < completed_index, "response.created should come before response.completed" + + def assert_response_consistency(self): + """Verify response ID consistency across events.""" + response_ids = set() + for chunk in self.chunks: + if hasattr(chunk, "response_id"): + response_ids.add(chunk.response_id) + elif hasattr(chunk, "response") and hasattr(chunk.response, "id"): + response_ids.add(chunk.response.id) + + assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}" + + def assert_has_incremental_content(self): + """Verify that content is delivered incrementally via delta events.""" + delta_events = [ + i for i, event_type in enumerate(self.event_types) if event_type == "response.output_text.delta" + ] + assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none" + + # Verify delta events have content + non_empty_deltas = 0 + delta_content_total = "" + + for delta_idx in delta_events: + chunk = self.chunks[delta_idx] + if hasattr(chunk, "delta") and chunk.delta: + delta_content_total += chunk.delta + non_empty_deltas += 1 + + assert non_empty_deltas > 0, "Delta events found but none contain content" + assert len(delta_content_total) > 0, "Delta events found but total delta content is empty" + + return delta_content_total + + def assert_content_quality(self, expected_content: str): + """Verify the final response contains expected content.""" + final_chunk = self.chunks[-1] + if hasattr(final_chunk, "response"): + output_text = final_chunk.response.output_text.lower().strip() + assert len(output_text) > 0, "Response should have content" + assert expected_content.lower() in output_text, f"Expected '{expected_content}' in response" + + def assert_has_tool_calls(self): + """Verify tool call streaming events are present.""" + # Check for tool call events + delta_events = [ + chunk + for chunk in self.chunks + if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"] + ] + done_events = [ + chunk + for chunk in self.chunks + if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"] + ] + + assert len(delta_events) > 0, f"Expected tool call delta events, got chunk types: {self.event_types}" + assert len(done_events) > 0, f"Expected tool call done events, got chunk types: {self.event_types}" + + # Verify output item events + item_added_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.added"] + item_done_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.done"] + + assert len(item_added_events) > 0, ( + f"Expected response.output_item.added events, got chunk types: {self.event_types}" + ) + assert len(item_done_events) > 0, ( + f"Expected response.output_item.done events, got chunk types: {self.event_types}" + ) + + def assert_has_mcp_events(self): + """Verify MCP-specific streaming events are present.""" + # Tool execution progress events + mcp_in_progress_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.in_progress"] + mcp_completed_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.completed"] + + assert len(mcp_in_progress_events) > 0, ( + f"Expected response.mcp_call.in_progress events, got chunk types: {self.event_types}" + ) + assert len(mcp_completed_events) > 0, ( + f"Expected response.mcp_call.completed events, got chunk types: {self.event_types}" + ) + + # MCP list tools events + mcp_list_tools_in_progress_events = [ + chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.in_progress" + ] + mcp_list_tools_completed_events = [ + chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.completed" + ] + + assert len(mcp_list_tools_in_progress_events) > 0, ( + f"Expected response.mcp_list_tools.in_progress events, got chunk types: {self.event_types}" + ) + assert len(mcp_list_tools_completed_events) > 0, ( + f"Expected response.mcp_list_tools.completed events, got chunk types: {self.event_types}" + ) + + def assert_rich_streaming(self, min_chunks: int = 10): + """Verify we have substantial streaming activity.""" + assert len(self.chunks) > min_chunks, ( + f"Expected rich streaming with many events, got only {len(self.chunks)} chunks" + ) + + def validate_event_structure(self): + """Validate the structure of various event types.""" + for chunk in self.chunks: + if chunk.type == "response.created": + assert chunk.response.status == "in_progress" + elif chunk.type == "response.completed": + assert chunk.response.status == "completed" + elif hasattr(chunk, "item_id"): + assert chunk.item_id, "Events with item_id should have non-empty item_id" + elif hasattr(chunk, "sequence_number"): + assert isinstance(chunk.sequence_number, int), "sequence_number should be an integer" diff --git a/tests/integration/non_ci/responses/test_basic_responses.py b/tests/integration/non_ci/responses/test_basic_responses.py new file mode 100644 index 000000000..a8106e593 --- /dev/null +++ b/tests/integration/non_ci/responses/test_basic_responses.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time + +import pytest +from fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_image_test_cases, multi_turn_test_cases +from streaming_assertions import StreamingValidator + + +@pytest.mark.parametrize("case", basic_test_cases) +def test_response_non_streaming_basic(compat_client, text_model_id, case): + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + stream=False, + ) + output_text = response.output_text.lower().strip() + assert len(output_text) > 0 + assert case.expected.lower() in output_text + + retrieved_response = compat_client.responses.retrieve(response_id=response.id) + assert retrieved_response.output_text == response.output_text + + next_response = compat_client.responses.create( + model=text_model_id, + input="Repeat your previous response in all caps.", + previous_response_id=response.id, + ) + next_output_text = next_response.output_text.strip() + assert case.expected.upper() in next_output_text + + +@pytest.mark.parametrize("case", basic_test_cases) +def test_response_streaming_basic(compat_client, text_model_id, case): + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + stream=True, + ) + + # Track events and timing to verify proper streaming + events = [] + event_times = [] + response_id = "" + + start_time = time.time() + + for chunk in response: + current_time = time.time() + event_times.append(current_time - start_time) + events.append(chunk) + + if chunk.type == "response.created": + # Verify response.created is emitted first and immediately + assert len(events) == 1, "response.created should be the first event" + assert event_times[0] < 0.1, "response.created should be emitted immediately" + assert chunk.response.status == "in_progress" + response_id = chunk.response.id + + elif chunk.type == "response.completed": + # Verify response.completed comes after response.created + assert len(events) >= 2, "response.completed should come after response.created" + assert chunk.response.status == "completed" + assert chunk.response.id == response_id, "Response ID should be consistent" + + # Verify content quality + output_text = chunk.response.output_text.lower().strip() + assert len(output_text) > 0, "Response should have content" + assert case.expected.lower() in output_text, f"Expected '{case.expected}' in response" + + # Use validator for common checks + validator = StreamingValidator(events) + validator.assert_basic_event_sequence() + validator.assert_response_consistency() + + # Verify stored response matches streamed response + retrieved_response = compat_client.responses.retrieve(response_id=response_id) + final_event = events[-1] + assert retrieved_response.output_text == final_event.response.output_text + + +@pytest.mark.parametrize("case", basic_test_cases) +def test_response_streaming_incremental_content(compat_client, text_model_id, case): + """Test that streaming actually delivers content incrementally, not just at the end.""" + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + stream=True, + ) + + # Track all events and their content to verify incremental streaming + events = [] + content_snapshots = [] + event_times = [] + + start_time = time.time() + + for chunk in response: + current_time = time.time() + event_times.append(current_time - start_time) + events.append(chunk) + + # Track content at each event based on event type + if chunk.type == "response.output_text.delta": + # For delta events, track the delta content + content_snapshots.append(chunk.delta) + elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"): + # For response.created/completed events, track the full output_text + content_snapshots.append(chunk.response.output_text) + else: + content_snapshots.append("") + + validator = StreamingValidator(events) + validator.assert_basic_event_sequence() + + # Check if we have incremental content updates + event_types = [event.type for event in events] + created_index = event_types.index("response.created") + completed_index = event_types.index("response.completed") + + # The key test: verify content progression + created_content = content_snapshots[created_index] + completed_content = content_snapshots[completed_index] + + # Verify that response.created has empty or minimal content + assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}" + + # Verify that response.completed has the full content + assert len(completed_content) > 0, "response.completed should have content" + assert case.expected.lower() in completed_content.lower(), f"Expected '{case.expected}' in final content" + + # Use validator for incremental content checks + delta_content_total = validator.assert_has_incremental_content() + + # Verify that the accumulated delta content matches the final content + assert delta_content_total.strip() == completed_content.strip(), ( + f"Delta content '{delta_content_total}' should match final content '{completed_content}'" + ) + + # Verify timing: delta events should come between created and completed + delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"] + for delta_idx in delta_events: + assert created_index < delta_idx < completed_index, ( + f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})" + ) + + +@pytest.mark.parametrize("case", multi_turn_test_cases) +def test_response_non_streaming_multi_turn(compat_client, text_model_id, case): + previous_response_id = None + for turn_input, turn_expected in case.turns: + response = compat_client.responses.create( + model=text_model_id, + input=turn_input, + previous_response_id=previous_response_id, + ) + previous_response_id = response.id + output_text = response.output_text.lower() + assert turn_expected.lower() in output_text + + +@pytest.mark.parametrize("case", image_test_cases) +def test_response_non_streaming_image(compat_client, text_model_id, case): + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + stream=False, + ) + output_text = response.output_text.lower() + assert case.expected.lower() in output_text + + +@pytest.mark.parametrize("case", multi_turn_image_test_cases) +def test_response_non_streaming_multi_turn_image(compat_client, text_model_id, case): + previous_response_id = None + for turn_input, turn_expected in case.turns: + response = compat_client.responses.create( + model=text_model_id, + input=turn_input, + previous_response_id=previous_response_id, + ) + previous_response_id = response.id + output_text = response.output_text.lower() + assert turn_expected.lower() in output_text diff --git a/tests/integration/non_ci/responses/test_file_search.py b/tests/integration/non_ci/responses/test_file_search.py new file mode 100644 index 000000000..ba7775a0b --- /dev/null +++ b/tests/integration/non_ci/responses/test_file_search.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import time + +import pytest + +from llama_stack import LlamaStackAsLibraryClient + +from .helpers import new_vector_store, upload_file + + +@pytest.mark.parametrize( + "text_format", + # Not testing json_object because most providers don't actually support it. + [ + {"type": "text"}, + { + "type": "json_schema", + "name": "capitals", + "description": "A schema for the capital of each country", + "schema": {"type": "object", "properties": {"capital": {"type": "string"}}}, + "strict": True, + }, + ], +) +def test_response_text_format(compat_client, text_model_id, text_format): + if isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("Responses API text format is not yet supported in library client.") + + stream = False + response = compat_client.responses.create( + model=text_model_id, + input="What is the capital of France?", + stream=stream, + text={"format": text_format}, + ) + # by_alias=True is needed because otherwise Pydantic renames our "schema" field + assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format + assert "paris" in response.output_text.lower() + if text_format["type"] == "json_schema": + assert "paris" in json.loads(response.output_text)["capital"].lower() + + +@pytest.fixture +def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory): + """Create a vector store with multiple files that have different attributes for filtering tests.""" + if isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("Responses API file search is not yet supported in library client.") + + vector_store = new_vector_store(compat_client, "test_vector_store_with_filters") + tmp_path = tmp_path_factory.mktemp("filter_test_files") + + # Create multiple files with different attributes + files_data = [ + { + "name": "us_marketing_q1.txt", + "content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.", + "attributes": { + "region": "us", + "category": "marketing", + "date": 1672531200, # Jan 1, 2023 + }, + }, + { + "name": "us_engineering_q2.txt", + "content": "US technical updates for Q2 2023. New features deployed in the US region.", + "attributes": { + "region": "us", + "category": "engineering", + "date": 1680307200, # Apr 1, 2023 + }, + }, + { + "name": "eu_marketing_q1.txt", + "content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.", + "attributes": { + "region": "eu", + "category": "marketing", + "date": 1672531200, # Jan 1, 2023 + }, + }, + { + "name": "asia_sales_q3.txt", + "content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.", + "attributes": { + "region": "asia", + "category": "sales", + "date": 1688169600, # Jul 1, 2023 + }, + }, + ] + + file_ids = [] + for file_data in files_data: + # Create file + file_path = tmp_path / file_data["name"] + file_path.write_text(file_data["content"]) + + # Upload file + file_response = upload_file(compat_client, file_data["name"], str(file_path)) + file_ids.append(file_response.id) + + # Attach file to vector store with attributes + file_attach_response = compat_client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file_response.id, + attributes=file_data["attributes"], + ) + + # Wait for attachment + while file_attach_response.status == "in_progress": + time.sleep(0.1) + file_attach_response = compat_client.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=file_response.id, + ) + assert file_attach_response.status == "completed" + + yield vector_store + + # Cleanup: delete vector store and files + try: + compat_client.vector_stores.delete(vector_store_id=vector_store.id) + for file_id in file_ids: + try: + compat_client.files.delete(file_id=file_id) + except Exception: + pass # File might already be deleted + except Exception: + pass # Best effort cleanup + + +def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files): + """Test file search with region equality filter.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": {"type": "eq", "key": "region", "value": "us"}, + } + ] + + response = compat_client.responses.create( + model=text_model_id, + input="What are the updates from the US region?", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + # Verify file search was called with US filter + assert len(response.output) > 1 + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return US files (not EU or Asia files) + for result in response.output[0].results: + assert "us" in result.text.lower() or "US" in result.text + # Ensure non-US regions are NOT returned + assert "european" not in result.text.lower() + assert "asia" not in result.text.lower() + + +def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files): + """Test file search with category equality filter.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": {"type": "eq", "key": "category", "value": "marketing"}, + } + ] + + response = compat_client.responses.create( + model=text_model_id, + input="Show me all marketing reports", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return marketing files (not engineering or sales) + for result in response.output[0].results: + # Marketing files should have promotional/advertising content + assert "promotional" in result.text.lower() or "advertising" in result.text.lower() + # Ensure non-marketing categories are NOT returned + assert "technical" not in result.text.lower() + assert "revenue figures" not in result.text.lower() + + +def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files): + """Test file search with date range filter using compound AND.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": { + "type": "and", + "filters": [ + { + "type": "gte", + "key": "date", + "value": 1672531200, # Jan 1, 2023 + }, + { + "type": "lt", + "key": "date", + "value": 1680307200, # Apr 1, 2023 + }, + ], + }, + } + ] + + response = compat_client.responses.create( + model=text_model_id, + input="What happened in Q1 2023?", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return Q1 files (not Q2 or Q3) + for result in response.output[0].results: + assert "q1" in result.text.lower() + # Ensure non-Q1 quarters are NOT returned + assert "q2" not in result.text.lower() + assert "q3" not in result.text.lower() + + +def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files): + """Test file search with compound AND filter (region AND category).""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": { + "type": "and", + "filters": [ + {"type": "eq", "key": "region", "value": "us"}, + {"type": "eq", "key": "category", "value": "engineering"}, + ], + }, + } + ] + + response = compat_client.responses.create( + model=text_model_id, + input="What are the engineering updates from the US?", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should only return US engineering files + assert len(response.output[0].results) >= 1 + for result in response.output[0].results: + assert "us" in result.text.lower() and "technical" in result.text.lower() + # Ensure it's not from other regions or categories + assert "european" not in result.text.lower() and "asia" not in result.text.lower() + assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower() + + +def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files): + """Test file search with compound OR filter (marketing OR sales).""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_with_filtered_files.id], + "filters": { + "type": "or", + "filters": [ + {"type": "eq", "key": "category", "value": "marketing"}, + {"type": "eq", "key": "category", "value": "sales"}, + ], + }, + } + ] + + response = compat_client.responses.create( + model=text_model_id, + input="Show me marketing and sales documents", + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].results + # Should return marketing and sales files, but NOT engineering + categories_found = set() + for result in response.output[0].results: + text_lower = result.text.lower() + if "promotional" in text_lower or "advertising" in text_lower: + categories_found.add("marketing") + if "revenue figures" in text_lower: + categories_found.add("sales") + # Ensure engineering files are NOT returned + assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}" + + # Verify we got at least one of the expected categories + assert len(categories_found) > 0, "Should have found at least one marketing or sales file" + assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}" diff --git a/tests/integration/non_ci/responses/test_responses.py b/tests/integration/non_ci/responses/test_responses.py deleted file mode 100644 index 954f009c2..000000000 --- a/tests/integration/non_ci/responses/test_responses.py +++ /dev/null @@ -1,1143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import os -import time - -import httpx -import openai -import pytest - -from llama_stack import LlamaStackAsLibraryClient -from llama_stack.core.datatypes import AuthenticationRequiredError -from tests.common.mcp import dependency_tools, make_mcp_server - -from .fixtures.fixtures import case_id_generator -from .fixtures.load import load_test_cases - -responses_test_cases = load_test_cases("responses") - - -def _new_vector_store(openai_client, name): - # Ensure we don't reuse an existing vector store - vector_stores = openai_client.vector_stores.list() - for vector_store in vector_stores: - if vector_store.name == name: - openai_client.vector_stores.delete(vector_store_id=vector_store.id) - - # Create a new vector store - vector_store = openai_client.vector_stores.create( - name=name, - ) - return vector_store - - -def _upload_file(openai_client, name, file_path): - # Ensure we don't reuse an existing file - files = openai_client.files.list() - for file in files: - if file.filename == name: - openai_client.files.delete(file_id=file.id) - - # Upload a text file with our document content - return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants") - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_basic"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_basic(request, compat_client, text_model_id, case): - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - stream=False, - ) - output_text = response.output_text.lower().strip() - assert len(output_text) > 0 - assert case["output"].lower() in output_text - - retrieved_response = compat_client.responses.retrieve(response_id=response.id) - assert retrieved_response.output_text == response.output_text - - next_response = compat_client.responses.create( - model=text_model_id, - input="Repeat your previous response in all caps.", - previous_response_id=response.id, - ) - next_output_text = next_response.output_text.strip() - assert case["output"].upper() in next_output_text - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_basic"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_streaming_basic(request, compat_client, text_model_id, case): - import time - - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - stream=True, - ) - - # Track events and timing to verify proper streaming - events = [] - event_times = [] - response_id = "" - - start_time = time.time() - - for chunk in response: - current_time = time.time() - event_times.append(current_time - start_time) - events.append(chunk) - - if chunk.type == "response.created": - # Verify response.created is emitted first and immediately - assert len(events) == 1, "response.created should be the first event" - assert event_times[0] < 0.1, "response.created should be emitted immediately" - assert chunk.response.status == "in_progress" - response_id = chunk.response.id - - elif chunk.type == "response.completed": - # Verify response.completed comes after response.created - assert len(events) >= 2, "response.completed should come after response.created" - assert chunk.response.status == "completed" - assert chunk.response.id == response_id, "Response ID should be consistent" - - # Verify content quality - output_text = chunk.response.output_text.lower().strip() - assert len(output_text) > 0, "Response should have content" - assert case["output"].lower() in output_text, f"Expected '{case['output']}' in response" - - # Verify we got both required events - event_types = [event.type for event in events] - assert "response.created" in event_types, "Missing response.created event" - assert "response.completed" in event_types, "Missing response.completed event" - - # Verify event order - created_index = event_types.index("response.created") - completed_index = event_types.index("response.completed") - assert created_index < completed_index, "response.created should come before response.completed" - - # Verify stored response matches streamed response - retrieved_response = compat_client.responses.retrieve(response_id=response_id) - final_event = events[-1] - assert retrieved_response.output_text == final_event.response.output_text - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_basic"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_streaming_incremental_content(request, compat_client, text_model_id, case): - """Test that streaming actually delivers content incrementally, not just at the end.""" - import time - - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - stream=True, - ) - - # Track all events and their content to verify incremental streaming - events = [] - content_snapshots = [] - event_times = [] - - start_time = time.time() - - for chunk in response: - current_time = time.time() - event_times.append(current_time - start_time) - events.append(chunk) - - # Track content at each event based on event type - if chunk.type == "response.output_text.delta": - # For delta events, track the delta content - content_snapshots.append(chunk.delta) - elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"): - # For response.created/completed events, track the full output_text - content_snapshots.append(chunk.response.output_text) - else: - content_snapshots.append("") - - # Verify we have the expected events - event_types = [event.type for event in events] - assert "response.created" in event_types, "Missing response.created event" - assert "response.completed" in event_types, "Missing response.completed event" - - # Check if we have incremental content updates - created_index = event_types.index("response.created") - completed_index = event_types.index("response.completed") - - # The key test: verify content progression - created_content = content_snapshots[created_index] - completed_content = content_snapshots[completed_index] - - # Verify that response.created has empty or minimal content - assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}" - - # Verify that response.completed has the full content - assert len(completed_content) > 0, "response.completed should have content" - assert case["output"].lower() in completed_content.lower(), f"Expected '{case['output']}' in final content" - - # Check for true incremental streaming by looking for delta events - delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"] - - # Assert that we have delta events (true incremental streaming) - assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none" - - # Verify delta events have content and accumulate to final content - delta_content_total = "" - non_empty_deltas = 0 - - for delta_idx in delta_events: - delta_content = content_snapshots[delta_idx] - if delta_content: - delta_content_total += delta_content - non_empty_deltas += 1 - - # Assert that we have meaningful delta content - assert non_empty_deltas > 0, "Delta events found but none contain content" - assert len(delta_content_total) > 0, "Delta events found but total delta content is empty" - - # Verify that the accumulated delta content matches the final content - assert delta_content_total.strip() == completed_content.strip(), ( - f"Delta content '{delta_content_total}' should match final content '{completed_content}'" - ) - - # Verify timing: delta events should come between created and completed - for delta_idx in delta_events: - assert created_index < delta_idx < completed_index, ( - f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})" - ) - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_multi_turn"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_multi_turn(request, compat_client, text_model_id, case): - previous_response_id = None - for turn in case["turns"]: - response = compat_client.responses.create( - model=text_model_id, - input=turn["input"], - previous_response_id=previous_response_id, - tools=turn["tools"] if "tools" in turn else None, - ) - previous_response_id = response.id - output_text = response.output_text.lower() - assert turn["output"].lower() in output_text - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_web_search"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_web_search(request, compat_client, text_model_id, case): - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - tools=case["tools"], - stream=False, - ) - assert len(response.output) > 1 - assert response.output[0].type == "web_search_call" - assert response.output[0].status == "completed" - assert response.output[1].type == "message" - assert response.output[1].status == "completed" - assert response.output[1].role == "assistant" - assert len(response.output[1].content) > 0 - assert case["output"].lower() in response.output_text.lower().strip() - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_file_search"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_file_search(request, compat_client, text_model_id, tmp_path, case): - if isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("Responses API file search is not yet supported in library client.") - - vector_store = _new_vector_store(compat_client, "test_vector_store") - - if "file_content" in case: - file_name = "test_response_non_streaming_file_search.txt" - file_path = tmp_path / file_name - file_path.write_text(case["file_content"]) - elif "file_path" in case: - file_path = os.path.join(os.path.dirname(__file__), "fixtures", case["file_path"]) - file_name = os.path.basename(file_path) - else: - raise ValueError(f"No file content or path provided for case {case['case_id']}") - - file_response = _upload_file(compat_client, file_name, file_path) - - # Attach our file to the vector store - file_attach_response = compat_client.vector_stores.files.create( - vector_store_id=vector_store.id, - file_id=file_response.id, - ) - - # Wait for the file to be attached - while file_attach_response.status == "in_progress": - time.sleep(0.1) - file_attach_response = compat_client.vector_stores.files.retrieve( - vector_store_id=vector_store.id, - file_id=file_response.id, - ) - assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}" - assert not file_attach_response.last_error - - # Update our tools with the right vector store id - tools = case["tools"] - for tool in tools: - if tool["type"] == "file_search": - tool["vector_store_ids"] = [vector_store.id] - - # Create the response request, which should query our vector store - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - tools=tools, - stream=False, - include=["file_search_call.results"], - ) - - # Verify the file_search_tool was called - assert len(response.output) > 1 - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].queries # ensure it's some non-empty list - assert response.output[0].results - assert case["output"].lower() in response.output[0].results[0].text.lower() - assert response.output[0].results[0].score > 0 - - # Verify the output_text generated by the response - assert case["output"].lower() in response.output_text.lower().strip() - - -def test_response_non_streaming_file_search_empty_vector_store(request, compat_client, text_model_id): - if isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("Responses API file search is not yet supported in library client.") - - vector_store = _new_vector_store(compat_client, "test_vector_store") - - # Create the response request, which should query our vector store - response = compat_client.responses.create( - model=text_model_id, - input="How many experts does the Llama 4 Maverick model have?", - tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}], - stream=False, - include=["file_search_call.results"], - ) - - # Verify the file_search_tool was called - assert len(response.output) > 1 - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].queries # ensure it's some non-empty list - assert not response.output[0].results # ensure we don't get any results - - # Verify some output_text was generated by the response - assert response.output_text - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_mcp_tool"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id, case): - if not isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("in-process MCP server is only supported in library client") - - with make_mcp_server() as mcp_server_info: - tools = case["tools"] - for tool in tools: - if tool["type"] == "mcp": - tool["server_url"] = mcp_server_info["server_url"] - - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - tools=tools, - stream=False, - ) - - assert len(response.output) >= 3 - list_tools = response.output[0] - assert list_tools.type == "mcp_list_tools" - assert list_tools.server_label == "localmcp" - assert len(list_tools.tools) == 2 - assert {t.name for t in list_tools.tools} == { - "get_boiling_point", - "greet_everyone", - } - - call = response.output[1] - assert call.type == "mcp_call" - assert call.name == "get_boiling_point" - assert json.loads(call.arguments) == { - "liquid_name": "myawesomeliquid", - "celsius": True, - } - assert call.error is None - assert "-100" in call.output - - # sometimes the model will call the tool again, so we need to get the last message - message = response.output[-1] - text_content = message.content[0].text - assert "boiling point" in text_content.lower() - - with make_mcp_server(required_auth_token="test-token") as mcp_server_info: - tools = case["tools"] - for tool in tools: - if tool["type"] == "mcp": - tool["server_url"] = mcp_server_info["server_url"] - - exc_type = ( - AuthenticationRequiredError - if isinstance(compat_client, LlamaStackAsLibraryClient) - else (httpx.HTTPStatusError, openai.AuthenticationError) - ) - with pytest.raises(exc_type): - compat_client.responses.create( - model=text_model_id, - input=case["input"], - tools=tools, - stream=False, - ) - - for tool in tools: - if tool["type"] == "mcp": - tool["server_url"] = mcp_server_info["server_url"] - tool["headers"] = {"Authorization": "Bearer test-token"} - - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - tools=tools, - stream=False, - ) - assert len(response.output) >= 3 - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_custom_tool"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_custom_tool(request, compat_client, text_model_id, case): - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - tools=case["tools"], - stream=False, - ) - assert len(response.output) == 1 - assert response.output[0].type == "function_call" - assert response.output[0].status == "completed" - assert response.output[0].name == "get_weather" - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_image"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_image(request, compat_client, text_model_id, case): - response = compat_client.responses.create( - model=text_model_id, - input=case["input"], - stream=False, - ) - output_text = response.output_text.lower() - assert case["output"].lower() in output_text - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_multi_turn_image"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_multi_turn_image(request, compat_client, text_model_id, case): - previous_response_id = None - for turn in case["turns"]: - response = compat_client.responses.create( - model=text_model_id, - input=turn["input"], - previous_response_id=previous_response_id, - tools=turn["tools"] if "tools" in turn else None, - ) - previous_response_id = response.id - output_text = response.output_text.lower() - assert turn["output"].lower() in output_text - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): - """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" - if not isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("in-process MCP server is only supported in library client") - - with make_mcp_server(tools=dependency_tools()) as mcp_server_info: - tools = case["tools"] - # Replace the placeholder URL with the actual server URL - for tool in tools: - if tool["type"] == "mcp" and tool["server_url"] == "": - tool["server_url"] = mcp_server_info["server_url"] - - response = compat_client.responses.create( - input=case["input"], - model=text_model_id, - tools=tools, - ) - - # Verify we have MCP tool calls in the output - mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"] - - mcp_calls = [output for output in response.output if output.type == "mcp_call"] - message_outputs = [output for output in response.output if output.type == "message"] - - # Should have exactly 1 MCP list tools message (at the beginning) - assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" - assert mcp_list_tools[0].server_label == "localmcp" - assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools - expected_tool_names = { - "get_user_id", - "get_user_permissions", - "check_file_access", - "get_experiment_id", - "get_experiment_results", - } - assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names - - assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}" - for mcp_call in mcp_calls: - assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}" - - assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}" - - final_message = message_outputs[-1] - assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}" - assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}" - assert len(final_message.content) > 0, "Final message should have content" - - expected_output = case["output"] - assert expected_output.lower() in response.output_text.lower(), ( - f"Expected '{expected_output}' to appear in response: {response.output_text}" - ) - - -@pytest.mark.parametrize( - "case", - responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"], - ids=case_id_generator, -) -def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): - """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" - if not isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("in-process MCP server is only supported in library client") - - with make_mcp_server(tools=dependency_tools()) as mcp_server_info: - tools = case["tools"] - # Replace the placeholder URL with the actual server URL - for tool in tools: - if tool["type"] == "mcp" and tool["server_url"] == "": - tool["server_url"] = mcp_server_info["server_url"] - - stream = compat_client.responses.create( - input=case["input"], - model=text_model_id, - tools=tools, - stream=True, - ) - - chunks = [] - for chunk in stream: - chunks.append(chunk) - - # Should have at least response.created and response.completed - assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}" - - # First chunk should be response.created - assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}" - - # Last chunk should be response.completed - assert chunks[-1].type == "response.completed", ( - f"Last chunk should be response.completed, got {chunks[-1].type}" - ) - - # Verify tool call streaming events are present - chunk_types = [chunk.type for chunk in chunks] - - # Should have function call or MCP arguments delta/done events for tool calls - delta_events = [ - chunk - for chunk in chunks - if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"] - ] - done_events = [ - chunk - for chunk in chunks - if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"] - ] - - # Should have output item events for tool calls - item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"] - item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"] - - # Should have tool execution progress events - mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"] - mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"] - - # Should have MCP list tools streaming events - mcp_list_tools_in_progress_events = [ - chunk for chunk in chunks if chunk.type == "response.mcp_list_tools.in_progress" - ] - mcp_list_tools_completed_events = [ - chunk for chunk in chunks if chunk.type == "response.mcp_list_tools.completed" - ] - - # Verify we have substantial streaming activity (not just batch events) - assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks" - - # Since this test involves MCP tool calls, we should see streaming events - assert len(delta_events) > 0, ( - f"Expected function_call_arguments.delta or mcp_call.arguments.delta events, got chunk types: {chunk_types}" - ) - assert len(done_events) > 0, ( - f"Expected function_call_arguments.done or mcp_call.arguments.done events, got chunk types: {chunk_types}" - ) - - # Should have output item events for function calls - assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}" - assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}" - - # Should have tool execution progress events - assert len(mcp_in_progress_events) > 0, ( - f"Expected response.mcp_call.in_progress events, got chunk types: {chunk_types}" - ) - assert len(mcp_completed_events) > 0, ( - f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}" - ) - - # Should have MCP list tools streaming events - assert len(mcp_list_tools_in_progress_events) > 0, ( - f"Expected response.mcp_list_tools.in_progress events, got chunk types: {chunk_types}" - ) - assert len(mcp_list_tools_completed_events) > 0, ( - f"Expected response.mcp_list_tools.completed events, got chunk types: {chunk_types}" - ) - # MCP failed events are optional (only if errors occur) - - # Verify progress events have proper structure - for progress_event in mcp_in_progress_events: - assert hasattr(progress_event, "item_id"), "Progress event should have 'item_id' field" - assert hasattr(progress_event, "output_index"), "Progress event should have 'output_index' field" - assert hasattr(progress_event, "sequence_number"), "Progress event should have 'sequence_number' field" - - for completed_event in mcp_completed_events: - assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field" - - # Verify MCP list tools events have proper structure - for list_tools_progress_event in mcp_list_tools_in_progress_events: - assert hasattr(list_tools_progress_event, "sequence_number"), ( - "MCP list tools progress event should have 'sequence_number' field" - ) - - for list_tools_completed_event in mcp_list_tools_completed_events: - assert hasattr(list_tools_completed_event, "sequence_number"), ( - "MCP list tools completed event should have 'sequence_number' field" - ) - - # Verify delta events have proper structure - for delta_event in delta_events: - assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field" - assert hasattr(delta_event, "item_id"), "Delta event should have 'item_id' field" - assert hasattr(delta_event, "sequence_number"), "Delta event should have 'sequence_number' field" - assert delta_event.delta, "Delta should not be empty" - - # Verify done events have proper structure - for done_event in done_events: - assert hasattr(done_event, "arguments"), "Done event should have 'arguments' field" - assert hasattr(done_event, "item_id"), "Done event should have 'item_id' field" - assert done_event.arguments, "Final arguments should not be empty" - - # Verify output item added events have proper structure - for added_event in item_added_events: - assert hasattr(added_event, "item"), "Added event should have 'item' field" - assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field" - assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field" - assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field" - assert added_event.item.type in ["function_call", "mcp_call", "mcp_list_tools"], ( - "Added item should be a tool call or MCP list tools" - ) - if added_event.item.type in ["function_call", "mcp_call"]: - assert added_event.item.status == "in_progress", "Added tool call should be in progress" - # Note: mcp_list_tools doesn't have a status field, it's implicitly completed when added - assert added_event.response_id, "Response ID should not be empty" - assert isinstance(added_event.output_index, int), "Output index should be integer" - assert added_event.output_index >= 0, "Output index should be non-negative" - - # Verify output item done events have proper structure - for done_event in item_done_events: - assert hasattr(done_event, "item"), "Done event should have 'item' field" - assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field" - assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field" - assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field" - assert done_event.item.type in ["function_call", "mcp_call", "mcp_list_tools"], ( - "Done item should be a tool call or MCP list tools" - ) - # Note: MCP calls and mcp_list_tools don't have a status field, only function calls do - if done_event.item.type == "function_call": - assert done_event.item.status == "completed", "Function call should be completed" - # Note: mcp_call and mcp_list_tools don't have status fields - assert done_event.response_id, "Response ID should not be empty" - assert isinstance(done_event.output_index, int), "Output index should be integer" - assert done_event.output_index >= 0, "Output index should be non-negative" - - # Group function call and MCP argument events by item_id (these should have proper tracking) - argument_events_by_item_id = {} - for chunk in chunks: - if hasattr(chunk, "item_id") and chunk.type in [ - "response.function_call_arguments.delta", - "response.function_call_arguments.done", - "response.mcp_call.arguments.delta", - "response.mcp_call.arguments.done", - ]: - item_id = chunk.item_id - if item_id not in argument_events_by_item_id: - argument_events_by_item_id[item_id] = [] - argument_events_by_item_id[item_id].append(chunk) - - for item_id, related_events in argument_events_by_item_id.items(): - # Should have at least one delta and one done event for a complete tool call - delta_events = [ - e - for e in related_events - if e.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"] - ] - done_events = [ - e - for e in related_events - if e.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"] - ] - - assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event" - assert len(done_events) == 1, f"Item {item_id} should have exactly one done event" - - # Verify all events have the same item_id - for event in related_events: - assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}" - - # Verify content part events if they exist (for text streaming) - content_part_added_events = [chunk for chunk in chunks if chunk.type == "response.content_part.added"] - content_part_done_events = [chunk for chunk in chunks if chunk.type == "response.content_part.done"] - - # Content part events should be paired (if any exist) - if len(content_part_added_events) > 0: - assert len(content_part_done_events) > 0, ( - "Should have content_part.done events if content_part.added events exist" - ) - - # Verify content part event structure - for added_event in content_part_added_events: - assert hasattr(added_event, "response_id"), "Content part added event should have response_id" - assert hasattr(added_event, "item_id"), "Content part added event should have item_id" - assert hasattr(added_event, "part"), "Content part added event should have part" - - # TODO: enable this after the client types are updated - # assert added_event.part.type == "output_text", "Content part should be an output_text" - - for done_event in content_part_done_events: - assert hasattr(done_event, "response_id"), "Content part done event should have response_id" - assert hasattr(done_event, "item_id"), "Content part done event should have item_id" - assert hasattr(done_event, "part"), "Content part done event should have part" - - # TODO: enable this after the client types are updated - # assert len(done_event.part.text) > 0, "Content part should have text when done" - - # Basic pairing check: each output_item.added should be followed by some activity - # (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios) - assert len(item_added_events) > 0, "Should have at least one output_item.added event" - - # Verify response_id consistency across all events - response_ids = set() - for chunk in chunks: - if hasattr(chunk, "response_id"): - response_ids.add(chunk.response_id) - elif hasattr(chunk, "response") and hasattr(chunk.response, "id"): - response_ids.add(chunk.response.id) - - assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}" - - # Get the final response from the last chunk - final_chunk = chunks[-1] - if hasattr(final_chunk, "response"): - final_response = final_chunk.response - - # Verify multi-turn MCP tool execution results - mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"] - mcp_calls = [output for output in final_response.output if output.type == "mcp_call"] - message_outputs = [output for output in final_response.output if output.type == "message"] - - # Should have exactly 1 MCP list tools message (at the beginning) - assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" - assert mcp_list_tools[0].server_label == "localmcp" - assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools - expected_tool_names = { - "get_user_id", - "get_user_permissions", - "check_file_access", - "get_experiment_id", - "get_experiment_results", - } - assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names - - # Should have at least 1 MCP call (the model should call at least one tool) - assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}" - - # All MCP calls should be completed (verifies our tool execution works) - for mcp_call in mcp_calls: - assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}" - - # Should have at least one final message response - assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}" - - # Final message should be from assistant and completed - final_message = message_outputs[-1] - assert final_message.role == "assistant", ( - f"Final message should be from assistant, got {final_message.role}" - ) - assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}" - assert len(final_message.content) > 0, "Final message should have content" - - # Check that the expected output appears in the response - expected_output = case["output"] - assert expected_output.lower() in final_response.output_text.lower(), ( - f"Expected '{expected_output}' to appear in response: {final_response.output_text}" - ) - - -@pytest.mark.parametrize( - "text_format", - # Not testing json_object because most providers don't actually support it. - [ - {"type": "text"}, - { - "type": "json_schema", - "name": "capitals", - "description": "A schema for the capital of each country", - "schema": {"type": "object", "properties": {"capital": {"type": "string"}}}, - "strict": True, - }, - ], -) -def test_response_text_format(compat_client, text_model_id, text_format): - if isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("Responses API text format is not yet supported in library client.") - - stream = False - response = compat_client.responses.create( - model=text_model_id, - input="What is the capital of France?", - stream=stream, - text={"format": text_format}, - ) - # by_alias=True is needed because otherwise Pydantic renames our "schema" field - assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format - assert "paris" in response.output_text.lower() - if text_format["type"] == "json_schema": - assert "paris" in json.loads(response.output_text)["capital"].lower() - - -@pytest.fixture -def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory): - """Create a vector store with multiple files that have different attributes for filtering tests.""" - if isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("Responses API file search is not yet supported in library client.") - - vector_store = _new_vector_store(compat_client, "test_vector_store_with_filters") - tmp_path = tmp_path_factory.mktemp("filter_test_files") - - # Create multiple files with different attributes - files_data = [ - { - "name": "us_marketing_q1.txt", - "content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.", - "attributes": { - "region": "us", - "category": "marketing", - "date": 1672531200, # Jan 1, 2023 - }, - }, - { - "name": "us_engineering_q2.txt", - "content": "US technical updates for Q2 2023. New features deployed in the US region.", - "attributes": { - "region": "us", - "category": "engineering", - "date": 1680307200, # Apr 1, 2023 - }, - }, - { - "name": "eu_marketing_q1.txt", - "content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.", - "attributes": { - "region": "eu", - "category": "marketing", - "date": 1672531200, # Jan 1, 2023 - }, - }, - { - "name": "asia_sales_q3.txt", - "content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.", - "attributes": { - "region": "asia", - "category": "sales", - "date": 1688169600, # Jul 1, 2023 - }, - }, - ] - - file_ids = [] - for file_data in files_data: - # Create file - file_path = tmp_path / file_data["name"] - file_path.write_text(file_data["content"]) - - # Upload file - file_response = _upload_file(compat_client, file_data["name"], str(file_path)) - file_ids.append(file_response.id) - - # Attach file to vector store with attributes - file_attach_response = compat_client.vector_stores.files.create( - vector_store_id=vector_store.id, - file_id=file_response.id, - attributes=file_data["attributes"], - ) - - # Wait for attachment - while file_attach_response.status == "in_progress": - time.sleep(0.1) - file_attach_response = compat_client.vector_stores.files.retrieve( - vector_store_id=vector_store.id, - file_id=file_response.id, - ) - assert file_attach_response.status == "completed" - - yield vector_store - - # Cleanup: delete vector store and files - try: - compat_client.vector_stores.delete(vector_store_id=vector_store.id) - for file_id in file_ids: - try: - compat_client.files.delete(file_id=file_id) - except Exception: - pass # File might already be deleted - except Exception: - pass # Best effort cleanup - - -def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files): - """Test file search with region equality filter.""" - tools = [ - { - "type": "file_search", - "vector_store_ids": [vector_store_with_filtered_files.id], - "filters": {"type": "eq", "key": "region", "value": "us"}, - } - ] - - response = compat_client.responses.create( - model=text_model_id, - input="What are the updates from the US region?", - tools=tools, - stream=False, - include=["file_search_call.results"], - ) - - # Verify file search was called with US filter - assert len(response.output) > 1 - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].results - # Should only return US files (not EU or Asia files) - for result in response.output[0].results: - assert "us" in result.text.lower() or "US" in result.text - # Ensure non-US regions are NOT returned - assert "european" not in result.text.lower() - assert "asia" not in result.text.lower() - - -def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files): - """Test file search with category equality filter.""" - tools = [ - { - "type": "file_search", - "vector_store_ids": [vector_store_with_filtered_files.id], - "filters": {"type": "eq", "key": "category", "value": "marketing"}, - } - ] - - response = compat_client.responses.create( - model=text_model_id, - input="Show me all marketing reports", - tools=tools, - stream=False, - include=["file_search_call.results"], - ) - - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].results - # Should only return marketing files (not engineering or sales) - for result in response.output[0].results: - # Marketing files should have promotional/advertising content - assert "promotional" in result.text.lower() or "advertising" in result.text.lower() - # Ensure non-marketing categories are NOT returned - assert "technical" not in result.text.lower() - assert "revenue figures" not in result.text.lower() - - -def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files): - """Test file search with date range filter using compound AND.""" - tools = [ - { - "type": "file_search", - "vector_store_ids": [vector_store_with_filtered_files.id], - "filters": { - "type": "and", - "filters": [ - { - "type": "gte", - "key": "date", - "value": 1672531200, # Jan 1, 2023 - }, - { - "type": "lt", - "key": "date", - "value": 1680307200, # Apr 1, 2023 - }, - ], - }, - } - ] - - response = compat_client.responses.create( - model=text_model_id, - input="What happened in Q1 2023?", - tools=tools, - stream=False, - include=["file_search_call.results"], - ) - - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].results - # Should only return Q1 files (not Q2 or Q3) - for result in response.output[0].results: - assert "q1" in result.text.lower() - # Ensure non-Q1 quarters are NOT returned - assert "q2" not in result.text.lower() - assert "q3" not in result.text.lower() - - -def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files): - """Test file search with compound AND filter (region AND category).""" - tools = [ - { - "type": "file_search", - "vector_store_ids": [vector_store_with_filtered_files.id], - "filters": { - "type": "and", - "filters": [ - {"type": "eq", "key": "region", "value": "us"}, - {"type": "eq", "key": "category", "value": "engineering"}, - ], - }, - } - ] - - response = compat_client.responses.create( - model=text_model_id, - input="What are the engineering updates from the US?", - tools=tools, - stream=False, - include=["file_search_call.results"], - ) - - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].results - # Should only return US engineering files - assert len(response.output[0].results) >= 1 - for result in response.output[0].results: - assert "us" in result.text.lower() and "technical" in result.text.lower() - # Ensure it's not from other regions or categories - assert "european" not in result.text.lower() and "asia" not in result.text.lower() - assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower() - - -def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files): - """Test file search with compound OR filter (marketing OR sales).""" - tools = [ - { - "type": "file_search", - "vector_store_ids": [vector_store_with_filtered_files.id], - "filters": { - "type": "or", - "filters": [ - {"type": "eq", "key": "category", "value": "marketing"}, - {"type": "eq", "key": "category", "value": "sales"}, - ], - }, - } - ] - - response = compat_client.responses.create( - model=text_model_id, - input="Show me marketing and sales documents", - tools=tools, - stream=False, - include=["file_search_call.results"], - ) - - assert response.output[0].type == "file_search_call" - assert response.output[0].status == "completed" - assert response.output[0].results - # Should return marketing and sales files, but NOT engineering - categories_found = set() - for result in response.output[0].results: - text_lower = result.text.lower() - if "promotional" in text_lower or "advertising" in text_lower: - categories_found.add("marketing") - if "revenue figures" in text_lower: - categories_found.add("sales") - # Ensure engineering files are NOT returned - assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}" - - # Verify we got at least one of the expected categories - assert len(categories_found) > 0, "Should have found at least one marketing or sales file" - assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}" diff --git a/tests/integration/non_ci/responses/test_tool_responses.py b/tests/integration/non_ci/responses/test_tool_responses.py new file mode 100644 index 000000000..33d109863 --- /dev/null +++ b/tests/integration/non_ci/responses/test_tool_responses.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os + +import httpx +import openai +import pytest +from fixtures.test_cases import ( + custom_tool_test_cases, + file_search_test_cases, + mcp_tool_test_cases, + multi_turn_tool_execution_streaming_test_cases, + multi_turn_tool_execution_test_cases, + web_search_test_cases, +) +from helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment +from streaming_assertions import StreamingValidator + +from llama_stack import LlamaStackAsLibraryClient +from llama_stack.core.datatypes import AuthenticationRequiredError +from tests.common.mcp import dependency_tools, make_mcp_server + + +@pytest.mark.parametrize("case", web_search_test_cases) +def test_response_non_streaming_web_search(compat_client, text_model_id, case): + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=case.tools, + stream=False, + ) + assert len(response.output) > 1 + assert response.output[0].type == "web_search_call" + assert response.output[0].status == "completed" + assert response.output[1].type == "message" + assert response.output[1].status == "completed" + assert response.output[1].role == "assistant" + assert len(response.output[1].content) > 0 + assert case.expected.lower() in response.output_text.lower().strip() + + +@pytest.mark.parametrize("case", file_search_test_cases) +def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case): + if isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("Responses API file search is not yet supported in library client.") + + vector_store = new_vector_store(compat_client, "test_vector_store") + + if case.file_content: + file_name = "test_response_non_streaming_file_search.txt" + file_path = tmp_path / file_name + file_path.write_text(case.file_content) + elif case.file_path: + file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path) + file_name = os.path.basename(file_path) + else: + raise ValueError("No file content or path provided for case") + + file_response = upload_file(compat_client, file_name, file_path) + + # Attach our file to the vector store + compat_client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file_response.id, + ) + + # Wait for the file to be attached + wait_for_file_attachment(compat_client, vector_store.id, file_response.id) + + # Update our tools with the right vector store id + tools = case.tools + for tool in tools: + if tool["type"] == "file_search": + tool["vector_store_ids"] = [vector_store.id] + + # Create the response request, which should query our vector store + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=tools, + stream=False, + include=["file_search_call.results"], + ) + + # Verify the file_search_tool was called + assert len(response.output) > 1 + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].queries # ensure it's some non-empty list + assert response.output[0].results + assert case.expected.lower() in response.output[0].results[0].text.lower() + assert response.output[0].results[0].score > 0 + + # Verify the output_text generated by the response + assert case.expected.lower() in response.output_text.lower().strip() + + +def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id): + if isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("Responses API file search is not yet supported in library client.") + + vector_store = new_vector_store(compat_client, "test_vector_store") + + # Create the response request, which should query our vector store + response = compat_client.responses.create( + model=text_model_id, + input="How many experts does the Llama 4 Maverick model have?", + tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}], + stream=False, + include=["file_search_call.results"], + ) + + # Verify the file_search_tool was called + assert len(response.output) > 1 + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].queries # ensure it's some non-empty list + assert not response.output[0].results # ensure we don't get any results + + # Verify some output_text was generated by the response + assert response.output_text + + +@pytest.mark.parametrize("case", mcp_tool_test_cases) +def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case): + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + + with make_mcp_server() as mcp_server_info: + tools = setup_mcp_tools(case.tools, mcp_server_info) + + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=tools, + stream=False, + ) + + assert len(response.output) >= 3 + list_tools = response.output[0] + assert list_tools.type == "mcp_list_tools" + assert list_tools.server_label == "localmcp" + assert len(list_tools.tools) == 2 + assert {t.name for t in list_tools.tools} == { + "get_boiling_point", + "greet_everyone", + } + + call = response.output[1] + assert call.type == "mcp_call" + assert call.name == "get_boiling_point" + assert json.loads(call.arguments) == { + "liquid_name": "myawesomeliquid", + "celsius": True, + } + assert call.error is None + assert "-100" in call.output + + # sometimes the model will call the tool again, so we need to get the last message + message = response.output[-1] + text_content = message.content[0].text + assert "boiling point" in text_content.lower() + + with make_mcp_server(required_auth_token="test-token") as mcp_server_info: + tools = setup_mcp_tools(case.tools, mcp_server_info) + + exc_type = ( + AuthenticationRequiredError + if isinstance(compat_client, LlamaStackAsLibraryClient) + else (httpx.HTTPStatusError, openai.AuthenticationError) + ) + with pytest.raises(exc_type): + compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=tools, + stream=False, + ) + + for tool in tools: + if tool["type"] == "mcp": + tool["headers"] = {"Authorization": "Bearer test-token"} + + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=tools, + stream=False, + ) + assert len(response.output) >= 3 + + +@pytest.mark.parametrize("case", custom_tool_test_cases) +def test_response_non_streaming_custom_tool(compat_client, text_model_id, case): + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=case.tools, + stream=False, + ) + assert len(response.output) == 1 + assert response.output[0].type == "function_call" + assert response.output[0].status == "completed" + assert response.output[0].name == "get_weather" + + +@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases) +def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): + """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + + with make_mcp_server(tools=dependency_tools()) as mcp_server_info: + tools = setup_mcp_tools(case.tools, mcp_server_info) + + response = compat_client.responses.create( + input=case.input, + model=text_model_id, + tools=tools, + ) + + # Verify we have MCP tool calls in the output + mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"] + mcp_calls = [output for output in response.output if output.type == "mcp_call"] + message_outputs = [output for output in response.output if output.type == "message"] + + # Should have exactly 1 MCP list tools message (at the beginning) + assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" + assert mcp_list_tools[0].server_label == "localmcp" + assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools + expected_tool_names = { + "get_user_id", + "get_user_permissions", + "check_file_access", + "get_experiment_id", + "get_experiment_results", + } + assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names + + assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}" + for mcp_call in mcp_calls: + assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}" + + assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}" + + final_message = message_outputs[-1] + assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}" + assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}" + assert len(final_message.content) > 0, "Final message should have content" + + expected_output = case.expected + assert expected_output.lower() in response.output_text.lower(), ( + f"Expected '{expected_output}' to appear in response: {response.output_text}" + ) + + +@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases) +def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): + """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + + with make_mcp_server(tools=dependency_tools()) as mcp_server_info: + tools = setup_mcp_tools(case.tools, mcp_server_info) + + stream = compat_client.responses.create( + input=case.input, + model=text_model_id, + tools=tools, + stream=True, + ) + + chunks = [] + for chunk in stream: + chunks.append(chunk) + + # Use validator for common streaming checks + validator = StreamingValidator(chunks) + validator.assert_basic_event_sequence() + validator.assert_response_consistency() + validator.assert_has_tool_calls() + validator.assert_has_mcp_events() + validator.assert_rich_streaming() + + # Get the final response from the last chunk + final_chunk = chunks[-1] + if hasattr(final_chunk, "response"): + final_response = final_chunk.response + + # Verify multi-turn MCP tool execution results + mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"] + mcp_calls = [output for output in final_response.output if output.type == "mcp_call"] + message_outputs = [output for output in final_response.output if output.type == "message"] + + # Should have exactly 1 MCP list tools message (at the beginning) + assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" + assert mcp_list_tools[0].server_label == "localmcp" + assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools + expected_tool_names = { + "get_user_id", + "get_user_permissions", + "check_file_access", + "get_experiment_id", + "get_experiment_results", + } + assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names + + # Should have at least 1 MCP call (the model should call at least one tool) + assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}" + + # All MCP calls should be completed (verifies our tool execution works) + for mcp_call in mcp_calls: + assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}" + + # Should have at least one final message response + assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}" + + # Final message should be from assistant and completed + final_message = message_outputs[-1] + assert final_message.role == "assistant", ( + f"Final message should be from assistant, got {final_message.role}" + ) + assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}" + assert len(final_message.content) > 0, "Final message should have content" + + # Check that the expected output appears in the response + expected_output = case.expected + assert expected_output.lower() in final_response.output_text.lower(), ( + f"Expected '{expected_output}' to appear in response: {final_response.output_text}" + )