feat: adds test suite to verify provider's OAI compat endpoints (#1901)

# What does this PR do?


## Test Plan
pytest verifications/openai/test_chat_completion.py --provider together
This commit is contained in:
ehhuang 2025-04-08 21:21:38 -07:00 committed by GitHub
parent 7d9adf22ad
commit bcbc56baa2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 9404 additions and 0 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from openai import OpenAI
@pytest.fixture
def providers_model_mapping():
"""
Mapping from model names used in test cases to provider's model names.
"""
return {
"fireworks": {
"Llama-3.3-70B-Instruct": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"Llama-3.2-11B-Vision-Instruct": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
"Llama-4-Scout-17B-16E-Instruct": "accounts/fireworks/models/llama4-scout-instruct-basic",
"Llama-4-Maverick-17B-128E-Instruct": "accounts/fireworks/models/llama4-maverick-instruct-basic",
},
"together": {
"Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"Llama-4-Maverick-17B-128E-Instruct": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
},
"groq": {
"Llama-3.3-70B-Instruct": "llama-3.3-70b-versatile",
"Llama-3.2-11B-Vision-Instruct": "llama-3.2-11b-vision-preview",
"Llama-4-Scout-17B-16E-Instruct": "llama-4-scout-17b-16e-instruct",
"Llama-4-Maverick-17B-128E-Instruct": "llama-4-maverick-17b-128e-instruct",
},
"cerebras": {
"Llama-3.3-70B-Instruct": "llama-3.3-70b",
},
"openai": {
"gpt-4o": "gpt-4o",
"gpt-4o-mini": "gpt-4o-mini",
},
}
@pytest.fixture
def provider_metadata():
return {
"fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
"together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
"groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
"cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
"openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
}
@pytest.fixture
def provider(request, provider_metadata):
provider = request.config.getoption("--provider")
base_url = request.config.getoption("--base-url")
if provider and base_url and provider_metadata[provider][0] != base_url:
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
if not provider:
if not base_url:
raise ValueError("Provider and base URL are not provided")
for provider, metadata in provider_metadata.items():
if metadata[0] == base_url:
provider = provider
break
return provider
@pytest.fixture
def base_url(request, provider, provider_metadata):
return request.config.getoption("--base-url") or provider_metadata[provider][0]
@pytest.fixture
def api_key(request, provider, provider_metadata):
return request.config.getoption("--api-key") or os.getenv(provider_metadata[provider][1])
@pytest.fixture
def model_mapping(provider, providers_model_mapping):
return providers_model_mapping[provider]
@pytest.fixture
def openai_client(base_url, api_key):
return OpenAI(
base_url=base_url,
api_key=api_key,
)

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
import yaml
def load_test_cases(name: str):
fixture_dir = Path(__file__).parent / "test_cases"
yaml_path = fixture_dir / f"{name}.yaml"
with open(yaml_path, "r") as f:
return yaml.safe_load(f)

View file

@ -0,0 +1,162 @@
test_chat_basic:
test_name: test_chat_basic
test_params:
input_output:
- input:
messages:
- content: Which planet do humans live on?
role: user
output: Earth
- input:
messages:
- content: Which planet has rings around it with a name starting with letter
S?
role: user
output: Saturn
model:
- Llama-3.3-8B-Instruct
- Llama-3.3-70B-Instruct
- Llama-4-Scout-17B-16E
- Llama-4-Scout-17B-16E-Instruct
- Llama-4-Maverick-17B-128E
- Llama-4-Maverick-17B-128E-Instruct
- gpt-4o
- gpt-4o-mini
test_chat_image:
test_name: test_chat_image
test_params:
input_output:
- input:
messages:
- content:
- text: What is in this image?
type: text
- image_url:
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
type: image_url
role: user
output: llama
model:
- Llama-4-Scout-17B-16E
- Llama-4-Scout-17B-16E-Instruct
- Llama-4-Maverick-17B-128E
- Llama-4-Maverick-17B-128E-Instruct
- gpt-4o
- gpt-4o-mini
test_chat_structured_output:
test_name: test_chat_structured_output
test_params:
input_output:
- input:
messages:
- content: Extract the event information.
role: system
- content: Alice and Bob are going to a science fair on Friday.
role: user
response_format:
json_schema:
name: calendar_event
schema:
properties:
date:
title: Date
type: string
name:
title: Name
type: string
participants:
items:
type: string
title: Participants
type: array
required:
- name
- date
- participants
title: CalendarEvent
type: object
type: json_schema
output: valid_calendar_event
- input:
messages:
- content: You are a helpful math tutor. Guide the user through the solution
step by step.
role: system
- content: how can I solve 8x + 7 = -23
role: user
response_format:
json_schema:
name: math_reasoning
schema:
$defs:
Step:
properties:
explanation:
title: Explanation
type: string
output:
title: Output
type: string
required:
- explanation
- output
title: Step
type: object
properties:
final_answer:
title: Final Answer
type: string
steps:
items:
$ref: '#/$defs/Step'
title: Steps
type: array
required:
- steps
- final_answer
title: MathReasoning
type: object
type: json_schema
output: valid_math_reasoning
model:
- Llama-3.3-8B-Instruct
- Llama-3.3-70B-Instruct
- Llama-4-Scout-17B-16E
- Llama-4-Scout-17B-16E-Instruct
- Llama-4-Maverick-17B-128E
- Llama-4-Maverick-17B-128E-Instruct
- gpt-4o
- gpt-4o-mini
test_tool_calling:
test_name: test_tool_calling
test_params:
input_output:
- input:
messages:
- content: You are a helpful assistant that can use tools to get information.
role: system
- content: What's the weather like in San Francisco?
role: user
tools:
- function:
description: Get current temperature for a given location.
name: get_weather
parameters:
additionalProperties: false
properties:
location:
description: "City and country e.g. Bogot\xE1, Colombia"
type: string
required:
- location
type: object
type: function
output: get_weather_tool_call
model:
- Llama-3.3-70B-Instruct
- Llama-4-Scout-17B-16E
- Llama-4-Scout-17B-16E-Instruct
- Llama-4-Maverick-17B-128E
- Llama-4-Maverick-17B-128E-Instruct
- gpt-4o
- gpt-4o-mini

