forked from phoenix-oss/llama-stack-mirror
# What does this PR do? When clients called the Open AI API with invalid input that wasn't caught by our own Pydantic API validation but instead only caught by the backend inference provider, that backend inference provider was returning a HTTP 400 error. However, we were wrapping that into a HTTP 500 error, obfuscating the actual issue from calling clients and triggering OpenAI client retry logic. This change adjusts our existing `translate_exception` method in `server.py` to wrap `openai.BadRequestError` as HTTP 400 errors, passing through the string representation of the error message to the calling user so they can see the actual input validation error and correct it. I tried changing this in a few other places, but ultimately `translate_exception` was the only real place to handle this for both streaming and non-streaming requests across all inference providers that use the OpenAI server APIs. This also tightens up our validation a bit for the OpenAI chat completions API, to catch empty `messages` parameters, invalid `tool_choice` parameters, invalid `tools` items, or passing `tool_choice` when `tools` isn't given. Lastly, this extends our OpenAI API chat completions verifications to also check for consistent input validation across providers. Providers behind Llama Stack should automatically pass all the new tests due to the input validation added here, but some of the providers fail this test when not run behind Llama Stack due to differences in how they handle input validation and errors. (Closes #1951) ## Test Plan To test this, start an OpenAI API verification stack: ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` Then, run the new verification tests with your provider(s) of choice: ``` python -m pytest -s -v \ tests/verifications/openai_api/test_chat_completion.py \ --provider openai-llama-stack python -m pytest -s -v \ tests/verifications/openai_api/test_chat_completion.py \ --provider together-llama-stack ``` Signed-off-by: Ben Browning <bbrownin@redhat.com>
767 lines
29 KiB
Python
767 lines
29 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import base64
|
|
import copy
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from openai import APIError
|
|
from pydantic import BaseModel
|
|
|
|
from tests.verifications.openai_api.fixtures.fixtures import (
|
|
_load_all_verification_configs,
|
|
)
|
|
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
|
|
|
chat_completion_test_cases = load_test_cases("chat_completion")
|
|
|
|
THIS_DIR = Path(__file__).parent
|
|
|
|
|
|
def case_id_generator(case):
|
|
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
|
case_id = case.get("case_id")
|
|
if isinstance(case_id, (str, int)):
|
|
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
|
|
return None
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
"""Dynamically parametrize tests based on the selected provider and config."""
|
|
if "model" in metafunc.fixturenames:
|
|
provider = metafunc.config.getoption("provider")
|
|
if not provider:
|
|
print("Warning: --provider not specified. Skipping model parametrization.")
|
|
metafunc.parametrize("model", [])
|
|
return
|
|
|
|
try:
|
|
config_data = _load_all_verification_configs()
|
|
except (FileNotFoundError, IOError) as e:
|
|
print(f"ERROR loading verification configs: {e}")
|
|
config_data = {"providers": {}}
|
|
|
|
provider_config = config_data.get("providers", {}).get(provider)
|
|
if provider_config:
|
|
models = provider_config.get("models", [])
|
|
if models:
|
|
metafunc.parametrize("model", models)
|
|
else:
|
|
print(f"Warning: No models found for provider '{provider}' in config.")
|
|
metafunc.parametrize("model", []) # Parametrize empty if no models found
|
|
else:
|
|
print(f"Warning: Provider '{provider}' not found in config. No models parametrized.")
|
|
metafunc.parametrize("model", []) # Parametrize empty if provider not found
|
|
|
|
|
|
def should_skip_test(verification_config, provider, model, test_name_base):
|
|
"""Check if a test should be skipped based on config exclusions."""
|
|
provider_config = verification_config.get("providers", {}).get(provider)
|
|
if not provider_config:
|
|
return False # No config for provider, don't skip
|
|
|
|
exclusions = provider_config.get("test_exclusions", {}).get(model, [])
|
|
return test_name_base in exclusions
|
|
|
|
|
|
# Helper to get the base test name from the request object
|
|
def get_base_test_name(request):
|
|
return request.node.originalname
|
|
|
|
|
|
@pytest.fixture
|
|
def multi_image_data():
|
|
files = [
|
|
THIS_DIR / "fixtures/images/vision_test_1.jpg",
|
|
THIS_DIR / "fixtures/images/vision_test_2.jpg",
|
|
THIS_DIR / "fixtures/images/vision_test_3.jpg",
|
|
]
|
|
encoded_files = []
|
|
for file in files:
|
|
with open(file, "rb") as image_file:
|
|
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
|
encoded_files.append(f"data:image/jpeg;base64,{base64_data}")
|
|
return encoded_files
|
|
|
|
|
|
# --- Test Functions ---
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_basic"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_basic(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
stream=False,
|
|
)
|
|
assert response.choices[0].message.role == "assistant"
|
|
assert case["output"].lower() in response.choices[0].message.content.lower()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_basic"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_basic(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
stream=True,
|
|
)
|
|
content = ""
|
|
for chunk in response:
|
|
content += chunk.choices[0].delta.content or ""
|
|
|
|
# TODO: add detailed type validation
|
|
|
|
assert case["output"].lower() in content.lower()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
with pytest.raises(APIError) as e:
|
|
openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
stream=False,
|
|
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
|
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
|
)
|
|
assert case["output"]["error"]["status_code"] == e.value.status_code
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
with pytest.raises(APIError) as e:
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
stream=True,
|
|
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
|
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
|
)
|
|
for _chunk in response:
|
|
pass
|
|
assert str(case["output"]["error"]["status_code"]) in e.value.message
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_image(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
stream=False,
|
|
)
|
|
assert response.choices[0].message.role == "assistant"
|
|
assert case["output"].lower() in response.choices[0].message.content.lower()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_image(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
stream=True,
|
|
)
|
|
content = ""
|
|
for chunk in response:
|
|
content += chunk.choices[0].delta.content or ""
|
|
|
|
# TODO: add detailed type validation
|
|
|
|
assert case["output"].lower() in content.lower()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_structured_output(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
response_format=case["input"]["response_format"],
|
|
stream=False,
|
|
)
|
|
|
|
assert response.choices[0].message.role == "assistant"
|
|
maybe_json_content = response.choices[0].message.content
|
|
|
|
validate_structured_output(maybe_json_content, case["output"])
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_structured_output(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
response_format=case["input"]["response_format"],
|
|
stream=True,
|
|
)
|
|
maybe_json_content = ""
|
|
for chunk in response:
|
|
maybe_json_content += chunk.choices[0].delta.content or ""
|
|
validate_structured_output(maybe_json_content, case["output"])
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
tools=case["input"]["tools"],
|
|
stream=False,
|
|
)
|
|
|
|
assert response.choices[0].message.role == "assistant"
|
|
assert len(response.choices[0].message.tool_calls) > 0
|
|
assert case["output"] == "get_weather_tool_call"
|
|
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
|
# TODO: add detailed type validation
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
stream = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
tools=case["input"]["tools"],
|
|
stream=True,
|
|
)
|
|
|
|
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
|
|
assert len(tool_calls_buffer) == 1
|
|
for call in tool_calls_buffer:
|
|
assert len(call["id"]) > 0
|
|
function = call["function"]
|
|
assert function["name"] == "get_weather"
|
|
|
|
args_dict = json.loads(function["arguments"])
|
|
assert "san francisco" in args_dict["location"].lower()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
tools=case["input"]["tools"],
|
|
tool_choice="required", # Force tool call
|
|
stream=False,
|
|
)
|
|
|
|
assert response.choices[0].message.role == "assistant"
|
|
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
|
|
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
|
|
assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
stream = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
tools=case["input"]["tools"],
|
|
tool_choice="required", # Force tool call
|
|
stream=True,
|
|
)
|
|
|
|
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
|
|
|
|
assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'"
|
|
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
|
|
assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), (
|
|
f"Expected tool call '{expected_tool_name}' not found in stream"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
tools=case["input"]["tools"],
|
|
tool_choice="none",
|
|
stream=False,
|
|
)
|
|
|
|
assert response.choices[0].message.role == "assistant"
|
|
assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'"
|
|
assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
stream = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=case["input"]["messages"],
|
|
tools=case["input"]["tools"],
|
|
tool_choice="none",
|
|
stream=True,
|
|
)
|
|
|
|
content = ""
|
|
for chunk in stream:
|
|
delta = chunk.choices[0].delta
|
|
if delta.content:
|
|
content += delta.content
|
|
assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'"
|
|
|
|
assert len(content) > 0, "Expected content when tool_choice='none'"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
|
|
"""
|
|
Test cases for multi-turn tool calling.
|
|
Tool calls are asserted.
|
|
Tool responses are provided in the test case.
|
|
Final response is asserted.
|
|
"""
|
|
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
# Create a copy of the messages list to avoid modifying the original
|
|
messages = []
|
|
tools = case["input"]["tools"]
|
|
# Use deepcopy to prevent modification across runs/parametrization
|
|
expected_results = copy.deepcopy(case["expected"])
|
|
tool_responses = copy.deepcopy(case.get("tool_responses", []))
|
|
input_messages_turns = copy.deepcopy(case["input"]["messages"])
|
|
|
|
# keep going until either
|
|
# 1. we have messages to test in multi-turn
|
|
# 2. no messages but last message is tool response
|
|
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
|
# do not take new messages if last message is tool response
|
|
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
|
new_messages = input_messages_turns.pop(0)
|
|
# Ensure new_messages is a list of message objects
|
|
if isinstance(new_messages, list):
|
|
messages.extend(new_messages)
|
|
else:
|
|
# If it's a single message object, add it directly
|
|
messages.append(new_messages)
|
|
|
|
# --- API Call ---
|
|
response = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
stream=False,
|
|
)
|
|
|
|
# --- Process Response ---
|
|
assistant_message = response.choices[0].message
|
|
messages.append(assistant_message.model_dump(exclude_unset=True))
|
|
|
|
assert assistant_message.role == "assistant"
|
|
|
|
# Get the expected result data
|
|
expected = expected_results.pop(0)
|
|
num_tool_calls = expected["num_tool_calls"]
|
|
|
|
# --- Assertions based on expected result ---
|
|
assert len(assistant_message.tool_calls or []) == num_tool_calls, (
|
|
f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}"
|
|
)
|
|
|
|
if num_tool_calls > 0:
|
|
tool_call = assistant_message.tool_calls[0]
|
|
assert tool_call.function.name == expected["tool_name"], (
|
|
f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'"
|
|
)
|
|
# Parse the JSON string arguments before comparing
|
|
actual_arguments = json.loads(tool_call.function.arguments)
|
|
assert actual_arguments == expected["tool_arguments"], (
|
|
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
|
|
)
|
|
|
|
# Prepare and append the tool response for the next turn
|
|
tool_response = tool_responses.pop(0)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"content": tool_response["response"],
|
|
}
|
|
)
|
|
else:
|
|
assert assistant_message.content is not None, "Expected content, but none received."
|
|
expected_answers = expected["answer"] # This is now a list
|
|
content_lower = assistant_message.content.lower()
|
|
assert any(ans.lower() in content_lower for ans in expected_answers), (
|
|
f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
|
|
ids=case_id_generator,
|
|
)
|
|
def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
|
|
""" """
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
messages = []
|
|
tools = case["input"]["tools"]
|
|
expected_results = copy.deepcopy(case["expected"])
|
|
tool_responses = copy.deepcopy(case.get("tool_responses", []))
|
|
input_messages_turns = copy.deepcopy(case["input"]["messages"])
|
|
|
|
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
|
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
|
new_messages = input_messages_turns.pop(0)
|
|
if isinstance(new_messages, list):
|
|
messages.extend(new_messages)
|
|
else:
|
|
messages.append(new_messages)
|
|
|
|
# --- API Call (Streaming) ---
|
|
stream = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
stream=True,
|
|
)
|
|
|
|
# --- Process Stream ---
|
|
accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream)
|
|
|
|
# --- Construct Assistant Message for History ---
|
|
assistant_message_dict = {"role": "assistant"}
|
|
if accumulated_content:
|
|
assistant_message_dict["content"] = accumulated_content
|
|
if accumulated_tool_calls:
|
|
assistant_message_dict["tool_calls"] = accumulated_tool_calls
|
|
|
|
messages.append(assistant_message_dict)
|
|
|
|
# --- Assertions ---
|
|
expected = expected_results.pop(0)
|
|
num_tool_calls = expected["num_tool_calls"]
|
|
|
|
assert len(accumulated_tool_calls or []) == num_tool_calls, (
|
|
f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}"
|
|
)
|
|
|
|
if num_tool_calls > 0:
|
|
# Use the first accumulated tool call for assertion
|
|
tool_call = accumulated_tool_calls[0]
|
|
assert tool_call["function"]["name"] == expected["tool_name"], (
|
|
f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'"
|
|
)
|
|
# Parse the accumulated arguments string for comparison
|
|
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
|
assert actual_arguments == expected["tool_arguments"], (
|
|
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
|
|
)
|
|
|
|
# Prepare and append the tool response for the next turn
|
|
tool_response = tool_responses.pop(0)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"content": tool_response["response"],
|
|
}
|
|
)
|
|
else:
|
|
assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received."
|
|
expected_answers = expected["answer"]
|
|
content_lower = accumulated_content.lower()
|
|
assert any(ans.lower() in content_lower for ans in expected_answers), (
|
|
f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("stream", [False, True], ids=["stream=False", "stream=True"])
|
|
def test_chat_multi_turn_multiple_images(
|
|
request, openai_client, model, provider, verification_config, multi_image_data, stream
|
|
):
|
|
test_name_base = get_base_test_name(request)
|
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
|
|
|
messages_turn1 = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": multi_image_data[0],
|
|
},
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": multi_image_data[1],
|
|
},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": "What furniture is in the first image that is not in the second image?",
|
|
},
|
|
],
|
|
},
|
|
]
|
|
|
|
# First API call
|
|
response1 = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=messages_turn1,
|
|
stream=stream,
|
|
)
|
|
if stream:
|
|
message_content1 = ""
|
|
for chunk in response1:
|
|
message_content1 += chunk.choices[0].delta.content or ""
|
|
else:
|
|
message_content1 = response1.choices[0].message.content
|
|
assert len(message_content1) > 0
|
|
assert any(expected in message_content1.lower().strip() for expected in {"chair", "table"}), message_content1
|
|
|
|
# Prepare messages for the second turn
|
|
messages_turn2 = messages_turn1 + [
|
|
{"role": "assistant", "content": message_content1},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": multi_image_data[2],
|
|
},
|
|
},
|
|
{"type": "text", "text": "What is in this image that is also in the first image?"},
|
|
],
|
|
},
|
|
]
|
|
|
|
# Second API call
|
|
response2 = openai_client.chat.completions.create(
|
|
model=model,
|
|
messages=messages_turn2,
|
|
stream=stream,
|
|
)
|
|
if stream:
|
|
message_content2 = ""
|
|
for chunk in response2:
|
|
message_content2 += chunk.choices[0].delta.content or ""
|
|
else:
|
|
message_content2 = response2.choices[0].message.content
|
|
assert len(message_content2) > 0
|
|
assert any(expected in message_content2.lower().strip() for expected in {"bed"}), message_content2
|
|
|
|
|
|
# --- Helper functions (structured output validation) ---
|
|
|
|
|
|
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
|
|
if schema_name == "valid_calendar_event":
|
|
|
|
class CalendarEvent(BaseModel):
|
|
name: str
|
|
date: str
|
|
participants: list[str]
|
|
|
|
try:
|
|
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
|
|
return calendar_event
|
|
except Exception:
|
|
return None
|
|
elif schema_name == "valid_math_reasoning":
|
|
|
|
class Step(BaseModel):
|
|
explanation: str
|
|
output: str
|
|
|
|
class MathReasoning(BaseModel):
|
|
steps: list[Step]
|
|
final_answer: str
|
|
|
|
try:
|
|
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
|
|
return math_reasoning
|
|
except Exception:
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
|
|
structured_output = get_structured_output(maybe_json_content, schema_name)
|
|
assert structured_output is not None
|
|
if schema_name == "valid_calendar_event":
|
|
assert structured_output.name is not None
|
|
assert structured_output.date is not None
|
|
assert len(structured_output.participants) == 2
|
|
elif schema_name == "valid_math_reasoning":
|
|
assert len(structured_output.final_answer) > 0
|
|
|
|
|
|
def _accumulate_streaming_tool_calls(stream):
|
|
"""Accumulates tool calls and content from a streaming ChatCompletion response."""
|
|
tool_calls_buffer = {}
|
|
current_id = None
|
|
full_content = "" # Initialize content accumulator
|
|
# Process streaming chunks
|
|
for chunk in stream:
|
|
choice = chunk.choices[0]
|
|
delta = choice.delta
|
|
|
|
# Accumulate content
|
|
if delta.content:
|
|
full_content += delta.content
|
|
|
|
if delta.tool_calls is None:
|
|
continue
|
|
|
|
for tool_call_delta in delta.tool_calls:
|
|
if tool_call_delta.id:
|
|
current_id = tool_call_delta.id
|
|
call_id = current_id
|
|
# Skip if no ID seen yet for this tool call delta
|
|
if not call_id:
|
|
continue
|
|
func_delta = tool_call_delta.function
|
|
|
|
if call_id not in tool_calls_buffer:
|
|
tool_calls_buffer[call_id] = {
|
|
"id": call_id,
|
|
"type": "function", # Assume function type
|
|
"function": {"name": None, "arguments": ""}, # Nested structure
|
|
}
|
|
|
|
# Accumulate name and arguments into the nested function dict
|
|
if func_delta:
|
|
if func_delta.name:
|
|
tool_calls_buffer[call_id]["function"]["name"] = func_delta.name
|
|
if func_delta.arguments:
|
|
tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments
|
|
|
|
# Return content and tool calls as a list
|
|
return full_content, list(tool_calls_buffer.values())
|