Merge branch 'main' into register_custom_model

This commit is contained in:
Rashmi Pawar 2025-04-16 14:35:51 +05:30 committed by GitHub
commit afb792b9c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
69 changed files with 8875 additions and 890 deletions

View file

@ -115,7 +115,7 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
stream=True,
max_tokens=50,
)
streamed_content = [chunk.choices[0].text for chunk in response]
streamed_content = [chunk.choices[0].text or "" for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert len(content_str) > 10

View file

@ -26,7 +26,12 @@ from openai.types.chat.chat_completion_chunk import (
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import ToolChoice, ToolConfig
from llama_stack.apis.inference import (
ChatCompletionRequest,
ToolChoice,
ToolConfig,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
@ -232,3 +237,14 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings
@pytest.mark.asyncio
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
# make sure the scheduler starts empty
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_jobs() == []
called = False
# schedule a job that will exercise the handlers
async def job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
# exercise the handlers
on_log("test log1")
on_log("test log2")
on_artifact({"type": "type1", "path": "path1"})
on_artifact({"type": "type2", "path": "path2"})
on_status(JobStatus.completed)
job_id = "test_job_id"
job_type = "test_job_type"
sched.schedule(job_type, job_id, job_handler)
# make sure the job was properly registered
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_job(job_id) is not None
assert sched.get_jobs() == [sched.get_job(job_id)]
assert sched.get_jobs("unknown") == []
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
# now shut the scheduler down and make sure the job ran
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed
assert job.scheduled_at is not None
assert job.started_at is not None
assert job.completed_at is not None
assert job.scheduled_at < job.started_at < job.completed_at
assert job.artifacts == [
{"type": "type1", "path": "path1"},
{"type": "type2", "path": "path2"},
]
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()
async def failing_job_handler(on_log, on_status, on_artifact):
on_status(JobStatus.running)
raise ValueError("test error")
job_id = "test_job_id1"
job_type = "test_job_type"
sched.schedule(job_type, job_id, failing_job_handler)
job = sched.get_job(job_id)
assert job is not None
# confirm the exception made the job transition to failed state, even
# though it was set to `running` before the error
for _ in range(10):
if job.status == JobStatus.failed:
break
await asyncio.sleep(0.1)
assert job.status == JobStatus.failed
# confirm that the raised error got registered in log
assert job.logs[0][1] == "test error"
# even after failed job, we can schedule another one
called = False
async def successful_job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
on_status(JobStatus.completed)
job_id = "test_job_id2"
sched.schedule(job_type, job_id, successful_job_handler)
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed

View file

@ -1,6 +1,6 @@
# Test Results Report
*Generated on: 2025-04-10 16:48:18*
*Generated on: 2025-04-14 18:11:37*
*This report was generated by running `python tests/verifications/generate_report.py`*
@ -15,15 +15,15 @@
| Provider | Pass Rate | Tests Passed | Total Tests |
| --- | --- | --- | --- |
| Together | 64.7% | 22 | 34 |
| Fireworks | 82.4% | 28 | 34 |
| Openai | 100.0% | 24 | 24 |
| Together | 48.7% | 37 | 76 |
| Fireworks | 47.4% | 36 | 76 |
| Openai | 100.0% | 52 | 52 |
## Together
*Tests run on: 2025-04-10 16:46:35*
*Tests run on: 2025-04-14 18:08:14*
```bash
# Run all tests for this provider:
@ -48,19 +48,33 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ❌ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_calling | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_none | ❌ | ❌ | ❌ |
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ | ✅ |
| test_chat_streaming_basic (earth) | ✅ | ❌ | ❌ |
| test_chat_streaming_basic (saturn) | ✅ | ❌ | ❌ |
| test_chat_streaming_image | ⚪ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
| test_chat_streaming_structured_output (calendar) | ✅ | ❌ | ❌ |
| test_chat_streaming_structured_output (math) | ✅ | ❌ | ❌ |
| test_chat_streaming_tool_calling | ✅ | ❌ | ❌ |
| test_chat_streaming_tool_choice_none | ❌ | ❌ | ❌ |
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
## Fireworks
*Tests run on: 2025-04-10 16:44:44*
*Tests run on: 2025-04-14 18:04:06*
```bash
# Run all tests for this provider:
@ -85,19 +99,33 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_calling | ❌ | ❌ | ❌ |
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ | ❌ | ❌ |
| test_chat_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
| test_chat_streaming_structured_output (math) | ✅ | ✅ | ✅ |
| test_chat_streaming_tool_calling | ❌ | ❌ | ❌ |
| test_chat_streaming_tool_choice_none | ✅ | ✅ | ✅ |
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
## Openai
*Tests run on: 2025-04-10 16:47:28*
*Tests run on: 2025-04-14 18:09:51*
```bash
# Run all tests for this provider:
@ -121,12 +149,26 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
| test_chat_non_streaming_basic (earth) | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ |
| test_chat_non_streaming_image | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ |
| test_chat_non_streaming_tool_calling | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ |
| test_chat_streaming_basic (earth) | ✅ | ✅ |
| test_chat_streaming_basic (saturn) | ✅ | ✅ |
| test_chat_streaming_image | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ |
| test_chat_streaming_structured_output (math) | ✅ | ✅ |
| test_chat_streaming_tool_calling | ✅ | ✅ |
| test_chat_streaming_tool_choice_none | ✅ | ✅ |
| test_chat_streaming_tool_choice_required | ✅ | ✅ |

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: FIREWORKS_API_KEY
models:
- fireworks/llama-v3p3-70b-instruct
- fireworks/llama4-scout-instruct-basic
- fireworks/llama4-maverick-instruct-basic
model_display_names:
fireworks/llama-v3p3-70b-instruct: Llama-3.3-70B-Instruct
fireworks/llama4-scout-instruct-basic: Llama-4-Scout-Instruct
fireworks/llama4-maverick-instruct-basic: Llama-4-Maverick-Instruct
test_exclusions:
fireworks/llama-v3p3-70b-instruct:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: GROQ_API_KEY
models:
- groq/llama-3.3-70b-versatile
- groq/llama-4-scout-17b-16e-instruct
- groq/llama-4-maverick-17b-128e-instruct
model_display_names:
groq/llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
groq/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
groq/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions:
groq/llama-3.3-70b-versatile:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -2,12 +2,12 @@ base_url: https://api.groq.com/openai/v1
api_key_var: GROQ_API_KEY
models:
- llama-3.3-70b-versatile
- llama-4-scout-17b-16e-instruct
- llama-4-maverick-17b-128e-instruct
- meta-llama/llama-4-scout-17b-16e-instruct
- meta-llama/llama-4-maverick-17b-128e-instruct
model_display_names:
llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
meta-llama/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
meta-llama/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions:
llama-3.3-70b-versatile:
- test_chat_non_streaming_image

View file

@ -0,0 +1,9 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: OPENAI_API_KEY
models:
- openai/gpt-4o
- openai/gpt-4o-mini
model_display_names:
openai/gpt-4o: gpt-4o
openai/gpt-4o-mini: gpt-4o-mini
test_exclusions: {}

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: TOGETHER_API_KEY
models:
- together/meta-llama/Llama-3.3-70B-Instruct-Turbo
- together/meta-llama/Llama-4-Scout-17B-16E-Instruct
- together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_display_names:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo: Llama-3.3-70B-Instruct
together/meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8: Llama-4-Maverick-Instruct
test_exclusions:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -67,7 +67,17 @@ RESULTS_DIR.mkdir(exist_ok=True)
# Maximum number of test result files to keep per provider
MAX_RESULTS_PER_PROVIDER = 1
PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"]
PROVIDER_ORDER = [
"together",
"fireworks",
"groq",
"cerebras",
"openai",
"together-llama-stack",
"fireworks-llama-stack",
"groq-llama-stack",
"openai-llama-stack",
]
VERIFICATION_CONFIG = _load_all_verification_configs()

View file

@ -0,0 +1,146 @@
version: '2'
image_name: openai-api-verification
apis:
- inference
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY}
- provider_id: openai
provider_type: remote::openai
config:
url: https://api.openai.com/v1
api_key: ${env.OPENAI_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/faiss_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/registry.db
models:
- metadata: {}
model_id: together/meta-llama/Llama-3.3-70B-Instruct-Turbo
provider_id: together
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm
- metadata: {}
model_id: together/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: together
provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
model_type: llm
- metadata: {}
model_id: together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
provider_id: together
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: fireworks/llama-v3p3-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: fireworks/llama4-scout-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: fireworks/llama4-maverick-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata: {}
model_id: groq/llama-3.3-70b-versatile
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: groq/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: openai/gpt-4o-mini
provider_id: openai
provider_model_id: openai/gpt-4o-mini
model_type: llm
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321

View file

@ -99,6 +99,9 @@ def model_mapping(provider, providers_model_mapping):
@pytest.fixture
def openai_client(base_url, api_key):
# Simplify running against a local Llama Stack
if "localhost" in base_url and not api_key:
api_key = "empty"
return OpenAI(
base_url=base_url,
api_key=api_key,

View file

@ -131,3 +131,221 @@ test_tool_calling:
type: object
type: function
output: get_weather_tool_call
test_chat_multi_turn_tool_calling:
test_name: test_chat_multi_turn_tool_calling
test_params:
case:
- case_id: "text_then_weather_tool"
input:
messages:
- - role: user
content: "What's the name of the Sun in latin?"
- - role: user
content: "What's the weather like in San Francisco?"
tools:
- function:
description: Get the current weather
name: get_weather
parameters:
type: object
properties:
location:
description: "The city and state (both required), e.g. San Francisco, CA."
type: string
required: ["location"]
type: function
tool_responses:
- response: "{'response': '70 degrees and foggy'}"
expected:
- num_tool_calls: 0
answer: ["sol"]
- num_tool_calls: 1
tool_name: get_weather
tool_arguments:
location: "San Francisco, CA"
- num_tool_calls: 0
answer: ["foggy", "70 degrees"]
- case_id: "weather_tool_then_text"
input:
messages:
- - role: user
content: "What's the weather like in San Francisco?"
tools:
- function:
description: Get the current weather
name: get_weather
parameters:
type: object
properties:
location:
description: "The city and state (both required), e.g. San Francisco, CA."
type: string
required: ["location"]
type: function
tool_responses:
- response: "{'response': '70 degrees and foggy'}"
expected:
- num_tool_calls: 1
tool_name: get_weather
tool_arguments:
location: "San Francisco, CA"
- num_tool_calls: 0
answer: ["foggy", "70 degrees"]
- case_id: "add_product_tool"
input:
messages:
- - role: user
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
tools:
- function:
description: Add a new product
name: addProduct
parameters:
type: object
properties:
name:
description: "Name of the product"
type: string
price:
description: "Price of the product"
type: number
inStock:
description: "Availability status of the product."
type: boolean
tags:
description: "List of product tags"
type: array
items:
type: string
required: ["name", "price", "inStock"]
type: function
tool_responses:
- response: "{'response': 'Successfully added product with id: 123'}"
expected:
- num_tool_calls: 1
tool_name: addProduct
tool_arguments:
name: "Widget"
price: 19.99
inStock: true
tags:
- "new"
- "sale"
- num_tool_calls: 0
answer: ["123", "product id: 123"]
- case_id: "get_then_create_event_tool"
input:
messages:
- - role: system
content: "Todays date is 2025-03-01."
- role: user
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
- - role: user
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
tools:
- function:
description: Create a new event
name: create_event
parameters:
type: object
properties:
name:
description: "Name of the event"
type: string
date:
description: "Date of the event in ISO format"
type: string
time:
description: "Event Time (HH:MM)"
type: string
location:
description: "Location of the event"
type: string
participants:
description: "List of participant names"
type: array
items:
type: string
required: ["name", "date", "time", "location", "participants"]
type: function
- function:
description: Get an event by date and time
name: get_event
parameters:
type: object
properties:
date:
description: "Date of the event in ISO format"
type: string
time:
description: "Event Time (HH:MM)"
type: string
required: ["date", "time"]
type: function
tool_responses:
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
- response: "{'response': 'Successfully created new event with id: e_123'}"
expected:
- num_tool_calls: 1
tool_name: get_event
tool_arguments:
date: "2025-03-03"
time: "10:00"
- num_tool_calls: 0
answer: ["no", "no events found", "no meetings"]
- num_tool_calls: 1
tool_name: create_event
tool_arguments:
name: "Team Building"
date: "2025-03-03"
time: "10:00"
location: "Main Conference Room"
participants:
- "Alice"
- "Bob"
- "Charlie"
- num_tool_calls: 0
answer: ["e_123", "event id: e_123"]
- case_id: "compare_monthly_expense_tool"
input:
messages:
- - role: system
content: "Todays date is 2025-03-01."
- role: user
content: "what was my monthly expense in Jan of this year?"
- - role: user
content: "Was it less than Feb of last year? Only answer with yes or no."
tools:
- function:
description: Get monthly expense summary
name: getMonthlyExpenseSummary
parameters:
type: object
properties:
month:
description: "Month of the year (1-12)"
type: integer
year:
description: "Year"
type: integer
required: ["month", "year"]
type: function
tool_responses:
- response: "{'response': 'Total expenses for January 2025: $1000'}"
- response: "{'response': 'Total expenses for February 2024: $2000'}"
expected:
- num_tool_calls: 1
tool_name: getMonthlyExpenseSummary
tool_arguments:
month: 1
year: 2025
- num_tool_calls: 0
answer: ["1000", "$1,000", "1,000"]
- num_tool_calls: 1
tool_name: getMonthlyExpenseSummary
tool_arguments:
month: 2
year: 2024
- num_tool_calls: 0
answer: ["yes"]

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import json
import re
from typing import Any
@ -243,43 +244,294 @@ def test_chat_streaming_tool_calling(request, openai_client, model, provider, ve
stream=True,
)
# Accumulate partial tool_calls here
tool_calls_buffer = {}
current_id = None
# Process streaming chunks
for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
if delta.tool_calls is None:
continue
for tool_call_delta in delta.tool_calls:
if tool_call_delta.id:
current_id = tool_call_delta.id
call_id = current_id
func_delta = tool_call_delta.function
if call_id not in tool_calls_buffer:
tool_calls_buffer[call_id] = {
"id": call_id,
"type": tool_call_delta.type,
"name": func_delta.name,
"arguments": "",
}
if func_delta.arguments:
tool_calls_buffer[call_id]["arguments"] += func_delta.arguments
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
assert len(tool_calls_buffer) == 1
for call in tool_calls_buffer.values():
for call in tool_calls_buffer:
assert len(call["id"]) > 0
assert call["name"] == "get_weather"
function = call["function"]
assert function["name"] == "get_weather"
args_dict = json.loads(call["arguments"])
args_dict = json.loads(function["arguments"])
assert "san francisco" in args_dict["location"].lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="required", # Force tool call
stream=False,
)
print(response)
assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="required", # Force tool call
stream=True,
)
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'"
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), (
f"Expected tool call '{expected_tool_name}' not found in stream"
)
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="none",
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'"
assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'"
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="none",
stream=True,
)
content = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
content += delta.content
assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'"
assert len(content) > 0, "Expected content when tool_choice='none'"
@pytest.mark.parametrize(
"case",
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
ids=case_id_generator,
)
def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
"""
Test cases for multi-turn tool calling.
Tool calls are asserted.
Tool responses are provided in the test case.
Final response is asserted.
"""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
# Create a copy of the messages list to avoid modifying the original
messages = []
tools = case["input"]["tools"]
# Use deepcopy to prevent modification across runs/parametrization
expected_results = copy.deepcopy(case["expected"])
tool_responses = copy.deepcopy(case.get("tool_responses", []))
input_messages_turns = copy.deepcopy(case["input"]["messages"])
# keep going until either
# 1. we have messages to test in multi-turn
# 2. no messages but last message is tool response
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
# do not take new messages if last message is tool response
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = input_messages_turns.pop(0)
# Ensure new_messages is a list of message objects
if isinstance(new_messages, list):
messages.extend(new_messages)
else:
# If it's a single message object, add it directly
messages.append(new_messages)
# --- API Call ---
response = openai_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
stream=False,
)
# --- Process Response ---
assistant_message = response.choices[0].message
messages.append(assistant_message.model_dump(exclude_unset=True))
assert assistant_message.role == "assistant"
# Get the expected result data
expected = expected_results.pop(0)
num_tool_calls = expected["num_tool_calls"]
# --- Assertions based on expected result ---
assert len(assistant_message.tool_calls or []) == num_tool_calls, (
f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}"
)
if num_tool_calls > 0:
tool_call = assistant_message.tool_calls[0]
assert tool_call.function.name == expected["tool_name"], (
f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'"
)
# Parse the JSON string arguments before comparing
actual_arguments = json.loads(tool_call.function.arguments)
assert actual_arguments == expected["tool_arguments"], (
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
)
# Prepare and append the tool response for the next turn
tool_response = tool_responses.pop(0)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_response["response"],
}
)
else:
assert assistant_message.content is not None, "Expected content, but none received."
expected_answers = expected["answer"] # This is now a list
content_lower = assistant_message.content.lower()
assert any(ans.lower() in content_lower for ans in expected_answers), (
f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'"
)
@pytest.mark.parametrize(
"case",
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
ids=case_id_generator,
)
def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
""" """
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
messages = []
tools = case["input"]["tools"]
expected_results = copy.deepcopy(case["expected"])
tool_responses = copy.deepcopy(case.get("tool_responses", []))
input_messages_turns = copy.deepcopy(case["input"]["messages"])
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = input_messages_turns.pop(0)
if isinstance(new_messages, list):
messages.extend(new_messages)
else:
messages.append(new_messages)
# --- API Call (Streaming) ---
stream = openai_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
stream=True,
)
# --- Process Stream ---
accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream)
# --- Construct Assistant Message for History ---
assistant_message_dict = {"role": "assistant"}
if accumulated_content:
assistant_message_dict["content"] = accumulated_content
if accumulated_tool_calls:
assistant_message_dict["tool_calls"] = accumulated_tool_calls
messages.append(assistant_message_dict)
# --- Assertions ---
expected = expected_results.pop(0)
num_tool_calls = expected["num_tool_calls"]
assert len(accumulated_tool_calls or []) == num_tool_calls, (
f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}"
)
if num_tool_calls > 0:
# Use the first accumulated tool call for assertion
tool_call = accumulated_tool_calls[0]
assert tool_call["function"]["name"] == expected["tool_name"], (
f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'"
)
# Parse the accumulated arguments string for comparison
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert actual_arguments == expected["tool_arguments"], (
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
)
# Prepare and append the tool response for the next turn
tool_response = tool_responses.pop(0)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_response["response"],
}
)
else:
assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received."
expected_answers = expected["answer"]
content_lower = accumulated_content.lower()
assert any(ans.lower() in content_lower for ans in expected_answers), (
f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'"
)
# --- Helper functions (structured output validation) ---
@ -324,3 +576,47 @@ def validate_structured_output(maybe_json_content: str, schema_name: str) -> Non
assert len(structured_output.participants) == 2
elif schema_name == "valid_math_reasoning":
assert len(structured_output.final_answer) > 0
def _accumulate_streaming_tool_calls(stream):
"""Accumulates tool calls and content from a streaming ChatCompletion response."""
tool_calls_buffer = {}
current_id = None
full_content = "" # Initialize content accumulator
# Process streaming chunks
for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
# Accumulate content
if delta.content:
full_content += delta.content
if delta.tool_calls is None:
continue
for tool_call_delta in delta.tool_calls:
if tool_call_delta.id:
current_id = tool_call_delta.id
call_id = current_id
# Skip if no ID seen yet for this tool call delta
if not call_id:
continue
func_delta = tool_call_delta.function
if call_id not in tool_calls_buffer:
tool_calls_buffer[call_id] = {
"id": call_id,
"type": "function", # Assume function type
"function": {"name": None, "arguments": ""}, # Nested structure
}
# Accumulate name and arguments into the nested function dict
if func_delta:
if func_delta.name:
tool_calls_buffer[call_id]["function"]["name"] = func_delta.name
if func_delta.arguments:
tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments
# Return content and tool calls as a list
return full_content, list(tool_calls_buffer.values())

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long