View file

@ -0,0 +1,202 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import pytest
from pydantic import BaseModel
from tests.verifications.openai.fixtures.load import load_test_cases
chat_completion_test_cases = load_test_cases("chat_completion")
@pytest.fixture
def correct_model_name(model, provider, providers_model_mapping):
"""Return the provider-specific model name based on the generic model name."""
mapping = providers_model_mapping[provider]
if model not in mapping:
pytest.skip(f"Provider {provider} does not support model {model}")
return mapping[model]
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
)
def test_chat_non_streaming_basic(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert input_output["output"].lower() in response.choices[0].message.content.lower()
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
)
def test_chat_streaming_basic(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
stream=True,
)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
# TODO: add detailed type validation
assert input_output["output"].lower() in content.lower()
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
)
def test_chat_non_streaming_image(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert input_output["output"].lower() in response.choices[0].message.content.lower()
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
)
def test_chat_streaming_image(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
stream=True,
)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
# TODO: add detailed type validation
assert input_output["output"].lower() in content.lower()
@pytest.mark.parametrize(
"model",
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
)
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
)
def test_chat_non_streaming_structured_output(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
response_format=input_output["input"]["response_format"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
maybe_json_content = response.choices[0].message.content
validate_structured_output(maybe_json_content, input_output["output"])
@pytest.mark.parametrize(
"model",
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
)
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
)
def test_chat_streaming_structured_output(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
response_format=input_output["input"]["response_format"],
stream=True,
)
maybe_json_content = ""
for chunk in response:
maybe_json_content += chunk.choices[0].delta.content or ""
validate_structured_output(maybe_json_content, input_output["output"])
@pytest.mark.parametrize(
"model",
chat_completion_test_cases["test_tool_calling"]["test_params"]["model"],
)
@pytest.mark.parametrize(
"input_output",
chat_completion_test_cases["test_tool_calling"]["test_params"]["input_output"],
)
def test_chat_non_streaming_tool_calling(openai_client, input_output, correct_model_name):
response = openai_client.chat.completions.create(
model=correct_model_name,
messages=input_output["input"]["messages"],
tools=input_output["input"]["tools"],
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0
assert input_output["output"] == "get_weather_tool_call"
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
# TODO: add detailed type validation
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
if schema_name == "valid_calendar_event":
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
try:
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
return calendar_event
except Exception:
return None
elif schema_name == "valid_math_reasoning":
class Step(BaseModel):
explanation: str
output: str
class MathReasoning(BaseModel):
steps: list[Step]
final_answer: str
try:
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
return math_reasoning
except Exception:
return None
return None
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
structured_output = get_structured_output(maybe_json_content, schema_name)
assert structured_output is not None
if schema_name == "valid_calendar_event":
assert structured_output.name is not None
assert structured_output.date is not None
assert len(structured_output.participants) == 2
elif schema_name == "valid_math_reasoning":
assert len(structured_output.final_answer) > 0