forked from phoenix-oss/llama-stack-mirror
feat: adds test suite to verify provider's OAI compat endpoints (#1901)
# What does this PR do? ## Test Plan pytest verifications/openai/test_chat_completion.py --provider together
This commit is contained in:
parent
7d9adf22ad
commit
bcbc56baa2
14 changed files with 9404 additions and 0 deletions
5
tests/verifications/openai/__init__.py
Normal file
5
tests/verifications/openai/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/verifications/openai/fixtures/__init__.py
Normal file
5
tests/verifications/openai/fixtures/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
97
tests/verifications/openai/fixtures/fixtures.py
Normal file
97
tests/verifications/openai/fixtures/fixtures.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def providers_model_mapping():
|
||||
"""
|
||||
Mapping from model names used in test cases to provider's model names.
|
||||
"""
|
||||
return {
|
||||
"fireworks": {
|
||||
"Llama-3.3-70B-Instruct": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
"Llama-3.2-11B-Vision-Instruct": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||
},
|
||||
"together": {
|
||||
"Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
},
|
||||
"groq": {
|
||||
"Llama-3.3-70B-Instruct": "llama-3.3-70b-versatile",
|
||||
"Llama-3.2-11B-Vision-Instruct": "llama-3.2-11b-vision-preview",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "llama-4-scout-17b-16e-instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "llama-4-maverick-17b-128e-instruct",
|
||||
},
|
||||
"cerebras": {
|
||||
"Llama-3.3-70B-Instruct": "llama-3.3-70b",
|
||||
},
|
||||
"openai": {
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_metadata():
|
||||
return {
|
||||
"fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
|
||||
"together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
|
||||
"groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
|
||||
"cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
|
||||
"openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(request, provider_metadata):
|
||||
provider = request.config.getoption("--provider")
|
||||
base_url = request.config.getoption("--base-url")
|
||||
|
||||
if provider and base_url and provider_metadata[provider][0] != base_url:
|
||||
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
|
||||
|
||||
if not provider:
|
||||
if not base_url:
|
||||
raise ValueError("Provider and base URL are not provided")
|
||||
for provider, metadata in provider_metadata.items():
|
||||
if metadata[0] == base_url:
|
||||
provider = provider
|
||||
break
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_url(request, provider, provider_metadata):
|
||||
return request.config.getoption("--base-url") or provider_metadata[provider][0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(request, provider, provider_metadata):
|
||||
return request.config.getoption("--api-key") or os.getenv(provider_metadata[provider][1])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_mapping(provider, providers_model_mapping):
|
||||
return providers_model_mapping[provider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(base_url, api_key):
|
||||
return OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
16
tests/verifications/openai/fixtures/load.py
Normal file
16
tests/verifications/openai/fixtures/load.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
|
@ -0,0 +1,162 @@
|
|||
test_chat_basic:
|
||||
test_name: test_chat_basic
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
output: Earth
|
||||
- input:
|
||||
messages:
|
||||
- content: Which planet has rings around it with a name starting with letter
|
||||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
model:
|
||||
- Llama-3.3-8B-Instruct
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content:
|
||||
- text: What is in this image?
|
||||
type: text
|
||||
- image_url:
|
||||
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
|
||||
type: image_url
|
||||
role: user
|
||||
output: llama
|
||||
model:
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_chat_structured_output:
|
||||
test_name: test_chat_structured_output
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: Extract the event information.
|
||||
role: system
|
||||
- content: Alice and Bob are going to a science fair on Friday.
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: calendar_event
|
||||
schema:
|
||||
properties:
|
||||
date:
|
||||
title: Date
|
||||
type: string
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
participants:
|
||||
items:
|
||||
type: string
|
||||
title: Participants
|
||||
type: array
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
- participants
|
||||
title: CalendarEvent
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_calendar_event
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful math tutor. Guide the user through the solution
|
||||
step by step.
|
||||
role: system
|
||||
- content: how can I solve 8x + 7 = -23
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: math_reasoning
|
||||
schema:
|
||||
$defs:
|
||||
Step:
|
||||
properties:
|
||||
explanation:
|
||||
title: Explanation
|
||||
type: string
|
||||
output:
|
||||
title: Output
|
||||
type: string
|
||||
required:
|
||||
- explanation
|
||||
- output
|
||||
title: Step
|
||||
type: object
|
||||
properties:
|
||||
final_answer:
|
||||
title: Final Answer
|
||||
type: string
|
||||
steps:
|
||||
items:
|
||||
$ref: '#/$defs/Step'
|
||||
title: Steps
|
||||
type: array
|
||||
required:
|
||||
- steps
|
||||
- final_answer
|
||||
title: MathReasoning
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_math_reasoning
|
||||
model:
|
||||
- Llama-3.3-8B-Instruct
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_tool_calling:
|
||||
test_name: test_tool_calling
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful assistant that can use tools to get information.
|
||||
role: system
|
||||
- content: What's the weather like in San Francisco?
|
||||
role: user
|
||||
tools:
|
||||
- function:
|
||||
description: Get current temperature for a given location.
|
||||
name: get_weather
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
model:
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
202
tests/verifications/openai/test_chat_completion.py
Normal file
202
tests/verifications/openai/test_chat_completion.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.verifications.openai.fixtures.load import load_test_cases
|
||||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def correct_model_name(model, provider, providers_model_mapping):
|
||||
"""Return the provider-specific model name based on the generic model name."""
|
||||
mapping = providers_model_mapping[provider]
|
||||
if model not in mapping:
|
||||
pytest.skip(f"Provider {provider} does not support model {model}")
|
||||
return mapping[model]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_basic(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert input_output["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_basic(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert input_output["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_image(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert input_output["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_image(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert input_output["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_structured_output(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
response_format=input_output["input"]["response_format"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
maybe_json_content = response.choices[0].message.content
|
||||
|
||||
validate_structured_output(maybe_json_content, input_output["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_structured_output(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
response_format=input_output["input"]["response_format"],
|
||||
stream=True,
|
||||
)
|
||||
maybe_json_content = ""
|
||||
for chunk in response:
|
||||
maybe_json_content += chunk.choices[0].delta.content or ""
|
||||
validate_structured_output(maybe_json_content, input_output["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_tool_calling(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
tools=input_output["input"]["tools"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0
|
||||
assert input_output["output"] == "get_weather_tool_call"
|
||||
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||
# TODO: add detailed type validation
|
||||
|
||||
|
||||
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
|
||||
if schema_name == "valid_calendar_event":
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
try:
|
||||
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
|
||||
return calendar_event
|
||||
except Exception:
|
||||
return None
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
|
||||
class Step(BaseModel):
|
||||
explanation: str
|
||||
output: str
|
||||
|
||||
class MathReasoning(BaseModel):
|
||||
steps: list[Step]
|
||||
final_answer: str
|
||||
|
||||
try:
|
||||
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
|
||||
return math_reasoning
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
|
||||
structured_output = get_structured_output(maybe_json_content, schema_name)
|
||||
assert structured_output is not None
|
||||
if schema_name == "valid_calendar_event":
|
||||
assert structured_output.name is not None
|
||||
assert structured_output.date is not None
|
||||
assert len(structured_output.participants) == 2
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
assert len(structured_output.final_answer) > 0
|
Loading…
Add table
Add a link
Reference in a new issue