# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import copy import json import re from typing import Any import pytest from pydantic import BaseModel from tests.verifications.openai_api.fixtures.fixtures import ( _load_all_verification_configs, ) from tests.verifications.openai_api.fixtures.load import load_test_cases chat_completion_test_cases = load_test_cases("chat_completion") def case_id_generator(case): """Generate a test ID from the case's 'case_id' field, or use a default.""" case_id = case.get("case_id") if isinstance(case_id, (str, int)): return re.sub(r"\\W|^(?=\\d)", "_", str(case_id)) return None def pytest_generate_tests(metafunc): """Dynamically parametrize tests based on the selected provider and config.""" if "model" in metafunc.fixturenames: provider = metafunc.config.getoption("provider") if not provider: print("Warning: --provider not specified. Skipping model parametrization.") metafunc.parametrize("model", []) return try: config_data = _load_all_verification_configs() except (FileNotFoundError, IOError) as e: print(f"ERROR loading verification configs: {e}") config_data = {"providers": {}} provider_config = config_data.get("providers", {}).get(provider) if provider_config: models = provider_config.get("models", []) if models: metafunc.parametrize("model", models) else: print(f"Warning: No models found for provider '{provider}' in config.") metafunc.parametrize("model", []) # Parametrize empty if no models found else: print(f"Warning: Provider '{provider}' not found in config. No models parametrized.") metafunc.parametrize("model", []) # Parametrize empty if provider not found def should_skip_test(verification_config, provider, model, test_name_base): """Check if a test should be skipped based on config exclusions.""" provider_config = verification_config.get("providers", {}).get(provider) if not provider_config: return False # No config for provider, don't skip exclusions = provider_config.get("test_exclusions", {}).get(model, []) return test_name_base in exclusions # Helper to get the base test name from the request object def get_base_test_name(request): return request.node.originalname # --- Test Functions --- @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_basic"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_non_streaming_basic(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], stream=False, ) assert response.choices[0].message.role == "assistant" assert case["output"].lower() in response.choices[0].message.content.lower() @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_basic"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_streaming_basic(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], stream=True, ) content = "" for chunk in response: content += chunk.choices[0].delta.content or "" # TODO: add detailed type validation assert case["output"].lower() in content.lower() @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_image"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_non_streaming_image(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], stream=False, ) assert response.choices[0].message.role == "assistant" assert case["output"].lower() in response.choices[0].message.content.lower() @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_image"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_streaming_image(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], stream=True, ) content = "" for chunk in response: content += chunk.choices[0].delta.content or "" # TODO: add detailed type validation assert case["output"].lower() in content.lower() @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_non_streaming_structured_output(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], response_format=case["input"]["response_format"], stream=False, ) assert response.choices[0].message.role == "assistant" maybe_json_content = response.choices[0].message.content validate_structured_output(maybe_json_content, case["output"]) @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_streaming_structured_output(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], response_format=case["input"]["response_format"], stream=True, ) maybe_json_content = "" for chunk in response: maybe_json_content += chunk.choices[0].delta.content or "" validate_structured_output(maybe_json_content, case["output"]) @pytest.mark.parametrize( "case", chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_non_streaming_tool_calling(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], tools=case["input"]["tools"], stream=False, ) assert response.choices[0].message.role == "assistant" assert len(response.choices[0].message.tool_calls) > 0 assert case["output"] == "get_weather_tool_call" assert response.choices[0].message.tool_calls[0].function.name == "get_weather" # TODO: add detailed type validation @pytest.mark.parametrize( "case", chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], ids=case_id_generator, ) def test_chat_streaming_tool_calling(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") stream = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], tools=case["input"]["tools"], stream=True, ) _, tool_calls_buffer = _accumulate_streaming_tool_calls(stream) assert len(tool_calls_buffer) == 1 for call in tool_calls_buffer: assert len(call["id"]) > 0 function = call["function"] assert function["name"] == "get_weather" args_dict = json.loads(function["arguments"]) assert "san francisco" in args_dict["location"].lower() @pytest.mark.parametrize( "case", chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now ids=case_id_generator, ) def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], tools=case["input"]["tools"], tool_choice="required", # Force tool call stream=False, ) assert response.choices[0].message.role == "assistant" assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'" expected_tool_name = case["input"]["tools"][0]["function"]["name"] assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name @pytest.mark.parametrize( "case", chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now ids=case_id_generator, ) def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") stream = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], tools=case["input"]["tools"], tool_choice="required", # Force tool call stream=True, ) _, tool_calls_buffer = _accumulate_streaming_tool_calls(stream) assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'" expected_tool_name = case["input"]["tools"][0]["function"]["name"] assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), ( f"Expected tool call '{expected_tool_name}' not found in stream" ) @pytest.mark.parametrize( "case", chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now ids=case_id_generator, ) def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") response = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], tools=case["input"]["tools"], tool_choice="none", stream=False, ) assert response.choices[0].message.role == "assistant" assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'" assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'" @pytest.mark.parametrize( "case", chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now ids=case_id_generator, ) def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case): test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") stream = openai_client.chat.completions.create( model=model, messages=case["input"]["messages"], tools=case["input"]["tools"], tool_choice="none", stream=True, ) content = "" for chunk in stream: delta = chunk.choices[0].delta if delta.content: content += delta.content assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'" assert len(content) > 0, "Expected content when tool_choice='none'" @pytest.mark.parametrize( "case", chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []), ids=case_id_generator, ) def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case): """ Test cases for multi-turn tool calling. Tool calls are asserted. Tool responses are provided in the test case. Final response is asserted. """ test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") # Create a copy of the messages list to avoid modifying the original messages = [] tools = case["input"]["tools"] # Use deepcopy to prevent modification across runs/parametrization expected_results = copy.deepcopy(case["expected"]) tool_responses = copy.deepcopy(case.get("tool_responses", [])) input_messages_turns = copy.deepcopy(case["input"]["messages"]) # keep going until either # 1. we have messages to test in multi-turn # 2. no messages but last message is tool response while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"): # do not take new messages if last message is tool response if len(messages) == 0 or messages[-1]["role"] != "tool": new_messages = input_messages_turns.pop(0) # Ensure new_messages is a list of message objects if isinstance(new_messages, list): messages.extend(new_messages) else: # If it's a single message object, add it directly messages.append(new_messages) # --- API Call --- response = openai_client.chat.completions.create( model=model, messages=messages, tools=tools, stream=False, ) # --- Process Response --- assistant_message = response.choices[0].message messages.append(assistant_message.model_dump(exclude_unset=True)) assert assistant_message.role == "assistant" # Get the expected result data expected = expected_results.pop(0) num_tool_calls = expected["num_tool_calls"] # --- Assertions based on expected result --- assert len(assistant_message.tool_calls or []) == num_tool_calls, ( f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}" ) if num_tool_calls > 0: tool_call = assistant_message.tool_calls[0] assert tool_call.function.name == expected["tool_name"], ( f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'" ) # Parse the JSON string arguments before comparing actual_arguments = json.loads(tool_call.function.arguments) assert actual_arguments == expected["tool_arguments"], ( f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'" ) # Prepare and append the tool response for the next turn tool_response = tool_responses.pop(0) messages.append( { "role": "tool", "tool_call_id": tool_call.id, "content": tool_response["response"], } ) else: assert assistant_message.content is not None, "Expected content, but none received." expected_answers = expected["answer"] # This is now a list content_lower = assistant_message.content.lower() assert any(ans.lower() in content_lower for ans in expected_answers), ( f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'" ) @pytest.mark.parametrize( "case", chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []), ids=case_id_generator, ) def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case): """ """ test_name_base = get_base_test_name(request) if should_skip_test(verification_config, provider, model, test_name_base): pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") messages = [] tools = case["input"]["tools"] expected_results = copy.deepcopy(case["expected"]) tool_responses = copy.deepcopy(case.get("tool_responses", [])) input_messages_turns = copy.deepcopy(case["input"]["messages"]) while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"): if len(messages) == 0 or messages[-1]["role"] != "tool": new_messages = input_messages_turns.pop(0) if isinstance(new_messages, list): messages.extend(new_messages) else: messages.append(new_messages) # --- API Call (Streaming) --- stream = openai_client.chat.completions.create( model=model, messages=messages, tools=tools, stream=True, ) # --- Process Stream --- accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream) # --- Construct Assistant Message for History --- assistant_message_dict = {"role": "assistant"} if accumulated_content: assistant_message_dict["content"] = accumulated_content if accumulated_tool_calls: assistant_message_dict["tool_calls"] = accumulated_tool_calls messages.append(assistant_message_dict) # --- Assertions --- expected = expected_results.pop(0) num_tool_calls = expected["num_tool_calls"] assert len(accumulated_tool_calls or []) == num_tool_calls, ( f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}" ) if num_tool_calls > 0: # Use the first accumulated tool call for assertion tool_call = accumulated_tool_calls[0] assert tool_call["function"]["name"] == expected["tool_name"], ( f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'" ) # Parse the accumulated arguments string for comparison actual_arguments = json.loads(tool_call["function"]["arguments"]) assert actual_arguments == expected["tool_arguments"], ( f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'" ) # Prepare and append the tool response for the next turn tool_response = tool_responses.pop(0) messages.append( { "role": "tool", "tool_call_id": tool_call["id"], "content": tool_response["response"], } ) else: assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received." expected_answers = expected["answer"] content_lower = accumulated_content.lower() assert any(ans.lower() in content_lower for ans in expected_answers), ( f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'" ) # --- Helper functions (structured output validation) --- def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None: if schema_name == "valid_calendar_event": class CalendarEvent(BaseModel): name: str date: str participants: list[str] try: calendar_event = CalendarEvent.model_validate_json(maybe_json_content) return calendar_event except Exception: return None elif schema_name == "valid_math_reasoning": class Step(BaseModel): explanation: str output: str class MathReasoning(BaseModel): steps: list[Step] final_answer: str try: math_reasoning = MathReasoning.model_validate_json(maybe_json_content) return math_reasoning except Exception: return None return None def validate_structured_output(maybe_json_content: str, schema_name: str) -> None: structured_output = get_structured_output(maybe_json_content, schema_name) assert structured_output is not None if schema_name == "valid_calendar_event": assert structured_output.name is not None assert structured_output.date is not None assert len(structured_output.participants) == 2 elif schema_name == "valid_math_reasoning": assert len(structured_output.final_answer) > 0 def _accumulate_streaming_tool_calls(stream): """Accumulates tool calls and content from a streaming ChatCompletion response.""" tool_calls_buffer = {} current_id = None full_content = "" # Initialize content accumulator # Process streaming chunks for chunk in stream: choice = chunk.choices[0] delta = choice.delta # Accumulate content if delta.content: full_content += delta.content if delta.tool_calls is None: continue for tool_call_delta in delta.tool_calls: if tool_call_delta.id: current_id = tool_call_delta.id call_id = current_id # Skip if no ID seen yet for this tool call delta if not call_id: continue func_delta = tool_call_delta.function if call_id not in tool_calls_buffer: tool_calls_buffer[call_id] = { "id": call_id, "type": "function", # Assume function type "function": {"name": None, "arguments": ""}, # Nested structure } # Accumulate name and arguments into the nested function dict if func_delta: if func_delta.name: tool_calls_buffer[call_id]["function"]["name"] = func_delta.name if func_delta.arguments: tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments # Return content and tool calls as a list return full_content, list(tool_calls_buffer.values())