llama-stack/tests/verifications/openai_api/test_chat_completion.py
Ben Browning 8dfce2f596
feat: OpenAI Responses API (#1989)
# What does this PR do?

This provides an initial [OpenAI Responses
API](https://platform.openai.com/docs/api-reference/responses)
implementation. The API is not yet complete, and this is more a
proof-of-concept to show how we can store responses in our key-value
stores and use them to support the Responses API concepts like
`previous_response_id`.

## Test Plan

I've added a new
`tests/integration/openai_responses/test_openai_responses.py` as part of
a test-driven development for this new API. I'm only testing this
locally with the remote-vllm provider for now, but it should work with
any of our inference providers since the only API it requires out of the
inference provider is the `openai_chat_completion` endpoint.

```
VLLM_URL="http://localhost:8000/v1" \
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \
llama stack build --template remote-vllm --image-type venv --run
```

```
LLAMA_STACK_CONFIG="http://localhost:8321" \
python -m pytest -v \
  tests/integration/openai_responses/test_openai_responses.py \
  --text-model "meta-llama/Llama-3.2-3B-Instruct"
 ```

---------

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-04-28 14:06:00 -07:00

717 lines
27 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import copy
import json
from pathlib import Path
from typing import Any
import pytest
from openai import APIError
from pydantic import BaseModel
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
should_skip_test,
)
from tests.verifications.openai_api.fixtures.load import load_test_cases
chat_completion_test_cases = load_test_cases("chat_completion")
THIS_DIR = Path(__file__).parent
@pytest.fixture
def multi_image_data():
files = [
THIS_DIR / "fixtures/images/vision_test_1.jpg",
THIS_DIR / "fixtures/images/vision_test_2.jpg",
THIS_DIR / "fixtures/images/vision_test_3.jpg",
]
encoded_files = []
for file in files:
with open(file, "rb") as image_file:
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
encoded_files.append(f"data:image/jpeg;base64,{base64_data}")
return encoded_files
# --- Test Functions ---
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_basic"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_basic(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert case["output"].lower() in response.choices[0].message.content.lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_basic"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_basic(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=True,
)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
# TODO: add detailed type validation
assert case["output"].lower() in content.lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with pytest.raises(APIError) as e:
openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=False,
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
tools=case["input"]["tools"] if "tools" in case["input"] else None,
)
assert case["output"]["error"]["status_code"] == e.value.status_code
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with pytest.raises(APIError) as e:
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=True,
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
tools=case["input"]["tools"] if "tools" in case["input"] else None,
)
for _chunk in response:
pass
assert str(case["output"]["error"]["status_code"]) in e.value.message
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_image(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert case["output"].lower() in response.choices[0].message.content.lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_image(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=True,
)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
# TODO: add detailed type validation
assert case["output"].lower() in content.lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_structured_output(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
response_format=case["input"]["response_format"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
maybe_json_content = response.choices[0].message.content
validate_structured_output(maybe_json_content, case["output"])
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_structured_output(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
response_format=case["input"]["response_format"],
stream=True,
)
maybe_json_content = ""
for chunk in response:
maybe_json_content += chunk.choices[0].delta.content or ""
validate_structured_output(maybe_json_content, case["output"])
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0
assert case["output"] == "get_weather_tool_call"
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
# TODO: add detailed type validation
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
stream=True,
)
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
assert len(tool_calls_buffer) == 1
for call in tool_calls_buffer:
assert len(call["id"]) > 0
function = call["function"]
assert function["name"] == "get_weather"
args_dict = json.loads(function["arguments"])
assert "san francisco" in args_dict["location"].lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="required", # Force tool call
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="required", # Force tool call
stream=True,
)
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'"
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), (
f"Expected tool call '{expected_tool_name}' not found in stream"
)
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="none",
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'"
assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'"
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="none",
stream=True,
)
content = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
content += delta.content
assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'"
assert len(content) > 0, "Expected content when tool_choice='none'"
@pytest.mark.parametrize(
"case",
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
ids=case_id_generator,
)
def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
"""
Test cases for multi-turn tool calling.
Tool calls are asserted.
Tool responses are provided in the test case.
Final response is asserted.
"""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
# Create a copy of the messages list to avoid modifying the original
messages = []
tools = case["input"]["tools"]
# Use deepcopy to prevent modification across runs/parametrization
expected_results = copy.deepcopy(case["expected"])
tool_responses = copy.deepcopy(case.get("tool_responses", []))
input_messages_turns = copy.deepcopy(case["input"]["messages"])
# keep going until either
# 1. we have messages to test in multi-turn
# 2. no messages but last message is tool response
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
# do not take new messages if last message is tool response
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = input_messages_turns.pop(0)
# Ensure new_messages is a list of message objects
if isinstance(new_messages, list):
messages.extend(new_messages)
else:
# If it's a single message object, add it directly
messages.append(new_messages)
# --- API Call ---
response = openai_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
stream=False,
)
# --- Process Response ---
assistant_message = response.choices[0].message
messages.append(assistant_message.model_dump(exclude_unset=True))
assert assistant_message.role == "assistant"
# Get the expected result data
expected = expected_results.pop(0)
num_tool_calls = expected["num_tool_calls"]
# --- Assertions based on expected result ---
assert len(assistant_message.tool_calls or []) == num_tool_calls, (
f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}"
)
if num_tool_calls > 0:
tool_call = assistant_message.tool_calls[0]
assert tool_call.function.name == expected["tool_name"], (
f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'"
)
# Parse the JSON string arguments before comparing
actual_arguments = json.loads(tool_call.function.arguments)
assert actual_arguments == expected["tool_arguments"], (
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
)
# Prepare and append the tool response for the next turn
tool_response = tool_responses.pop(0)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_response["response"],
}
)
else:
assert assistant_message.content is not None, "Expected content, but none received."
expected_answers = expected["answer"] # This is now a list
content_lower = assistant_message.content.lower()
assert any(ans.lower() in content_lower for ans in expected_answers), (
f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'"
)
@pytest.mark.parametrize(
"case",
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
ids=case_id_generator,
)
def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
""" """
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
messages = []
tools = case["input"]["tools"]
expected_results = copy.deepcopy(case["expected"])
tool_responses = copy.deepcopy(case.get("tool_responses", []))
input_messages_turns = copy.deepcopy(case["input"]["messages"])
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = input_messages_turns.pop(0)
if isinstance(new_messages, list):
messages.extend(new_messages)
else:
messages.append(new_messages)
# --- API Call (Streaming) ---
stream = openai_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
stream=True,
)
# --- Process Stream ---
accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream)
# --- Construct Assistant Message for History ---
assistant_message_dict = {"role": "assistant"}
if accumulated_content:
assistant_message_dict["content"] = accumulated_content
if accumulated_tool_calls:
assistant_message_dict["tool_calls"] = accumulated_tool_calls
messages.append(assistant_message_dict)
# --- Assertions ---
expected = expected_results.pop(0)
num_tool_calls = expected["num_tool_calls"]
assert len(accumulated_tool_calls or []) == num_tool_calls, (
f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}"
)
if num_tool_calls > 0:
# Use the first accumulated tool call for assertion
tool_call = accumulated_tool_calls[0]
assert tool_call["function"]["name"] == expected["tool_name"], (
f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'"
)
# Parse the accumulated arguments string for comparison
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert actual_arguments == expected["tool_arguments"], (
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
)
# Prepare and append the tool response for the next turn
tool_response = tool_responses.pop(0)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_response["response"],
}
)
else:
assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received."
expected_answers = expected["answer"]
content_lower = accumulated_content.lower()
assert any(ans.lower() in content_lower for ans in expected_answers), (
f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'"
)
@pytest.mark.parametrize("stream", [False, True], ids=["stream=False", "stream=True"])
def test_chat_multi_turn_multiple_images(
request, openai_client, model, provider, verification_config, multi_image_data, stream
):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
messages_turn1 = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": multi_image_data[0],
},
},
{
"type": "image_url",
"image_url": {
"url": multi_image_data[1],
},
},
{
"type": "text",
"text": "What furniture is in the first image that is not in the second image?",
},
],
},
]
# First API call
response1 = openai_client.chat.completions.create(
model=model,
messages=messages_turn1,
stream=stream,
)
if stream:
message_content1 = ""
for chunk in response1:
message_content1 += chunk.choices[0].delta.content or ""
else:
message_content1 = response1.choices[0].message.content
assert len(message_content1) > 0
assert any(expected in message_content1.lower().strip() for expected in {"chair", "table"}), message_content1
# Prepare messages for the second turn
messages_turn2 = messages_turn1 + [
{"role": "assistant", "content": message_content1},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": multi_image_data[2],
},
},
{"type": "text", "text": "What is in this image that is also in the first image?"},
],
},
]
# Second API call
response2 = openai_client.chat.completions.create(
model=model,
messages=messages_turn2,
stream=stream,
)
if stream:
message_content2 = ""
for chunk in response2:
message_content2 += chunk.choices[0].delta.content or ""
else:
message_content2 = response2.choices[0].message.content
assert len(message_content2) > 0
assert any(expected in message_content2.lower().strip() for expected in {"bed"}), message_content2
# --- Helper functions (structured output validation) ---
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
if schema_name == "valid_calendar_event":
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
try:
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
return calendar_event
except Exception:
return None
elif schema_name == "valid_math_reasoning":
class Step(BaseModel):
explanation: str
output: str
class MathReasoning(BaseModel):
steps: list[Step]
final_answer: str
try:
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
return math_reasoning
except Exception:
return None
return None
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
structured_output = get_structured_output(maybe_json_content, schema_name)
assert structured_output is not None
if schema_name == "valid_calendar_event":
assert structured_output.name is not None
assert structured_output.date is not None
assert len(structured_output.participants) == 2
elif schema_name == "valid_math_reasoning":
assert len(structured_output.final_answer) > 0
def _accumulate_streaming_tool_calls(stream):
"""Accumulates tool calls and content from a streaming ChatCompletion response."""
tool_calls_buffer = {}
current_id = None
full_content = "" # Initialize content accumulator
# Process streaming chunks
for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
# Accumulate content
if delta.content:
full_content += delta.content
if delta.tool_calls is None:
continue
for tool_call_delta in delta.tool_calls:
if tool_call_delta.id:
current_id = tool_call_delta.id
call_id = current_id
# Skip if no ID seen yet for this tool call delta
if not call_id:
continue
func_delta = tool_call_delta.function
if call_id not in tool_calls_buffer:
tool_calls_buffer[call_id] = {
"id": call_id,
"type": "function", # Assume function type
"function": {"name": None, "arguments": ""}, # Nested structure
}
# Accumulate name and arguments into the nested function dict
if func_delta:
if func_delta.name:
tool_calls_buffer[call_id]["function"]["name"] = func_delta.name
if func_delta.arguments:
tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments
# Return content and tool calls as a list
return full_content, list(tool_calls_buffer.values())