mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 01:01:59 +00:00
Merge branch 'main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
b7f16ac7a6
535 changed files with 23539 additions and 8112 deletions
|
|
@ -0,0 +1,9 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Custom distro for CI tests
|
||||
providers:
|
||||
inference:
|
||||
- remote::custom_ollama
|
||||
image_type: container
|
||||
image_name: ci-test
|
||||
external_providers_dir: /tmp/providers.d
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
version: '2'
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- datasetio
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
|
|
@ -24,34 +20,13 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
|
|
@ -67,17 +42,6 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
|
@ -37,7 +37,7 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
|||
return -1
|
||||
|
||||
|
||||
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
|
||||
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
||||
|
||||
|
|
@ -115,6 +115,70 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
def test_agent_name(llama_stack_client, text_model_id):
|
||||
agent_name = f"test-agent-{uuid4()}"
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
name=agent_name,
|
||||
)
|
||||
except TypeError:
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
)
|
||||
return
|
||||
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me a sentence that contains the word: hello",
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
all_spans = []
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[
|
||||
{"key": "session_id", "op": "eq", "value": session_id},
|
||||
],
|
||||
attributes_to_return=["input", "output", "agent_name", "agent_id", "session_id"],
|
||||
):
|
||||
all_spans.append(span.attributes)
|
||||
|
||||
agent_name_spans = []
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[],
|
||||
attributes_to_return=["agent_name"],
|
||||
):
|
||||
if "agent_name" in span.attributes:
|
||||
agent_name_spans.append(span.attributes)
|
||||
|
||||
agent_logs = []
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[
|
||||
{"key": "agent_name", "op": "eq", "value": agent_name},
|
||||
],
|
||||
attributes_to_return=["input", "output", "agent_name"],
|
||||
):
|
||||
if "output" in span.attributes and span.attributes["output"] != "no shields":
|
||||
agent_logs.append(span.attributes)
|
||||
|
||||
assert len(agent_logs) == 1
|
||||
assert agent_logs[0]["agent_name"] == agent_name
|
||||
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
|
||||
assert "hello" in agent_logs[0]["output"].lower()
|
||||
|
||||
|
||||
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||
common_params = dict(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
|
|
@ -231,6 +295,7 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
|||
# This test must be run in an environment where `bwrap` is available. If you are running against a
|
||||
# server, this means the _server_ must have `bwrap` available. If you are using library client, then
|
||||
# you must have `bwrap` available in test's environment.
|
||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
||||
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
|
@ -487,6 +552,7 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
assert "lora" in response.output_message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
||||
def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_config):
|
||||
if "llama-4" in agent_config["model"].lower():
|
||||
pytest.xfail("Not working for llama4")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import platform
|
|||
import textwrap
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -19,10 +20,29 @@ from .report import Report
|
|||
logger = get_logger(__name__, category="tests")
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
outcome = yield
|
||||
report = outcome.get_result()
|
||||
if report.when == "call":
|
||||
item.execution_outcome = report.outcome
|
||||
item.was_xfail = getattr(report, "wasxfail", False)
|
||||
|
||||
|
||||
def pytest_runtest_teardown(item):
|
||||
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
||||
if interval_seconds:
|
||||
time.sleep(float(interval_seconds))
|
||||
# Check if the test actually ran and passed or failed, but was not skipped or an expected failure (xfail)
|
||||
outcome = getattr(item, "execution_outcome", None)
|
||||
was_xfail = getattr(item, "was_xfail", False)
|
||||
|
||||
name = item.nodeid
|
||||
if not any(x in name for x in ("inference/", "safety/", "agents/")):
|
||||
return
|
||||
|
||||
logger.debug(f"Test '{item.nodeid}' outcome was '{outcome}' (xfail={was_xfail})")
|
||||
if outcome in ("passed", "failed") and not was_xfail:
|
||||
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
||||
if interval_seconds:
|
||||
time.sleep(float(interval_seconds))
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def data_url_from_file(file_path: str) -> str:
|
|||
return data_url
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
|
||||
@pytest.mark.parametrize(
|
||||
"purpose, source, provider_id, limit",
|
||||
[
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
import yaml
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
|
@ -207,3 +208,9 @@ def llama_stack_client(request, provider_data, text_model_id):
|
|||
raise RuntimeError("Initialization failed")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class RecordableMock:
|
|||
# Load existing cache if available and not recording
|
||||
if self.json_path.exists():
|
||||
try:
|
||||
with open(self.json_path, "r") as f:
|
||||
with open(self.json_path) as f:
|
||||
self.cache = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading cache from {self.json_path}: {e}")
|
||||
|
|
|
|||
|
|
@ -75,19 +75,24 @@ def openai_client(client_with_models):
|
|||
return OpenAI(base_url=base_url, api_key="bar")
|
||||
|
||||
|
||||
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||
def compat_client(request):
|
||||
return request.getfixturevalue(request.param)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
|
|||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
|
|
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
|
|||
0,
|
||||
],
|
||||
)
|
||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
||||
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "Hello, world!"
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
|
|||
assert len(choice.prompt_logprobs) > 0
|
||||
|
||||
|
||||
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
||||
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "I am feeling really sad today."
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
|||
assert choice.text in ["joy", "sadness"]
|
||||
|
||||
|
||||
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
|||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
response = compat_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
|
|||
"inference:chat_completion:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
response = compat_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -77,7 +76,7 @@ class TestPostTraining:
|
|||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, List)
|
||||
assert isinstance(jobs_list, list)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
5
tests/integration/providers/nvidia/__init__.py
Normal file
5
tests/integration/providers/nvidia/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
14
tests/integration/providers/nvidia/conftest.py
Normal file
14
tests/integration/providers/nvidia/conftest.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip all tests in this directory when running in GitHub Actions
|
||||
in_github_actions = os.environ.get("GITHUB_ACTIONS") == "true"
|
||||
if in_github_actions:
|
||||
pytest.skip("Skipping NVIDIA tests in GitHub Actions environment", allow_module_level=True)
|
||||
47
tests/integration/providers/nvidia/test_datastore.py
Normal file
47
tests/integration/providers/nvidia/test_datastore.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# LLAMA_STACK_CONFIG="nvidia" pytest -v tests/integration/providers/nvidia/test_datastore.py
|
||||
|
||||
|
||||
# nvidia provider only
|
||||
@pytest.mark.parametrize(
|
||||
"provider_id",
|
||||
[
|
||||
"nvidia",
|
||||
],
|
||||
)
|
||||
def test_register_and_unregister(llama_stack_client, provider_id):
|
||||
purpose = "eval/messages-answer"
|
||||
source = {
|
||||
"type": "uri",
|
||||
"uri": "hf://datasets/llamastack/simpleqa?split=train",
|
||||
}
|
||||
dataset_id = f"test-dataset-{provider_id}"
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
dataset_id=dataset_id,
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata={"provider_id": provider_id, "format": "json", "description": "Test dataset description"},
|
||||
)
|
||||
assert dataset.identifier is not None
|
||||
assert dataset.provider_id == provider_id
|
||||
assert dataset.identifier == dataset_id
|
||||
|
||||
dataset_list = llama_stack_client.datasets.list()
|
||||
provider_datasets = [d for d in dataset_list if d.provider_id == provider_id]
|
||||
assert any(provider_datasets)
|
||||
assert any(d.identifier == dataset_id for d in provider_datasets)
|
||||
|
||||
llama_stack_client.datasets.unregister(dataset.identifier)
|
||||
dataset_list = llama_stack_client.datasets.list()
|
||||
provider_datasets = [d for d in dataset_list if d.identifier == dataset.identifier]
|
||||
assert not any(provider_datasets)
|
||||
37
tests/integration/test_cases/openai/responses.json
Normal file
37
tests/integration/test_cases/openai/responses.json
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
{
|
||||
"non_streaming_01": {
|
||||
"data": {
|
||||
"question": "Which planet do humans live on?",
|
||||
"expected": "Earth"
|
||||
}
|
||||
},
|
||||
"non_streaming_02": {
|
||||
"data": {
|
||||
"question": "Which planet has rings around it with a name starting with letter S?",
|
||||
"expected": "Saturn"
|
||||
}
|
||||
},
|
||||
"streaming_01": {
|
||||
"data": {
|
||||
"question": "What's the name of the Sun in latin?",
|
||||
"expected": "Sol"
|
||||
}
|
||||
},
|
||||
"streaming_02": {
|
||||
"data": {
|
||||
"question": "What is the name of the US captial?",
|
||||
"expected": "Washington"
|
||||
}
|
||||
},
|
||||
"tools_web_search_01": {
|
||||
"data": {
|
||||
"input": "How many experts does the Llama 4 Maverick model have?",
|
||||
"tools": [
|
||||
{
|
||||
"type": "web_search"
|
||||
}
|
||||
],
|
||||
"expected": "128"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,7 @@ class TestCase:
|
|||
_apis = [
|
||||
"inference/chat_completion",
|
||||
"inference/completion",
|
||||
"openai/responses",
|
||||
]
|
||||
_jsonblob = {}
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class TestCase:
|
|||
# loading all test cases
|
||||
if self._jsonblob == {}:
|
||||
for api in self._apis:
|
||||
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
|
||||
with open(pathlib.Path(__file__).parent / f"{api}.json") as f:
|
||||
coloned = api.replace("/", ":")
|
||||
try:
|
||||
loaded = json.load(f)
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
|||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
# Verify it is unregistered
|
||||
with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
# Verify tools are also unregistered
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@ from llama_stack.distribution.configure import (
|
|||
@pytest.fixture
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
f"""
|
||||
version: {LLAMA_STACK_RUN_CONFIG_VERSION}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
built_at: {datetime.now().isoformat()}
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
|
@ -42,16 +42,16 @@ def up_to_date_config():
|
|||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
f"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
built_at: {datetime.now().isoformat()}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
|
|
@ -82,7 +82,7 @@ def old_config():
|
|||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(built_at=datetime.now().isoformat())
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
314
tests/unit/distribution/routers/test_routing_tables.py
Normal file
314
tests/unit/distribution/routers/test_routing_tables.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Unit tests for the routing tables
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models.models import Model, ModelType
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.routers.routing_tables import (
|
||||
BenchmarksRoutingTable,
|
||||
DatasetsRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
ToolGroupsRoutingTable,
|
||||
VectorDBsRoutingTable,
|
||||
)
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dist_registry(tmp_path):
|
||||
db_path = tmp_path / "test_kv.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
yield registry
|
||||
|
||||
|
||||
class Impl:
|
||||
def __init__(self, api: Api):
|
||||
self.api = api
|
||||
|
||||
@property
|
||||
def __provider_spec__(self):
|
||||
_provider_spec = AsyncMock()
|
||||
_provider_spec.api = self.api
|
||||
return _provider_spec
|
||||
|
||||
|
||||
class InferenceImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.inference)
|
||||
|
||||
async def register_model(self, model: Model):
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str):
|
||||
return model_id
|
||||
|
||||
|
||||
class SafetyImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.safety)
|
||||
|
||||
async def register_shield(self, shield: Shield):
|
||||
return shield
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
|
||||
class DatasetsImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.datasetio)
|
||||
|
||||
async def register_dataset(self, dataset: Dataset):
|
||||
return dataset
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str):
|
||||
return dataset_id
|
||||
|
||||
|
||||
class ScoringFunctionsImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.scoring)
|
||||
|
||||
async def list_scoring_functions(self):
|
||||
return []
|
||||
|
||||
async def register_scoring_function(self, scoring_fn):
|
||||
return scoring_fn
|
||||
|
||||
|
||||
class BenchmarksImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.eval)
|
||||
|
||||
async def register_benchmark(self, benchmark):
|
||||
return benchmark
|
||||
|
||||
|
||||
class ToolGroupsImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.tool_runtime)
|
||||
|
||||
async def register_tool(self, tool):
|
||||
return tool
|
||||
|
||||
async def unregister_tool(self, tool_name: str):
|
||||
return tool_name
|
||||
|
||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="test-tool",
|
||||
description="Test tool",
|
||||
parameters=[ToolParameter(name="test-param", description="Test param", parameter_type="string")],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider")
|
||||
await table.register_model(model_id="test-model-2", provider_id="test_provider")
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
model_ids = {m.identifier for m in models.data}
|
||||
assert "test-model" in model_ids
|
||||
assert "test-model-2" in model_ids
|
||||
|
||||
# Test openai list models
|
||||
openai_models = await table.openai_list_models()
|
||||
assert len(openai_models.data) == 2
|
||||
openai_model_ids = {m.id for m in openai_models.data}
|
||||
assert "test-model" in openai_model_ids
|
||||
assert "test-model-2" in openai_model_ids
|
||||
|
||||
# Test get_object_by_identifier
|
||||
model = await table.get_object_by_identifier("model", "test-model")
|
||||
assert model is not None
|
||||
assert model.identifier == "test-model"
|
||||
|
||||
# Test get_object_by_identifier on non-existent object
|
||||
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
|
||||
assert non_existent is None
|
||||
|
||||
await table.unregister_model(model_id="test-model")
|
||||
await table.unregister_model(model_id="test-model-2")
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 0
|
||||
|
||||
# Test openai list models
|
||||
openai_models = await table.openai_list_models()
|
||||
assert len(openai_models.data) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
||||
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
||||
shields = await table.list_shields()
|
||||
|
||||
assert len(shields.data) == 2
|
||||
shield_ids = {s.identifier for s in shields.data}
|
||||
assert "test-shield" in shield_ids
|
||||
assert "test-shield-2" in shield_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry)
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_providere",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_datasets_routing_table(dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
await table.register_dataset(
|
||||
dataset_id="test-dataset", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource(uri="test-uri")
|
||||
)
|
||||
await table.register_dataset(
|
||||
dataset_id="test-dataset-2", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource(uri="test-uri-2")
|
||||
)
|
||||
datasets = await table.list_datasets()
|
||||
|
||||
assert len(datasets.data) == 2
|
||||
dataset_ids = {d.identifier for d in datasets.data}
|
||||
assert "test-dataset" in dataset_ids
|
||||
assert "test-dataset-2" in dataset_ids
|
||||
|
||||
await table.unregister_dataset(dataset_id="test-dataset")
|
||||
await table.unregister_dataset(dataset_id="test-dataset-2")
|
||||
|
||||
datasets = await table.list_datasets()
|
||||
assert len(datasets.data) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn-2",
|
||||
provider_id="test_provider",
|
||||
description="Another test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
scoring_functions = await table.list_scoring_functions()
|
||||
|
||||
assert len(scoring_functions.data) == 2
|
||||
scoring_fn_ids = {fn.identifier for fn in scoring_functions.data}
|
||||
assert "test-scoring-fn" in scoring_fn_ids
|
||||
assert "test-scoring-fn-2" in scoring_fn_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
await table.register_benchmark(
|
||||
benchmark_id="test-benchmark",
|
||||
dataset_id="test-dataset",
|
||||
scoring_functions=["test-scoring-fn", "test-scoring-fn-2"],
|
||||
)
|
||||
benchmarks = await table.list_benchmarks()
|
||||
|
||||
assert len(benchmarks.data) == 1
|
||||
benchmark_ids = {b.identifier for b in benchmarks.data}
|
||||
assert "test-benchmark" in benchmark_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry)
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
await table.register_tool_group(
|
||||
toolgroup_id="test-toolgroup",
|
||||
provider_id="test_provider",
|
||||
)
|
||||
tool_groups = await table.list_tool_groups()
|
||||
|
||||
assert len(tool_groups.data) == 1
|
||||
tool_group_ids = {tg.identifier for tg in tool_groups.data}
|
||||
assert "test-toolgroup" in tool_group_ids
|
||||
|
||||
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
|
||||
tool_groups = await table.list_tool_groups()
|
||||
assert len(tool_groups.data) == 0
|
||||
40
tests/unit/distribution/test_build_path.py
Normal file
40
tests/unit/distribution/test_build_path.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.stack._build import (
|
||||
_run_stack_build_command_from_build_config,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||
called_with = {}
|
||||
|
||||
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
|
||||
called_with["path"] = template_or_config
|
||||
called_with["run_config"] = run_config
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.cli.stack._build.build_image",
|
||||
spy_build_image,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
cfg = BuildConfig(
|
||||
image_type=LlamaStackImageType.CONTAINER.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description=""),
|
||||
)
|
||||
|
||||
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
|
||||
|
||||
assert "path" in called_with
|
||||
assert isinstance(called_with["path"], str)
|
||||
assert Path(called_with["path"]).exists()
|
||||
assert called_with["run_config"] is None
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -23,7 +23,7 @@ class SampleConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseOutputMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kvstore():
|
||||
kvstore = AsyncMock(spec=KVStore)
|
||||
return kvstore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
inference_api = AsyncMock()
|
||||
return inference_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_groups_api():
|
||||
tool_groups_api = AsyncMock(spec=ToolGroups)
|
||||
return tool_groups_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_runtime_api():
|
||||
tool_runtime_api = AsyncMock(spec=ToolRuntime)
|
||||
return tool_runtime_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api):
|
||||
return OpenAIResponsesImpl(
|
||||
persistence_store=mock_kvstore,
|
||||
inference_api=mock_inference_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input."""
|
||||
# Setup
|
||||
input_text = "Hello, world!"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
mock_chat_completion = OpenAIChatCompletion(
|
||||
id="chat-completion-123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Hello! How can I help you?"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
)
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
|
||||
# Execute
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello, world!", name=None)],
|
||||
tools=None,
|
||||
stream=False,
|
||||
temperature=0.1,
|
||||
)
|
||||
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||
assert result.model == model
|
||||
assert len(result.output) == 1
|
||||
assert isinstance(result.output[0], OpenAIResponseOutputMessage)
|
||||
assert result.output[0].content[0].text == "Hello! How can I help you?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input and tools."""
|
||||
# Setup
|
||||
input_text = "What was the score of todays game?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
mock_chat_completions = [
|
||||
OpenAIChatCompletion(
|
||||
id="chat-completion-123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionToolCall(
|
||||
id="tool_call_123",
|
||||
type="function",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search", arguments='{"query":"What was the score of todays game?"}'
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
),
|
||||
OpenAIChatCompletion(
|
||||
id="chat-completion-123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="The score of todays game was 10-12"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
),
|
||||
]
|
||||
|
||||
mock_inference_api.openai_chat_completion.side_effect = mock_chat_completions
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||
identifier="web_search",
|
||||
provider_id="client",
|
||||
toolgroup_id="web_search",
|
||||
tool_host="client",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
|
||||
],
|
||||
)
|
||||
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
|
||||
status="completed",
|
||||
content="The score of todays game was 10-12",
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolWebSearch(
|
||||
name="web_search",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == "What was the score of todays game?"
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
assert second_call.kwargs["messages"][-1].content == "The score of todays game was 10-12"
|
||||
assert second_call.kwargs["temperature"] == 0.1
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||
tool_name="web_search",
|
||||
kwargs={"query": "What was the score of todays game?"},
|
||||
)
|
||||
|
||||
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
assert len(result.output) >= 1
|
||||
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
||||
assert result.output[1].content[0].text == "The score of todays game was 10-12"
|
||||
|
|
@ -10,7 +10,7 @@ import logging
|
|||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -26,9 +26,17 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
)
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
from llama_stack.apis.inference import ToolChoice, ToolConfig
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
SystemMessage,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||
VLLMInferenceAdapter,
|
||||
|
|
@ -47,7 +55,7 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
|
|||
|
||||
|
||||
class MockInferenceAdapterWithSleep:
|
||||
def __init__(self, sleep_time: int, response: Dict[str, Any]):
|
||||
def __init__(self, sleep_time: int, response: dict[str, Any]):
|
||||
self.httpd = None
|
||||
|
||||
class DelayedRequestHandler(BaseHTTPRequestHandler):
|
||||
|
|
@ -130,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
|||
assert request.tool_config.tool_choice == ToolChoice.none
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_response(vllm_inference_adapter):
|
||||
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||
into the expected JSON format."""
|
||||
|
||||
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||
with patch.object(
|
||||
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_nonstream_completion:
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
UserMessage(content="How many?"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="foo",
|
||||
tool_name="knowledge_search",
|
||||
arguments={"query": "How many?"},
|
||||
arguments_json='{"query": "How many?"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
|
||||
]
|
||||
await vllm_inference_adapter.chat_completion(
|
||||
"mock-model",
|
||||
messages,
|
||||
stream=False,
|
||||
tools=[],
|
||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||
{
|
||||
"id": "foo",
|
||||
"type": "function",
|
||||
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_delta_empty_tool_call_buf():
|
||||
"""
|
||||
|
|
@ -232,3 +283,14 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
|||
# above.
|
||||
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
||||
assert not asyncio_warnings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||
request = ChatCompletionRequest(
|
||||
tools=[],
|
||||
model="test_model",
|
||||
messages=[UserMessage(content="test")],
|
||||
)
|
||||
params = await vllm_inference_adapter._get_params(request)
|
||||
assert "tools" not in params
|
||||
|
|
|
|||
138
tests/unit/providers/nvidia/test_datastore.py
Normal file
138
tests/unit/providers/nvidia/test_datastore.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
||||
|
||||
|
||||
class TestNvidiaDatastore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
self.adapter = NvidiaDatasetIOAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
self.run_async(self.adapter.register_dataset(dataset_def))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
|
||||
def test_unregister_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
|
||||
self.run_async(self.adapter.unregister_dataset(dataset_id))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
|
||||
def test_register_dataset_with_custom_namespace_project(self):
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"format": "jsonl"},
|
||||
)
|
||||
|
||||
self.run_async(custom_adapter.register_dataset(dataset_def))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "custom-project",
|
||||
"format": "jsonl",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
201
tests/unit/providers/nvidia/test_eval.py
Normal file
201
tests/unit/providers/nvidia/test_eval.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||
|
||||
MOCK_DATASET_ID = "default/test-dataset"
|
||||
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||
|
||||
|
||||
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
|
||||
# Create mock APIs
|
||||
self.datasetio_api = MagicMock()
|
||||
self.datasets_api = MagicMock()
|
||||
self.scoring_api = MagicMock()
|
||||
self.inference_api = MagicMock()
|
||||
self.agents_api = MagicMock()
|
||||
|
||||
self.config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
|
||||
self.eval_impl = NVIDIAEvalImpl(
|
||||
config=self.config,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
scoring_api=self.scoring_api,
|
||||
inference_api=self.inference_api,
|
||||
agents_api=self.agents_api,
|
||||
)
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.evaluator_get_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||
)
|
||||
self.evaluator_post_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||
)
|
||||
|
||||
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.evaluator_get_patcher.stop()
|
||||
self.evaluator_post_patcher.stop()
|
||||
|
||||
def _assert_request_body(self, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
call_args = self.mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def test_register_benchmark(self):
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type="benchmark",
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Register the benchmark
|
||||
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||
|
||||
def test_run_eval(self):
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Run the Evaluation job
|
||||
result = self.run_async(
|
||||
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body(
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
|
||||
def test_job_status(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||
|
||||
# Get the Evaluation job
|
||||
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
def test_job_cancel(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Cancel the Evaluation job
|
||||
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
|
||||
def test_job_result(self):
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
self.mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert self.mock_evaluator_get.call_count == 2
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
||||
|
|
@ -10,14 +10,17 @@ import warnings
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigEfficiencyConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
EfficiencyConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
OptimizerType,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
|
|
@ -66,11 +69,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
|
||||
def test_customizer_parameters_passed(self):
|
||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||
custom_adapter_dim = 32 # Different from default of 8
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=custom_adapter_dim,
|
||||
adapter_dropout=0.2,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
|
|
@ -78,8 +78,15 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
|
||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
|
||||
data_config = DataConfig(
|
||||
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0002,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=3,
|
||||
data_config=data_config,
|
||||
|
|
@ -95,7 +102,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
|
@ -114,7 +121,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
self._assert_request_params(
|
||||
{
|
||||
"hyperparameters": {
|
||||
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
|
||||
"lora": {"alpha": 16},
|
||||
"epochs": 3,
|
||||
"learning_rate": 0.0002,
|
||||
"batch_size": 16,
|
||||
|
|
@ -130,8 +137,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
|
|
@ -139,12 +144,16 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id=required_dataset_id, # Required parameter
|
||||
batch_size=8,
|
||||
data_config = DataConfig(
|
||||
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
|
|
@ -161,7 +170,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
|
@ -186,24 +195,24 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
|
||||
def test_unsupported_parameters_warning(self):
|
||||
"""Test that warnings are raised for unsupported parameters."""
|
||||
data_config = TrainingConfigDataConfig(
|
||||
data_config = DataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
# Unsupported parameters
|
||||
shuffle=True,
|
||||
data_format="instruct",
|
||||
data_format=DatasetFormat.instruct,
|
||||
validation_dataset_id="val-dataset",
|
||||
)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
optimizer_config = OptimizerConfig(
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
# Unsupported parameters
|
||||
optimizer_type="adam",
|
||||
optimizer_type=OptimizerType.adam,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
efficiency_config = TrainingConfigEfficiencyConfig(
|
||||
efficiency_config = EfficiencyConfig(
|
||||
enable_activation_checkpointing=True # Unsupported parameter
|
||||
)
|
||||
|
||||
|
|
@ -230,15 +239,13 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
checkpoint_dir="test-dir", # Unsupported parameter
|
||||
algorithm_config=LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
),
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={"test": "value"}, # Unsupported parameter
|
||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
|
@ -139,8 +138,8 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
|
|||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
|
|
@ -193,8 +192,8 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
|
|||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
|
|
@ -269,8 +268,8 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
|
|||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
|
|
|
|||
|
|
@ -10,13 +10,19 @@ import warnings
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
OptimizerType,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
|
|
@ -40,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
|
||||
self.mock_client = unittest.mock.MagicMock()
|
||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||
self.inference_make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||
return_value=self.mock_client,
|
||||
)
|
||||
self.inference_make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
self.inference_make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
|
|
@ -105,7 +125,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
|
|
@ -116,8 +136,6 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
|
|
@ -125,10 +143,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
|
|
@ -145,7 +168,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
|
@ -169,16 +192,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_supervised_fine_tune_with_qat(self):
|
||||
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
|
|
@ -193,42 +222,55 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": "completed",
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
]
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == "completed"
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
||||
)
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
|
@ -290,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_inference_register_model(self):
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = self.run_async(self.inference_adapter.register_model(model))
|
||||
assert result == model
|
||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
self.run_async(
|
||||
self.inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
)
|
||||
|
||||
mock_chat_completion.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
116
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
116
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments_json='{"foo": "bar"}',
|
||||
arguments={"foo": "bar"},
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
OpenAIUserMessageParam(content="user message"),
|
||||
OpenAIAssistantMessageParam(content="assistant message"),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content == "system message"
|
||||
assert llama_messages[1].content == "user message"
|
||||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
163
tests/unit/providers/utils/test_model_registry.py
Normal file
163
tests/unit/providers/utils/test_model_registry.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
#
|
||||
# ModelRegistryHelper provides mixin functionality for registering and
|
||||
# unregistering models. It maintains a mapping of model ID / aliases to
|
||||
# provider model IDs.
|
||||
#
|
||||
# Test cases -
|
||||
# - Looking up an alias that does not exist should return None.
|
||||
# - Registering a model + provider ID should add the model to the registry. If
|
||||
# provider ID is known or an alias for a provider ID.
|
||||
# - Registering an existing model should return an error. Unless it's a
|
||||
# dulicate entry.
|
||||
# - Unregistering a model should remove it from the registry.
|
||||
# - Unregistering a model that does not exist should return an error.
|
||||
# - Supported model ID and their aliases are registered during initialization.
|
||||
# Only aliases are added afterwards.
|
||||
#
|
||||
# Questions -
|
||||
# - Should we be allowed to register models w/o provider model IDs? No.
|
||||
# According to POST /v1/models, required params are
|
||||
# - identifier
|
||||
# - provider_resource_id
|
||||
# - provider_id
|
||||
# - type
|
||||
# - metadata
|
||||
# - model_type
|
||||
#
|
||||
# TODO: llama_model functionality
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_model() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="known-model",
|
||||
provider_resource_id="known-provider-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_model2() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="known-model2",
|
||||
provider_resource_id="known-provider-id2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_provider_model(known_model: Model) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=known_model.provider_resource_id,
|
||||
aliases=[known_model.model_id],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_provider_model2(known_model2: Model) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=known_model2.provider_resource_id,
|
||||
# aliases=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unknown_model() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="unknown-model",
|
||||
provider_resource_id="unknown-provider-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
|
||||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.register_model(unknown_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
identifier="new-model",
|
||||
provider_resource_id=known_model.provider_resource_id,
|
||||
)
|
||||
assert helper.get_provider_model_id(model.model_id) is None
|
||||
await helper.register_model(model)
|
||||
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
identifier="new-model",
|
||||
provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id
|
||||
)
|
||||
assert helper.get_provider_model_id(model.model_id) is None
|
||||
await helper.register_model(model)
|
||||
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model)
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing_different(
|
||||
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
||||
) -> None:
|
||||
known_model.provider_resource_id = known_model2.provider_resource_id
|
||||
with pytest.raises(ValueError):
|
||||
await helper.register_model(known_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model) # duplicate entry
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.model_id)
|
||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.unregister_model(unknown_model.model_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.provider_resource_id)
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth import AccessAttributes
|
||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
|||
|
||||
class AsyncMock(MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super(AsyncMock, self).__call__(*args, **kwargs)
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
def _return_model(model):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, AuthProviderType
|
||||
|
||||
|
||||
class MockResponse:
|
||||
|
|
@ -38,9 +40,23 @@ def invalid_api_key():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_auth_endpoint):
|
||||
def valid_token():
|
||||
return "valid.jwt.token"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_token():
|
||||
return "invalid.jwt.token"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint)
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
|
@ -50,8 +66,29 @@ def app(mock_auth_endpoint):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_client(http_app):
|
||||
return TestClient(http_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_client(k8s_app):
|
||||
return TestClient(k8s_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -61,7 +98,7 @@ def mock_scope():
|
|||
"path": "/models/list",
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"authorization", b"Bearer test-api-key"),
|
||||
(b"authorization", b"Bearer test.jwt.token"),
|
||||
(b"user-agent", b"test-user-agent"),
|
||||
],
|
||||
"query_string": b"limit=100&offset=0",
|
||||
|
|
@ -69,13 +106,38 @@ def mock_scope():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_middleware(mock_auth_endpoint):
|
||||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
return MockResponse(200, {"message": "Authentication successful"})
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research", "production"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_post_failure(*args, **kwargs):
|
||||
|
|
@ -86,45 +148,46 @@ async def mock_post_exception(*args, **kwargs):
|
|||
raise Exception("Connection error")
|
||||
|
||||
|
||||
def test_missing_auth_header(client):
|
||||
response = client.get("/test")
|
||||
# HTTP Endpoint Tests
|
||||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(client):
|
||||
response = client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||
def test_valid_authentication(client, valid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
def test_valid_http_authentication(http_client, valid_api_key):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_failure)
|
||||
def test_invalid_authentication(client, invalid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
def test_invalid_http_authentication(http_client, invalid_api_key):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_exception)
|
||||
def test_auth_service_error(client, valid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
def test_http_auth_service_error(http_client, valid_api_key):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication service error" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(200, {"message": "Authentication successful"})
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client.get(
|
||||
http_client.get(
|
||||
"/test?param1=value1¶m2=value2",
|
||||
headers={
|
||||
"Authorization": f"Bearer {valid_api_key}",
|
||||
|
|
@ -149,40 +212,43 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
|
||||
middleware, mock_app = mock_middleware
|
||||
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
|
||||
"""Test HTTP middleware behavior with access attributes"""
|
||||
middleware, mock_app = mock_http_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["project-x", "project-y"],
|
||||
}
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research", "production"],
|
||||
},
|
||||
},
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert attributes["roles"] == ["admin", "user"]
|
||||
assert attributes["teams"] == ["ml-team", "nlp-team"]
|
||||
assert attributes["projects"] == ["llama-3", "project-x"]
|
||||
assert attributes["namespaces"] == ["research", "production"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
||||
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_middleware
|
||||
middleware, mock_app = mock_http_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -203,4 +269,104 @@ async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
|||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "namespaces" in attributes
|
||||
assert attributes["namespaces"] == ["test-api-key"]
|
||||
assert attributes["namespaces"] == ["test.jwt.token"]
|
||||
|
||||
|
||||
# Kubernetes Tests
|
||||
def test_missing_auth_header_k8s(k8s_client):
|
||||
response = k8s_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_k8s(k8s_client):
|
||||
response = k8s_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("kubernetes.client.ApiClient")
|
||||
def test_valid_k8s_authentication(mock_api_client, k8s_client, valid_token):
|
||||
# Mock the Kubernetes client
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock successful token validation
|
||||
mock_client.set_default_header = AsyncMock()
|
||||
|
||||
# Mock the token validation to return valid access attributes
|
||||
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
|
||||
mock_validate.return_value = AccessAttributes(
|
||||
roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"]
|
||||
)
|
||||
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("kubernetes.client.ApiClient")
|
||||
def test_invalid_k8s_authentication(mock_api_client, k8s_client, invalid_token):
|
||||
# Mock the Kubernetes client
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock failed token validation by raising an exception
|
||||
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
|
||||
mock_validate.side_effect = ValueError("Invalid or expired token")
|
||||
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid or expired token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_k8s_middleware_with_access_attributes(mock_k8s_middleware, mock_scope):
|
||||
middleware, mock_app = mock_k8s_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("kubernetes.client.ApiClient") as mock_api_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock successful token validation
|
||||
mock_client.set_default_header = AsyncMock()
|
||||
|
||||
# Mock token payload with access attributes
|
||||
mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiIsImdyb3VwcyI6WyJtbC10ZWFtIl19", "signature"]
|
||||
mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode())
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_k8s_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("kubernetes.client.ApiClient") as mock_api_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock successful token validation
|
||||
mock_client.set_default_header = AsyncMock()
|
||||
|
||||
# Mock token payload without access attributes
|
||||
mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiJ9", "signature"]
|
||||
mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode())
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "roles" in attributes
|
||||
assert attributes["roles"] == ["admin"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, Dict, Protocol
|
||||
from typing import Any, Protocol
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -48,14 +48,14 @@ class SampleConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
||||
|
||||
class SampleImpl:
|
||||
def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None):
|
||||
def __init__(self, config: SampleConfig, deps: dict[Api, Any], provider_spec: ProviderSpec = None):
|
||||
self.__provider_id__ = "test_provider"
|
||||
self.__provider_spec__ = provider_spec
|
||||
self.__provider_config__ = config
|
||||
|
|
|
|||
91
tests/unit/server/test_sse.py
Normal file
91
tests/unit/server/test_sse.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.server.server import create_sse_event, sse_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_basic():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield "Test event 1"
|
||||
yield "Test event 2"
|
||||
|
||||
return event_gen()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Test that the events are streamed correctly
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
assert len(seen_events) == 2
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
||||
assert seen_events[1] == create_sse_event("Test event 2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_client_disconnected():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield "Test event 1"
|
||||
# Simulate a client disconnect before emitting event 2
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
return event_gen()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# We should see 1 event before the client disconnected
|
||||
assert len(seen_events) == 1
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_client_disconnected_before_response_starts():
|
||||
# Disconnect before the response starts
|
||||
async def async_event_gen():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# No events should be seen since the client disconnected immediately
|
||||
assert len(seen_events) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_error_before_response_starts():
|
||||
# Raise an error before the response starts
|
||||
async def async_event_gen():
|
||||
raise Exception("Test error")
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# We should have 1 error event
|
||||
assert len(seen_events) == 1
|
||||
assert 'data: {"error":' in seen_events[0]
|
||||
|
|
@ -8,29 +8,44 @@ This framework allows you to run the same set of verification tests against diff
|
|||
|
||||
## Features
|
||||
|
||||
The verification suite currently tests:
|
||||
The verification suite currently tests the following in both streaming and non-streaming modes:
|
||||
|
||||
- Basic chat completions (streaming and non-streaming)
|
||||
- Basic chat completions
|
||||
- Image input capabilities
|
||||
- Structured JSON output formatting
|
||||
- Tool calling functionality
|
||||
|
||||
## Report
|
||||
|
||||
The lastest report can be found at [REPORT.md](REPORT.md).
|
||||
|
||||
To update the report, ensure you have the API keys set,
|
||||
```bash
|
||||
export OPENAI_API_KEY=<your_openai_api_key>
|
||||
export FIREWORKS_API_KEY=<your_fireworks_api_key>
|
||||
export TOGETHER_API_KEY=<your_together_api_key>
|
||||
```
|
||||
then run
|
||||
```bash
|
||||
uv run --with-editable ".[dev]" python tests/verifications/generate_report.py --run-tests
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the verification tests, use pytest with the following parameters:
|
||||
|
||||
```bash
|
||||
cd llama-stack
|
||||
pytest tests/verifications/openai --provider=<provider-name>
|
||||
pytest tests/verifications/openai_api --provider=<provider-name>
|
||||
```
|
||||
|
||||
Example:
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest tests/verifications/openai --provider=together
|
||||
pytest tests/verifications/openai_api --provider=together
|
||||
|
||||
# Only run tests with Llama 4 models
|
||||
pytest tests/verifications/openai --provider=together -k 'Llama-4'
|
||||
pytest tests/verifications/openai_api --provider=together -k 'Llama-4'
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
|
@ -41,23 +56,22 @@ pytest tests/verifications/openai --provider=together -k 'Llama-4'
|
|||
|
||||
## Supported Providers
|
||||
|
||||
The verification suite currently supports:
|
||||
- OpenAI
|
||||
- Fireworks
|
||||
- Together
|
||||
- Groq
|
||||
- Cerebras
|
||||
The verification suite supports any provider with an OpenAI compatible endpoint.
|
||||
|
||||
See `tests/verifications/conf/` for the list of supported providers.
|
||||
|
||||
To run on a new provider, simply add a new yaml file to the `conf/` directory with the provider config. See `tests/verifications/conf/together.yaml` for an example.
|
||||
|
||||
## Adding New Test Cases
|
||||
|
||||
To add new test cases, create appropriate JSON files in the `openai/fixtures/test_cases/` directory following the existing patterns.
|
||||
To add new test cases, create appropriate JSON files in the `openai_api/fixtures/test_cases/` directory following the existing patterns.
|
||||
|
||||
|
||||
## Structure
|
||||
|
||||
- `__init__.py` - Marks the directory as a Python package
|
||||
- `conftest.py` - Global pytest configuration and fixtures
|
||||
- `openai/` - Tests specific to OpenAI-compatible APIs
|
||||
- `conf/` - Provider-specific configuration files
|
||||
- `openai_api/` - Tests specific to OpenAI-compatible APIs
|
||||
- `fixtures/` - Test fixtures and utilities
|
||||
- `fixtures.py` - Provider-specific fixtures
|
||||
- `load.py` - Utilities for loading test cases
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Test Results Report
|
||||
|
||||
*Generated on: 2025-04-10 16:48:18*
|
||||
*Generated on: 2025-04-17 12:42:33*
|
||||
|
||||
*This report was generated by running `python tests/verifications/generate_report.py`*
|
||||
|
||||
|
|
@ -15,22 +15,74 @@
|
|||
|
||||
| Provider | Pass Rate | Tests Passed | Total Tests |
|
||||
| --- | --- | --- | --- |
|
||||
| Together | 64.7% | 22 | 34 |
|
||||
| Fireworks | 82.4% | 28 | 34 |
|
||||
| Openai | 100.0% | 24 | 24 |
|
||||
| Meta_reference | 100.0% | 28 | 28 |
|
||||
| Together | 50.0% | 40 | 80 |
|
||||
| Fireworks | 50.0% | 40 | 80 |
|
||||
| Openai | 100.0% | 56 | 56 |
|
||||
|
||||
|
||||
|
||||
## Meta_reference
|
||||
|
||||
*Tests run on: 2025-04-17 12:37:11*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -v
|
||||
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
**Model Key (Meta_reference)**
|
||||
|
||||
| Display Name | Full Model ID |
|
||||
| --- | --- |
|
||||
| Llama-4-Scout-Instruct | `meta-llama/Llama-4-Scout-17B-16E-Instruct` |
|
||||
|
||||
|
||||
| Test | Llama-4-Scout-Instruct |
|
||||
| --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ✅ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ |
|
||||
| test_chat_non_streaming_image | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_none | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ |
|
||||
| test_chat_streaming_basic (earth) | ✅ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ |
|
||||
| test_chat_streaming_image | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ |
|
||||
| test_chat_streaming_tool_calling | ✅ |
|
||||
| test_chat_streaming_tool_choice_none | ✅ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ |
|
||||
|
||||
## Together
|
||||
|
||||
*Tests run on: 2025-04-10 16:46:35*
|
||||
*Tests run on: 2025-04-17 12:27:45*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -v
|
||||
|
||||
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -k "test_chat_non_streaming_basic and earth"
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -45,29 +97,45 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
|
|||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-Instruct | Llama-4-Scout-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ⚪ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ❌ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_none | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (earth) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_image | ⚪ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_calling | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_choice_none | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
|
||||
|
||||
## Fireworks
|
||||
|
||||
*Tests run on: 2025-04-10 16:44:44*
|
||||
*Tests run on: 2025-04-17 12:29:53*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -v
|
||||
|
||||
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -k "test_chat_non_streaming_basic and earth"
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -82,29 +150,45 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
|
|||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-Instruct | Llama-4-Scout-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_calling | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_choice_none | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
|
||||
|
||||
## Openai
|
||||
|
||||
*Tests run on: 2025-04-10 16:47:28*
|
||||
*Tests run on: 2025-04-17 12:34:08*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -v
|
||||
|
||||
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -k "test_chat_non_streaming_basic and earth"
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -118,15 +202,31 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
|
|||
|
||||
| Test | gpt-4o | gpt-4o-mini |
|
||||
| --- | --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ✅ | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (earth) | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ | ✅ |
|
||||
| test_chat_streaming_image | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_calling | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_choice_none | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ | ✅ |
|
||||
|
|
|
|||
|
|
@ -8,3 +8,4 @@ test_exclusions:
|
|||
llama-3.3-70b:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
|||
|
|
@ -12,3 +12,6 @@ test_exclusions:
|
|||
fireworks/llama-v3p3-70b-instruct:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
- test_response_non_streaming_image
|
||||
- test_response_non_streaming_multi_turn_image
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
accounts/fireworks/models/llama-v3p3-70b-instruct:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
|||
|
|
@ -12,3 +12,6 @@ test_exclusions:
|
|||
groq/llama-3.3-70b-versatile:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
- test_response_non_streaming_image
|
||||
- test_response_non_streaming_multi_turn_image
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
llama-3.3-70b-versatile:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
|||
8
tests/verifications/conf/meta_reference.yaml
Normal file
8
tests/verifications/conf/meta_reference.yaml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
# LLAMA_STACK_PORT=5002 llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=<path_to_ckpt>
|
||||
base_url: http://localhost:5002/v1/openai/v1
|
||||
api_key_var: foo
|
||||
models:
|
||||
- meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
model_display_names:
|
||||
meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
|
||||
test_exclusions: {}
|
||||
|
|
@ -12,3 +12,6 @@ test_exclusions:
|
|||
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
- test_response_non_streaming_image
|
||||
- test_response_non_streaming_multi_turn_image
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
meta-llama/Llama-3.3-70B-Instruct-Turbo:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
|||
|
|
@ -1,16 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "pytest-json-report",
|
||||
# "pyyaml",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Test Report Generator
|
||||
|
||||
|
|
@ -24,7 +18,7 @@ Description:
|
|||
|
||||
|
||||
Configuration:
|
||||
- Provider details (models, display names) are loaded from `tests/verifications/config.yaml`.
|
||||
- Provider details (models, display names) are loaded from `tests/verifications/conf/*.yaml`.
|
||||
- Test cases are defined in YAML files within `tests/verifications/openai_api/fixtures/test_cases/`.
|
||||
- Test results are stored in `tests/verifications/test_results/`.
|
||||
|
||||
|
|
@ -56,7 +50,7 @@ import subprocess
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, DefaultDict, Dict, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
|
||||
|
||||
|
|
@ -67,16 +61,11 @@ RESULTS_DIR.mkdir(exist_ok=True)
|
|||
# Maximum number of test result files to keep per provider
|
||||
MAX_RESULTS_PER_PROVIDER = 1
|
||||
|
||||
PROVIDER_ORDER = [
|
||||
DEFAULT_PROVIDERS = [
|
||||
"meta_reference",
|
||||
"together",
|
||||
"fireworks",
|
||||
"groq",
|
||||
"cerebras",
|
||||
"openai",
|
||||
"together-llama-stack",
|
||||
"fireworks-llama-stack",
|
||||
"groq-llama-stack",
|
||||
"openai-llama-stack",
|
||||
]
|
||||
|
||||
VERIFICATION_CONFIG = _load_all_verification_configs()
|
||||
|
|
@ -117,7 +106,7 @@ def run_tests(provider, keyword=None):
|
|||
|
||||
# Check if the JSON file was created
|
||||
if temp_json_file.exists():
|
||||
with open(temp_json_file, "r") as f:
|
||||
with open(temp_json_file) as f:
|
||||
test_results = json.load(f)
|
||||
|
||||
test_results["run_timestamp"] = timestamp
|
||||
|
|
@ -142,9 +131,17 @@ def run_tests(provider, keyword=None):
|
|||
return None
|
||||
|
||||
|
||||
def run_multiple_tests(providers_to_run: list[str], keyword: str | None):
|
||||
"""Runs tests for a list of providers."""
|
||||
print(f"Running tests for providers: {', '.join(providers_to_run)}")
|
||||
for provider in providers_to_run:
|
||||
run_tests(provider.strip(), keyword=keyword)
|
||||
print("Finished running tests.")
|
||||
|
||||
|
||||
def parse_results(
|
||||
result_file,
|
||||
) -> Tuple[DefaultDict[str, DefaultDict[str, Dict[str, bool]]], DefaultDict[str, Set[str]], Set[str], str]:
|
||||
) -> tuple[defaultdict[str, defaultdict[str, dict[str, bool]]], defaultdict[str, set[str]], set[str], str]:
|
||||
"""Parse a single test results file.
|
||||
|
||||
Returns:
|
||||
|
|
@ -159,13 +156,13 @@ def parse_results(
|
|||
# Return empty defaultdicts/set matching the type hint
|
||||
return defaultdict(lambda: defaultdict(dict)), defaultdict(set), set(), ""
|
||||
|
||||
with open(result_file, "r") as f:
|
||||
with open(result_file) as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Initialize results dictionary with specific types
|
||||
parsed_results: DefaultDict[str, DefaultDict[str, Dict[str, bool]]] = defaultdict(lambda: defaultdict(dict))
|
||||
providers_in_file: DefaultDict[str, Set[str]] = defaultdict(set)
|
||||
tests_in_file: Set[str] = set()
|
||||
parsed_results: defaultdict[str, defaultdict[str, dict[str, bool]]] = defaultdict(lambda: defaultdict(dict))
|
||||
providers_in_file: defaultdict[str, set[str]] = defaultdict(set)
|
||||
tests_in_file: set[str] = set()
|
||||
# Extract provider from filename (e.g., "openai.json" -> "openai")
|
||||
provider: str = result_file.stem
|
||||
|
||||
|
|
@ -250,25 +247,11 @@ def parse_results(
|
|||
return parsed_results, providers_in_file, tests_in_file, run_timestamp_str
|
||||
|
||||
|
||||
def get_all_result_files_by_provider():
|
||||
"""Get all test result files, keyed by provider."""
|
||||
provider_results = {}
|
||||
|
||||
result_files = list(RESULTS_DIR.glob("*.json"))
|
||||
|
||||
for file in result_files:
|
||||
provider = file.stem
|
||||
if provider:
|
||||
provider_results[provider] = file
|
||||
|
||||
return provider_results
|
||||
|
||||
|
||||
def generate_report(
|
||||
results_dict: Dict[str, Any],
|
||||
providers: Dict[str, Set[str]],
|
||||
all_tests: Set[str],
|
||||
provider_timestamps: Dict[str, str],
|
||||
results_dict: dict[str, Any],
|
||||
providers: dict[str, set[str]],
|
||||
all_tests: set[str],
|
||||
provider_timestamps: dict[str, str],
|
||||
output_file=None,
|
||||
):
|
||||
"""Generate the markdown report.
|
||||
|
|
@ -276,6 +259,7 @@ def generate_report(
|
|||
Args:
|
||||
results_dict: Aggregated results [provider][model][test_name] -> status.
|
||||
providers: Dict of all providers and their models {provider: {models}}.
|
||||
The order of keys in this dict determines the report order.
|
||||
all_tests: Set of all test names found.
|
||||
provider_timestamps: Dict of provider to timestamp when tests were run
|
||||
output_file: Optional path to save the report.
|
||||
|
|
@ -293,8 +277,8 @@ def generate_report(
|
|||
sorted_tests = sorted(all_tests)
|
||||
|
||||
# Calculate counts for each base test name
|
||||
base_test_case_counts: DefaultDict[str, int] = defaultdict(int)
|
||||
base_test_name_map: Dict[str, str] = {}
|
||||
base_test_case_counts: defaultdict[str, int] = defaultdict(int)
|
||||
base_test_name_map: dict[str, str] = {}
|
||||
for test_name in sorted_tests:
|
||||
match = re.match(r"^(.*?)( \([^)]+\))?$", test_name)
|
||||
if match:
|
||||
|
|
@ -353,22 +337,17 @@ def generate_report(
|
|||
passed_tests += 1
|
||||
provider_totals[provider] = (provider_passed, provider_total)
|
||||
|
||||
# Add summary table (use passed-in providers dict)
|
||||
# Add summary table (use the order from the providers dict keys)
|
||||
report.append("| Provider | Pass Rate | Tests Passed | Total Tests |")
|
||||
report.append("| --- | --- | --- | --- |")
|
||||
for provider in [p for p in PROVIDER_ORDER if p in providers]: # Check against keys of passed-in dict
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
for provider in [p for p in providers if p not in PROVIDER_ORDER]: # Check against keys of passed-in dict
|
||||
# Iterate through providers in the order they appear in the input dict
|
||||
for provider in providers_sorted.keys():
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
report.append("\n")
|
||||
|
||||
for provider in sorted(
|
||||
providers_sorted.keys(), key=lambda p: (PROVIDER_ORDER.index(p) if p in PROVIDER_ORDER else float("inf"), p)
|
||||
):
|
||||
for provider in providers_sorted.keys():
|
||||
provider_models = providers_sorted[provider] # Use sorted models
|
||||
if not provider_models:
|
||||
continue
|
||||
|
|
@ -461,60 +440,62 @@ def main():
|
|||
"--providers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Specify providers to test (comma-separated or space-separated, default: all)",
|
||||
help="Specify providers to include/test (comma-separated or space-separated, default: uses DEFAULT_PROVIDERS)",
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output file location (default: tests/verifications/REPORT.md)")
|
||||
parser.add_argument("--k", type=str, help="Keyword expression to filter tests (passed to pytest -k)")
|
||||
args = parser.parse_args()
|
||||
|
||||
all_results = {}
|
||||
# Initialize collections to aggregate results in main
|
||||
aggregated_providers = defaultdict(set)
|
||||
final_providers_order = {} # Dictionary to store results, preserving processing order
|
||||
aggregated_tests = set()
|
||||
provider_timestamps = {}
|
||||
|
||||
if args.run_tests:
|
||||
# Get list of available providers from command line or use detected providers
|
||||
if args.providers:
|
||||
# Handle both comma-separated and space-separated lists
|
||||
test_providers = []
|
||||
for provider_arg in args.providers:
|
||||
# Split by comma if commas are present
|
||||
if "," in provider_arg:
|
||||
test_providers.extend(provider_arg.split(","))
|
||||
else:
|
||||
test_providers.append(provider_arg)
|
||||
else:
|
||||
# Default providers to test
|
||||
test_providers = PROVIDER_ORDER
|
||||
|
||||
for provider in test_providers:
|
||||
provider = provider.strip() # Remove any whitespace
|
||||
result_file = run_tests(provider, keyword=args.k)
|
||||
if result_file:
|
||||
# Parse and aggregate results
|
||||
parsed_results, providers_in_file, tests_in_file, run_timestamp = parse_results(result_file)
|
||||
all_results.update(parsed_results)
|
||||
for prov, models in providers_in_file.items():
|
||||
aggregated_providers[prov].update(models)
|
||||
if run_timestamp:
|
||||
provider_timestamps[prov] = run_timestamp
|
||||
aggregated_tests.update(tests_in_file)
|
||||
# 1. Determine the desired list and order of providers
|
||||
if args.providers:
|
||||
desired_providers = []
|
||||
for provider_arg in args.providers:
|
||||
desired_providers.extend([p.strip() for p in provider_arg.split(",")])
|
||||
else:
|
||||
# Use existing results
|
||||
provider_result_files = get_all_result_files_by_provider()
|
||||
desired_providers = DEFAULT_PROVIDERS # Use default order/list
|
||||
|
||||
for result_file in provider_result_files.values():
|
||||
# Parse and aggregate results
|
||||
parsed_results, providers_in_file, tests_in_file, run_timestamp = parse_results(result_file)
|
||||
all_results.update(parsed_results)
|
||||
for prov, models in providers_in_file.items():
|
||||
aggregated_providers[prov].update(models)
|
||||
if run_timestamp:
|
||||
provider_timestamps[prov] = run_timestamp
|
||||
aggregated_tests.update(tests_in_file)
|
||||
# 2. Run tests if requested (using the desired provider list)
|
||||
if args.run_tests:
|
||||
run_multiple_tests(desired_providers, args.k)
|
||||
|
||||
generate_report(all_results, aggregated_providers, aggregated_tests, provider_timestamps, args.output)
|
||||
for provider in desired_providers:
|
||||
# Construct the expected result file path directly
|
||||
result_file = RESULTS_DIR / f"{provider}.json"
|
||||
|
||||
if result_file.exists(): # Check if the specific file exists
|
||||
print(f"Loading results for {provider} from {result_file}")
|
||||
try:
|
||||
parsed_data = parse_results(result_file)
|
||||
parsed_results, providers_in_file, tests_in_file, run_timestamp = parsed_data
|
||||
all_results.update(parsed_results)
|
||||
aggregated_tests.update(tests_in_file)
|
||||
|
||||
# Add models for this provider, ensuring it's added in the correct report order
|
||||
if provider in providers_in_file:
|
||||
if provider not in final_providers_order:
|
||||
final_providers_order[provider] = set()
|
||||
final_providers_order[provider].update(providers_in_file[provider])
|
||||
if run_timestamp != "Unknown":
|
||||
provider_timestamps[provider] = run_timestamp
|
||||
else:
|
||||
print(
|
||||
f"Warning: Provider '{provider}' found in desired list but not within its result file data ({result_file})."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing results for provider {provider} from {result_file}: {e}")
|
||||
else:
|
||||
# Only print warning if we expected results (i.e., provider was in the desired list)
|
||||
print(f"Result file for desired provider '{provider}' not found at {result_file}. Skipping.")
|
||||
|
||||
# 5. Generate the report using the filtered & ordered results
|
||||
print(f"Final Provider Order for Report: {list(final_providers_order.keys())}")
|
||||
generate_report(all_results, final_providers_order, aggregated_tests, provider_timestamps, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
# This is a temporary run file because model names used by the verification tests
|
||||
# are not quite consistent with various pre-existing distributions.
|
||||
#
|
||||
version: '2'
|
||||
image_name: openai-api-verification
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: together
|
||||
|
|
@ -16,12 +21,12 @@ providers:
|
|||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
api_key: ${env.FIREWORKS_API_KEY:}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
api_key: ${env.GROQ_API_KEY:}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -44,7 +49,20 @@ providers:
|
|||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai-api-verification}/trace_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/agents_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
|||
35
tests/verifications/openai_api/conftest.py
Normal file
35
tests/verifications/openai_api/conftest.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Dynamically parametrize tests based on the selected provider and config."""
|
||||
if "model" in metafunc.fixturenames:
|
||||
provider = metafunc.config.getoption("provider")
|
||||
if not provider:
|
||||
print("Warning: --provider not specified. Skipping model parametrization.")
|
||||
metafunc.parametrize("model", [])
|
||||
return
|
||||
|
||||
try:
|
||||
config_data = _load_all_verification_configs()
|
||||
except (OSError, FileNotFoundError) as e:
|
||||
print(f"ERROR loading verification configs: {e}")
|
||||
config_data = {"providers": {}}
|
||||
|
||||
provider_config = config_data.get("providers", {}).get(provider)
|
||||
if provider_config:
|
||||
models = provider_config.get("models", [])
|
||||
if models:
|
||||
metafunc.parametrize("model", models)
|
||||
else:
|
||||
print(f"Warning: No models found for provider '{provider}' in config.")
|
||||
metafunc.parametrize("model", []) # Parametrize empty if no models found
|
||||
else:
|
||||
print(f"Warning: Provider '{provider}' not found in config. No models parametrized.")
|
||||
metafunc.parametrize("model", []) # Parametrize empty if provider not found
|
||||
|
|
@ -5,14 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from openai import OpenAI
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
|
||||
# --- Helper Function to Load Config ---
|
||||
def _load_all_verification_configs():
|
||||
"""Load and aggregate verification configs from the conf/ directory."""
|
||||
# Note: Path is relative to *this* file (fixtures.py)
|
||||
|
|
@ -31,7 +33,7 @@ def _load_all_verification_configs():
|
|||
for config_path in yaml_files:
|
||||
provider_name = config_path.stem
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
provider_config = yaml.safe_load(f)
|
||||
if provider_config:
|
||||
all_provider_configs[provider_name] = provider_config
|
||||
|
|
@ -39,12 +41,35 @@ def _load_all_verification_configs():
|
|||
# Log warning if possible, or just skip empty files silently
|
||||
print(f"Warning: Config file {config_path} is empty or invalid.")
|
||||
except Exception as e:
|
||||
raise IOError(f"Error loading config file {config_path}: {e}") from e
|
||||
raise OSError(f"Error loading config file {config_path}: {e}") from e
|
||||
|
||||
return {"providers": all_provider_configs}
|
||||
|
||||
|
||||
# --- End Helper Function ---
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
case_id = case.get("case_id")
|
||||
if isinstance(case_id, str | int):
|
||||
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
|
||||
return None
|
||||
|
||||
|
||||
def should_skip_test(verification_config, provider, model, test_name_base):
|
||||
"""Check if a test should be skipped based on config exclusions."""
|
||||
provider_config = verification_config.get("providers", {}).get(provider)
|
||||
if not provider_config:
|
||||
return False # No config for provider, don't skip
|
||||
|
||||
exclusions = provider_config.get("test_exclusions", {}).get(model, [])
|
||||
return test_name_base in exclusions
|
||||
|
||||
|
||||
# Helper to get the base test name from the request object
|
||||
def get_base_test_name(request):
|
||||
return request.node.originalname
|
||||
|
||||
|
||||
# --- End Helper Functions ---
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -52,7 +77,7 @@ def verification_config():
|
|||
"""Pytest fixture to provide the loaded verification config."""
|
||||
try:
|
||||
return _load_all_verification_configs()
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
except (OSError, FileNotFoundError) as e:
|
||||
pytest.fail(str(e)) # Fail test collection if config loading fails
|
||||
|
||||
|
||||
|
|
|
|||
BIN
tests/verifications/openai_api/fixtures/images/vision_test_1.jpg
Normal file
BIN
tests/verifications/openai_api/fixtures/images/vision_test_1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/verifications/openai_api/fixtures/images/vision_test_2.jpg
Normal file
BIN
tests/verifications/openai_api/fixtures/images/vision_test_2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 148 KiB |
BIN
tests/verifications/openai_api/fixtures/images/vision_test_3.jpg
Normal file
BIN
tests/verifications/openai_api/fixtures/images/vision_test_3.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 139 KiB |
|
|
@ -12,5 +12,5 @@ import yaml
|
|||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path, "r") as f:
|
||||
with open(yaml_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,52 @@ test_chat_basic:
|
|||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
test_chat_input_validation:
|
||||
test_name: test_chat_input_validation
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "messages_missing"
|
||||
input:
|
||||
messages: []
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "messages_role_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: fake_role
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_no_tools"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: required
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tools_type_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tools:
|
||||
- type: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
|
|
@ -131,3 +177,221 @@ test_tool_calling:
|
|||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
|
||||
test_chat_multi_turn_tool_calling:
|
||||
test_name: test_chat_multi_turn_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "text_then_weather_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the name of the Sun in latin?"
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 0
|
||||
answer: ["sol"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "weather_tool_then_text"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "add_product_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
|
||||
tools:
|
||||
- function:
|
||||
description: Add a new product
|
||||
name: addProduct
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the product"
|
||||
type: string
|
||||
price:
|
||||
description: "Price of the product"
|
||||
type: number
|
||||
inStock:
|
||||
description: "Availability status of the product."
|
||||
type: boolean
|
||||
tags:
|
||||
description: "List of product tags"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "price", "inStock"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Successfully added product with id: 123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: addProduct
|
||||
tool_arguments:
|
||||
name: "Widget"
|
||||
price: 19.99
|
||||
inStock: true
|
||||
tags:
|
||||
- "new"
|
||||
- "sale"
|
||||
- num_tool_calls: 0
|
||||
answer: ["123", "product id: 123"]
|
||||
- case_id: "get_then_create_event_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
|
||||
- - role: user
|
||||
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
|
||||
tools:
|
||||
- function:
|
||||
description: Create a new event
|
||||
name: create_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the event"
|
||||
type: string
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
location:
|
||||
description: "Location of the event"
|
||||
type: string
|
||||
participants:
|
||||
description: "List of participant names"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "date", "time", "location", "participants"]
|
||||
type: function
|
||||
- function:
|
||||
description: Get an event by date and time
|
||||
name: get_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
required: ["date", "time"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
|
||||
- response: "{'response': 'Successfully created new event with id: e_123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_event
|
||||
tool_arguments:
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
- num_tool_calls: 0
|
||||
answer: ["no", "no events found", "no meetings"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: create_event
|
||||
tool_arguments:
|
||||
name: "Team Building"
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
location: "Main Conference Room"
|
||||
participants:
|
||||
- "Alice"
|
||||
- "Bob"
|
||||
- "Charlie"
|
||||
- num_tool_calls: 0
|
||||
answer: ["e_123", "event id: e_123"]
|
||||
- case_id: "compare_monthly_expense_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "what was my monthly expense in Jan of this year?"
|
||||
- - role: user
|
||||
content: "Was it less than Feb of last year? Only answer with yes or no."
|
||||
tools:
|
||||
- function:
|
||||
description: Get monthly expense summary
|
||||
name: getMonthlyExpenseSummary
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
month:
|
||||
description: "Month of the year (1-12)"
|
||||
type: integer
|
||||
year:
|
||||
description: "Year"
|
||||
type: integer
|
||||
required: ["month", "year"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Total expenses for January 2025: $1000'}"
|
||||
- response: "{'response': 'Total expenses for February 2024: $2000'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 1
|
||||
year: 2025
|
||||
- num_tool_calls: 0
|
||||
answer: ["1000", "$1,000", "1,000"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 2
|
||||
year: 2024
|
||||
- num_tool_calls: 0
|
||||
answer: ["yes"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
test_response_basic:
|
||||
test_name: test_response_basic
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
input: "Which planet do humans live on?"
|
||||
output: "earth"
|
||||
- case_id: "saturn"
|
||||
input: "Which planet has rings around it with a name starting with letter S?"
|
||||
output: "saturn"
|
||||
|
||||
test_response_multi_turn:
|
||||
test_name: test_response_multi_turn
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
turns:
|
||||
- input: "Which planet do humans live on?"
|
||||
output: "earth"
|
||||
- input: "What is the name of the planet from your previous response?"
|
||||
output: "earth"
|
||||
|
||||
test_response_web_search:
|
||||
test_name: test_response_web_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: web_search
|
||||
search_context_size: "low"
|
||||
output: "128"
|
||||
|
||||
test_response_image:
|
||||
test_name: test_response_image
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_image"
|
||||
input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "Identify the type of animal in this image."
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||
output: "llama"
|
||||
|
||||
test_response_multi_turn_image:
|
||||
test_name: test_response_multi_turn_image
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_image_search"
|
||||
turns:
|
||||
- input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'."
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||
output: "llama"
|
||||
- input: "Search the web using the search tool for the animal from the previous response. Your search query should be a single phrase that includes the animal's name and the words 'maverick' and 'scout'."
|
||||
tools:
|
||||
- type: web_search
|
||||
output: "model"
|
||||
|
|
@ -4,68 +4,41 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from openai import APIError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
|
||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||
case_id_generator,
|
||||
get_base_test_name,
|
||||
should_skip_test,
|
||||
)
|
||||
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
||||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
case_id = case.get("case_id")
|
||||
if isinstance(case_id, (str, int)):
|
||||
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
|
||||
return None
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Dynamically parametrize tests based on the selected provider and config."""
|
||||
if "model" in metafunc.fixturenames:
|
||||
provider = metafunc.config.getoption("provider")
|
||||
if not provider:
|
||||
print("Warning: --provider not specified. Skipping model parametrization.")
|
||||
metafunc.parametrize("model", [])
|
||||
return
|
||||
|
||||
try:
|
||||
config_data = _load_all_verification_configs()
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f"ERROR loading verification configs: {e}")
|
||||
config_data = {"providers": {}}
|
||||
|
||||
provider_config = config_data.get("providers", {}).get(provider)
|
||||
if provider_config:
|
||||
models = provider_config.get("models", [])
|
||||
if models:
|
||||
metafunc.parametrize("model", models)
|
||||
else:
|
||||
print(f"Warning: No models found for provider '{provider}' in config.")
|
||||
metafunc.parametrize("model", []) # Parametrize empty if no models found
|
||||
else:
|
||||
print(f"Warning: Provider '{provider}' not found in config. No models parametrized.")
|
||||
metafunc.parametrize("model", []) # Parametrize empty if provider not found
|
||||
|
||||
|
||||
def should_skip_test(verification_config, provider, model, test_name_base):
|
||||
"""Check if a test should be skipped based on config exclusions."""
|
||||
provider_config = verification_config.get("providers", {}).get(provider)
|
||||
if not provider_config:
|
||||
return False # No config for provider, don't skip
|
||||
|
||||
exclusions = provider_config.get("test_exclusions", {}).get(model, [])
|
||||
return test_name_base in exclusions
|
||||
|
||||
|
||||
# Helper to get the base test name from the request object
|
||||
def get_base_test_name(request):
|
||||
return request.node.originalname
|
||||
@pytest.fixture
|
||||
def multi_image_data():
|
||||
files = [
|
||||
THIS_DIR / "fixtures/images/vision_test_1.jpg",
|
||||
THIS_DIR / "fixtures/images/vision_test_2.jpg",
|
||||
THIS_DIR / "fixtures/images/vision_test_3.jpg",
|
||||
]
|
||||
encoded_files = []
|
||||
for file in files:
|
||||
with open(file, "rb") as image_file:
|
||||
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
encoded_files.append(f"data:image/jpeg;base64,{base64_data}")
|
||||
return encoded_files
|
||||
|
||||
|
||||
# --- Test Functions ---
|
||||
|
|
@ -114,6 +87,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat
|
|||
assert case["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with pytest.raises(APIError) as e:
|
||||
openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=False,
|
||||
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
||||
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
||||
)
|
||||
assert case["output"]["error"]["status_code"] == e.value.status_code
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with pytest.raises(APIError) as e:
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=True,
|
||||
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
||||
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
||||
)
|
||||
for _chunk in response:
|
||||
pass
|
||||
assert str(case["output"]["error"]["status_code"]) in e.value.message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
||||
|
|
@ -243,43 +260,373 @@ def test_chat_streaming_tool_calling(request, openai_client, model, provider, ve
|
|||
stream=True,
|
||||
)
|
||||
|
||||
# Accumulate partial tool_calls here
|
||||
tool_calls_buffer = {}
|
||||
current_id = None
|
||||
# Process streaming chunks
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
if delta.tool_calls is None:
|
||||
continue
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
if tool_call_delta.id:
|
||||
current_id = tool_call_delta.id
|
||||
call_id = current_id
|
||||
func_delta = tool_call_delta.function
|
||||
|
||||
if call_id not in tool_calls_buffer:
|
||||
tool_calls_buffer[call_id] = {
|
||||
"id": call_id,
|
||||
"type": tool_call_delta.type,
|
||||
"name": func_delta.name,
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if func_delta.arguments:
|
||||
tool_calls_buffer[call_id]["arguments"] += func_delta.arguments
|
||||
|
||||
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
|
||||
assert len(tool_calls_buffer) == 1
|
||||
for call in tool_calls_buffer.values():
|
||||
for call in tool_calls_buffer:
|
||||
assert len(call["id"]) > 0
|
||||
assert call["name"] == "get_weather"
|
||||
function = call["function"]
|
||||
assert function["name"] == "get_weather"
|
||||
|
||||
args_dict = json.loads(call["arguments"])
|
||||
args_dict = json.loads(function["arguments"])
|
||||
assert "san francisco" in args_dict["location"].lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="required", # Force tool call
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
|
||||
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
|
||||
assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
stream = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="required", # Force tool call
|
||||
stream=True,
|
||||
)
|
||||
|
||||
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
|
||||
|
||||
assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'"
|
||||
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
|
||||
assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), (
|
||||
f"Expected tool call '{expected_tool_name}' not found in stream"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="none",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'"
|
||||
assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
stream = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="none",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
content = ""
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
content += delta.content
|
||||
assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'"
|
||||
|
||||
assert len(content) > 0, "Expected content when tool_choice='none'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
|
||||
"""
|
||||
Test cases for multi-turn tool calling.
|
||||
Tool calls are asserted.
|
||||
Tool responses are provided in the test case.
|
||||
Final response is asserted.
|
||||
"""
|
||||
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
# Create a copy of the messages list to avoid modifying the original
|
||||
messages = []
|
||||
tools = case["input"]["tools"]
|
||||
# Use deepcopy to prevent modification across runs/parametrization
|
||||
expected_results = copy.deepcopy(case["expected"])
|
||||
tool_responses = copy.deepcopy(case.get("tool_responses", []))
|
||||
input_messages_turns = copy.deepcopy(case["input"]["messages"])
|
||||
|
||||
# keep going until either
|
||||
# 1. we have messages to test in multi-turn
|
||||
# 2. no messages but last message is tool response
|
||||
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
||||
# do not take new messages if last message is tool response
|
||||
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
||||
new_messages = input_messages_turns.pop(0)
|
||||
# Ensure new_messages is a list of message objects
|
||||
if isinstance(new_messages, list):
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# If it's a single message object, add it directly
|
||||
messages.append(new_messages)
|
||||
|
||||
# --- API Call ---
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# --- Process Response ---
|
||||
assistant_message = response.choices[0].message
|
||||
messages.append(assistant_message.model_dump(exclude_unset=True))
|
||||
|
||||
assert assistant_message.role == "assistant"
|
||||
|
||||
# Get the expected result data
|
||||
expected = expected_results.pop(0)
|
||||
num_tool_calls = expected["num_tool_calls"]
|
||||
|
||||
# --- Assertions based on expected result ---
|
||||
assert len(assistant_message.tool_calls or []) == num_tool_calls, (
|
||||
f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}"
|
||||
)
|
||||
|
||||
if num_tool_calls > 0:
|
||||
tool_call = assistant_message.tool_calls[0]
|
||||
assert tool_call.function.name == expected["tool_name"], (
|
||||
f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'"
|
||||
)
|
||||
# Parse the JSON string arguments before comparing
|
||||
actual_arguments = json.loads(tool_call.function.arguments)
|
||||
assert actual_arguments == expected["tool_arguments"], (
|
||||
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
|
||||
)
|
||||
|
||||
# Prepare and append the tool response for the next turn
|
||||
tool_response = tool_responses.pop(0)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_response["response"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
assert assistant_message.content is not None, "Expected content, but none received."
|
||||
expected_answers = expected["answer"] # This is now a list
|
||||
content_lower = assistant_message.content.lower()
|
||||
assert any(ans.lower() in content_lower for ans in expected_answers), (
|
||||
f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
|
||||
""" """
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
messages = []
|
||||
tools = case["input"]["tools"]
|
||||
expected_results = copy.deepcopy(case["expected"])
|
||||
tool_responses = copy.deepcopy(case.get("tool_responses", []))
|
||||
input_messages_turns = copy.deepcopy(case["input"]["messages"])
|
||||
|
||||
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
||||
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
||||
new_messages = input_messages_turns.pop(0)
|
||||
if isinstance(new_messages, list):
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
messages.append(new_messages)
|
||||
|
||||
# --- API Call (Streaming) ---
|
||||
stream = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# --- Process Stream ---
|
||||
accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream)
|
||||
|
||||
# --- Construct Assistant Message for History ---
|
||||
assistant_message_dict = {"role": "assistant"}
|
||||
if accumulated_content:
|
||||
assistant_message_dict["content"] = accumulated_content
|
||||
if accumulated_tool_calls:
|
||||
assistant_message_dict["tool_calls"] = accumulated_tool_calls
|
||||
|
||||
messages.append(assistant_message_dict)
|
||||
|
||||
# --- Assertions ---
|
||||
expected = expected_results.pop(0)
|
||||
num_tool_calls = expected["num_tool_calls"]
|
||||
|
||||
assert len(accumulated_tool_calls or []) == num_tool_calls, (
|
||||
f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}"
|
||||
)
|
||||
|
||||
if num_tool_calls > 0:
|
||||
# Use the first accumulated tool call for assertion
|
||||
tool_call = accumulated_tool_calls[0]
|
||||
assert tool_call["function"]["name"] == expected["tool_name"], (
|
||||
f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'"
|
||||
)
|
||||
# Parse the accumulated arguments string for comparison
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert actual_arguments == expected["tool_arguments"], (
|
||||
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
|
||||
)
|
||||
|
||||
# Prepare and append the tool response for the next turn
|
||||
tool_response = tool_responses.pop(0)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": tool_response["response"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received."
|
||||
expected_answers = expected["answer"]
|
||||
content_lower = accumulated_content.lower()
|
||||
assert any(ans.lower() in content_lower for ans in expected_answers), (
|
||||
f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [False, True], ids=["stream=False", "stream=True"])
|
||||
def test_chat_multi_turn_multiple_images(
|
||||
request, openai_client, model, provider, verification_config, multi_image_data, stream
|
||||
):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
messages_turn1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What furniture is in the first image that is not in the second image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# First API call
|
||||
response1 = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages_turn1,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content1 = ""
|
||||
for chunk in response1:
|
||||
message_content1 += chunk.choices[0].delta.content or ""
|
||||
else:
|
||||
message_content1 = response1.choices[0].message.content
|
||||
assert len(message_content1) > 0
|
||||
assert any(expected in message_content1.lower().strip() for expected in {"chair", "table"}), message_content1
|
||||
|
||||
# Prepare messages for the second turn
|
||||
messages_turn2 = messages_turn1 + [
|
||||
{"role": "assistant", "content": message_content1},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[2],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is in this image that is also in the first image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Second API call
|
||||
response2 = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages_turn2,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content2 = ""
|
||||
for chunk in response2:
|
||||
message_content2 += chunk.choices[0].delta.content or ""
|
||||
else:
|
||||
message_content2 = response2.choices[0].message.content
|
||||
assert len(message_content2) > 0
|
||||
assert any(expected in message_content2.lower().strip() for expected in {"bed"}), message_content2
|
||||
|
||||
|
||||
# --- Helper functions (structured output validation) ---
|
||||
|
||||
|
||||
|
|
@ -324,3 +671,47 @@ def validate_structured_output(maybe_json_content: str, schema_name: str) -> Non
|
|||
assert len(structured_output.participants) == 2
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
assert len(structured_output.final_answer) > 0
|
||||
|
||||
|
||||
def _accumulate_streaming_tool_calls(stream):
|
||||
"""Accumulates tool calls and content from a streaming ChatCompletion response."""
|
||||
tool_calls_buffer = {}
|
||||
current_id = None
|
||||
full_content = "" # Initialize content accumulator
|
||||
# Process streaming chunks
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Accumulate content
|
||||
if delta.content:
|
||||
full_content += delta.content
|
||||
|
||||
if delta.tool_calls is None:
|
||||
continue
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
if tool_call_delta.id:
|
||||
current_id = tool_call_delta.id
|
||||
call_id = current_id
|
||||
# Skip if no ID seen yet for this tool call delta
|
||||
if not call_id:
|
||||
continue
|
||||
func_delta = tool_call_delta.function
|
||||
|
||||
if call_id not in tool_calls_buffer:
|
||||
tool_calls_buffer[call_id] = {
|
||||
"id": call_id,
|
||||
"type": "function", # Assume function type
|
||||
"function": {"name": None, "arguments": ""}, # Nested structure
|
||||
}
|
||||
|
||||
# Accumulate name and arguments into the nested function dict
|
||||
if func_delta:
|
||||
if func_delta.name:
|
||||
tool_calls_buffer[call_id]["function"]["name"] = func_delta.name
|
||||
if func_delta.arguments:
|
||||
tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments
|
||||
|
||||
# Return content and tool calls as a list
|
||||
return full_content, list(tool_calls_buffer.values())
|
||||
|
|
|
|||
166
tests/verifications/openai_api/test_responses.py
Normal file
166
tests/verifications/openai_api/test_responses.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||
case_id_generator,
|
||||
get_base_test_name,
|
||||
should_skip_test,
|
||||
)
|
||||
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
||||
|
||||
responses_test_cases = load_test_cases("responses")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_basic(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
stream=False,
|
||||
)
|
||||
output_text = response.output_text.lower().strip()
|
||||
assert len(output_text) > 0
|
||||
assert case["output"].lower() in output_text
|
||||
|
||||
retrieved_response = openai_client.responses.retrieve(response_id=response.id)
|
||||
assert retrieved_response.output_text == response.output_text
|
||||
|
||||
next_response = openai_client.responses.create(
|
||||
model=model, input="Repeat your previous response in all caps.", previous_response_id=response.id
|
||||
)
|
||||
next_output_text = next_response.output_text.strip()
|
||||
assert case["output"].upper() in next_output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_streaming_basic(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
stream=True,
|
||||
)
|
||||
streamed_content = []
|
||||
response_id = ""
|
||||
for chunk in response:
|
||||
if chunk.type == "response.completed":
|
||||
response_id = chunk.response.id
|
||||
streamed_content.append(chunk.response.output_text.strip())
|
||||
|
||||
assert len(streamed_content) > 0
|
||||
assert case["output"].lower() in "".join(streamed_content).lower()
|
||||
|
||||
retrieved_response = openai_client.responses.retrieve(response_id=response_id)
|
||||
assert retrieved_response.output_text == "".join(streamed_content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
previous_response_id = None
|
||||
for turn in case["turns"]:
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=turn["input"],
|
||||
previous_response_id=previous_response_id,
|
||||
tools=turn["tools"] if "tools" in turn else None,
|
||||
)
|
||||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_web_search"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_web_search(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
tools=case["tools"],
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "message"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[1].role == "assistant"
|
||||
assert len(response.output[1].content) > 0
|
||||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_image"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_image(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
stream=False,
|
||||
)
|
||||
output_text = response.output_text.lower()
|
||||
assert case["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_image"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn_image(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
previous_response_id = None
|
||||
for turn in case["turns"]:
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=turn["input"],
|
||||
previous_response_id=previous_response_id,
|
||||
tools=turn["tools"] if "tools" in turn else None,
|
||||
)
|
||||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
File diff suppressed because one or more lines are too long
1097
tests/verifications/test_results/meta_reference.json
Normal file
1097
tests/verifications/test_results/meta_reference.json
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue