Merge branch 'main' into remove-deprecated-embeddings

This commit is contained in:
Matthew Farrellee 2025-09-27 15:01:32 -04:00
commit 5c44dcdf0e
770 changed files with 176834 additions and 27431 deletions

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -105,6 +106,9 @@ class ScoringFunctionsImpl(Impl):
async def register_scoring_function(self, scoring_fn):
return scoring_fn
async def unregister_scoring_function(self, scoring_fn_id: str):
return scoring_fn_id
class BenchmarksImpl(Impl):
def __init__(self):
@ -113,6 +117,9 @@ class BenchmarksImpl(Impl):
async def register_benchmark(self, benchmark):
return benchmark
async def unregister_benchmark(self, benchmark_id: str):
return benchmark_id
class ToolGroupsImpl(Impl):
def __init__(self):
@ -146,6 +153,20 @@ class VectorDBImpl(Impl):
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_create_vector_store(self, **kwargs):
import time
import uuid
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
return VectorStoreObject(
id=vector_store_id,
name=kwargs.get("name", vector_store_id),
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
@ -247,17 +268,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
assert vdb1.identifier in vector_db_ids
assert vdb2.identifier in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
# Verify they have UUID-based identifiers
assert vdb1.identifier.startswith("vs_")
assert vdb2.identifier.startswith("vs_")
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
@ -312,6 +337,13 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
assert "test-scoring-fn" in scoring_fn_ids
assert "test-scoring-fn-2" in scoring_fn_ids
# Unregister scoring functions and verify listing
for i in range(len(scoring_functions.data)):
await table.unregister_scoring_function(scoring_functions.data[i].scoring_fn_id)
scoring_functions_list_after_deletion = await table.list_scoring_functions()
assert len(scoring_functions_list_after_deletion.data) == 0
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
@ -329,6 +361,15 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
benchmark_ids = {b.identifier for b in benchmarks.data}
assert "test-benchmark" in benchmark_ids
# Unregister the benchmark and verify removal
await table.unregister_benchmark(benchmark_id="test-benchmark")
benchmarks_after = await table.list_benchmarks()
assert len(benchmarks_after.data) == 0
# Unregistering a non-existent benchmark should raise a clear error
with pytest.raises(ValueError, match="Benchmark 'dummy_benchmark' not found"):
await table.unregister_benchmark(benchmark_id="dummy_benchmark")
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
@ -605,3 +646,25 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
# Cleanup
await table.shutdown()
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
exception_throwing_tool_groups_impl = ToolGroupsImpl()
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
table = ToolGroupsRoutingTable(
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
)
await table.initialize()
await table.register_tool_group(
toolgroup_id="test-toolgroup-exceptions",
provider_id="test_provider",
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
)
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
assert len(tools.data) == 0

View file

@ -7,6 +7,7 @@
# Unit tests for the routing tables vector_dbs
import time
import uuid
from unittest.mock import AsyncMock
import pytest
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
for k in vdb_dict:
assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader
with request_provider_data_context({}, authorized_user):
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
)
with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),

View file

@ -12,7 +12,7 @@ import yaml
from pydantic import BaseModel, Field, ValidationError
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
from llama_stack.providers.datatypes import ProviderSpec
@ -66,10 +66,9 @@ def base_config(tmp_path):
def provider_spec_yaml():
"""Common provider spec YAML for testing."""
return """
adapter:
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
api_dependencies:
- safety
"""
@ -152,6 +151,24 @@ class TestProviderRegistry:
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
assert registry[Api.inference]["test_provider"].api == Api.inference
def test_internal_apis_excluded(self):
"""Test that internal APIs are excluded and APIs without provider registries are marked as internal."""
import importlib
apis = providable_apis()
for internal_api in INTERNAL_APIS:
assert internal_api not in apis, f"Internal API {internal_api} should not be in providable_apis"
for api in apis:
module_name = f"llama_stack.providers.registry.{api.name.lower()}"
try:
importlib.import_module(module_name)
except ImportError as err:
raise AssertionError(
f"API {api} is in providable_apis but has no provider registry module ({module_name})"
) from err
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
"""Test loading external remote providers from YAML files."""
remote_dir, _ = api_directories
@ -164,9 +181,9 @@ class TestProviderRegistry:
assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
assert provider.adapter_type == "test_provider"
assert provider.module == "test_provider"
assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
@ -228,8 +245,7 @@ class TestProviderRegistry:
"""Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories
malformed_spec = """
adapter:
adapter_type: test_provider
adapter_type: test_provider
# Missing required fields
api_dependencies:
- safety
@ -252,7 +268,7 @@ pip_packages:
with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info:
with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)

View file

@ -6,16 +6,18 @@
import tempfile
from pathlib import Path
from unittest.mock import patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from openai import AsyncOpenAI
from openai import NOT_GIVEN, AsyncOpenAI
from openai.types.model import Model as OpenAIModel
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
@ -153,24 +155,22 @@ class TestInferenceRecording:
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
client.chat.completions._post.assert_called_once()
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
@ -178,40 +178,74 @@ class TestInferenceRecording:
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
client.chat.completions._post.assert_called_once()
# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
mock_create_patch.assert_not_called()
# Verify the original method was NOT called
client.chat.completions._post.assert_not_called()
async def test_replay_mode_models(self, temp_storage_dir):
"""Test that replay mode returns stored responses without making real model listing calls."""
async def _async_iterator(models):
for model in models:
yield model
models = [
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
]
expected_ids = {m.id for m in models}
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once()
# record the call
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once()
# replay the call
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
@ -228,36 +262,110 @@ class TestInferenceRecording:
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
async def mock_create(*args, **kwargs):
return real_embeddings_response
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
client.embeddings._post.assert_called_once()
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN,
user=NOT_GIVEN,
)
assert len(response.data) == 2
assert len(response.data) == 2
# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
mock_create_patch.assert_not_called()
# Verify original method was not called
client.embeddings._post.assert_not_called()
async def test_completions_recording(self, temp_storage_dir):
real_completions_response = OpenAICompletion(
id="test_completion",
object="text_completion",
created=1234567890,
model="llama3.2:3b",
choices=[
{
"text": "Hello! I'm doing well, thank you for asking.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
)
temp_storage_dir = temp_storage_dir / "test_completions_recording"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Record
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Replay
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
@ -266,7 +374,7 @@ class TestInferenceRecording:
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.LIVE):
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import random
import pytest
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
async def temp_prompt_store(tmp_path_factory):
unique_id = f"prompt_store_{random.randint(1, 1000000)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
config = PromptServiceConfig(run_config=mock_run_config)
store = PromptServiceImpl(config, deps={})
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
yield store

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
class TestPrompts:
async def test_create_and_get_prompt(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("Hello world!", ["name"])
assert prompt.prompt == "Hello world!"
assert prompt.version == 1
assert prompt.prompt_id.startswith("pmpt_")
assert prompt.variables == ["name"]
retrieved = await temp_prompt_store.get_prompt(prompt.prompt_id)
assert retrieved.prompt_id == prompt.prompt_id
assert retrieved.prompt == prompt.prompt
async def test_update_prompt(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("Original")
updated = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
assert updated.version == 2
assert updated.prompt == "Updated"
async def test_update_prompt_with_version(self, temp_prompt_store):
version_for_update = 1
prompt = await temp_prompt_store.create_prompt("Original")
assert prompt.version == 1
prompt = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
assert prompt.version == 2
with pytest.raises(ValueError):
# now this is a stale version
await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
with pytest.raises(ValueError):
# this version does not exist
await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
async def test_delete_prompt(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("to be deleted")
await temp_prompt_store.delete_prompt(prompt.prompt_id)
with pytest.raises(ValueError):
await temp_prompt_store.get_prompt(prompt.prompt_id)
async def test_list_prompts(self, temp_prompt_store):
response = await temp_prompt_store.list_prompts()
assert response.data == []
await temp_prompt_store.create_prompt("first")
await temp_prompt_store.create_prompt("second")
response = await temp_prompt_store.list_prompts()
assert len(response.data) == 2
async def test_version(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("V1")
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
v1 = await temp_prompt_store.get_prompt(prompt.prompt_id, version=1)
assert v1.version == 1 and v1.prompt == "V1"
latest = await temp_prompt_store.get_prompt(prompt.prompt_id)
assert latest.version == 2 and latest.prompt == "V2"
async def test_set_default_version(self, temp_prompt_store):
prompt0 = await temp_prompt_store.create_prompt("V1")
prompt1 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V2", 1)
assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 2
prompt_default = await temp_prompt_store.set_default_version(prompt0.prompt_id, 1)
assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 1
assert prompt_default.version == 1
prompt2 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
assert prompt2.version == 3
async def test_prompt_id_generation_and_validation(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("Test")
assert prompt.prompt_id.startswith("pmpt_")
assert len(prompt.prompt_id) == 53
with pytest.raises(ValueError):
await temp_prompt_store.get_prompt("invalid_id")
async def test_list_shows_default_versions(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("V1")
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
response = await temp_prompt_store.list_prompts()
listed_prompt = response.data[0]
assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
await temp_prompt_store.set_default_version(prompt.prompt_id, 1)
response = await temp_prompt_store.list_prompts()
listed_prompt = response.data[0]
assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
assert not (await temp_prompt_store.get_prompt(prompt.prompt_id, 3)).is_default
async def test_get_all_prompt_versions(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("V1")
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
assert len(versions) == 3
assert [v.version for v in versions] == [1, 2, 3]
assert [v.is_default for v in versions] == [False, False, True]
await temp_prompt_store.set_default_version(prompt.prompt_id, 2)
versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
assert [v.is_default for v in versions] == [False, True, False]
with pytest.raises(ValueError):
await temp_prompt_store.list_prompt_versions("nonexistent")
async def test_prompt_variable_validation(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
assert prompt.variables == ["name", "city"]
prompt_no_vars = await temp_prompt_store.create_prompt("Hello world!", [])
assert prompt_no_vars.variables == []
with pytest.raises(ValueError, match="undeclared variables"):
await temp_prompt_store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
async def test_update_prompt_set_as_default_behavior(self, temp_prompt_store):
prompt = await temp_prompt_store.create_prompt("V1")
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 1
prompt_v2 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1, [], set_as_default=True)
assert prompt_v2.version == 2
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2
prompt_v3 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2, [], set_as_default=False)
assert prompt_v3.version == 3
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2

View file

@ -16,9 +16,11 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import ListToolsResponse, Tool, ToolGroups, ToolParameter, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@ -75,11 +77,11 @@ def sample_agent_config():
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["string"],
toolgroups=["mcp::my_mcp_server"],
client_tools=[
{
"name": "string",
"description": "string",
"name": "client_tool",
"description": "Client Tool",
"parameters": [
{
"name": "string",
@ -226,3 +228,83 @@ async def test_delete_agent(agents_impl, sample_agent_config):
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)
async def test__initialize_tools(agents_impl, sample_agent_config):
# Mock tool_groups_api.list_tools()
agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse(
data=[
Tool(
identifier="story_maker",
provider_id="model-context-protocol",
type=ResourceType.tool,
toolgroup_id="mcp::my_mcp_server",
description="Make a story",
parameters=[
ToolParameter(
name="story_title",
parameter_type="string",
description="Title of the story",
required=True,
title="Story Title",
),
ToolParameter(
name="input_words",
parameter_type="array",
description="Input words",
required=False,
items={"type": "string"},
title="Input Words",
default=[],
),
],
)
]
)
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
# Get an instance of ChatAgent
chat_agent = await agents_impl._get_agent_impl(agent_id)
assert chat_agent is not None
assert isinstance(chat_agent, ChatAgent)
# Initialize tool definitions
await chat_agent._initialize_tools()
assert len(chat_agent.tool_defs) == 2
# Verify the first tool, which is a client tool
first_tool = chat_agent.tool_defs[0]
assert first_tool.tool_name == "client_tool"
assert first_tool.description == "Client Tool"
# Verify the second tool, which is an MCP tool that has an array-type property
second_tool = chat_agent.tool_defs[1]
assert second_tool.tool_name == "story_maker"
assert second_tool.description == "Make a story"
parameters = second_tool.parameters
assert len(parameters) == 2
# Verify a string property
story_title = parameters.get("story_title")
assert story_title is not None
assert story_title.param_type == "string"
assert story_title.description == "Title of the story"
assert story_title.required
assert story_title.items is None
assert story_title.title == "Story Title"
assert story_title.default is None
# Verify an array property
input_words = parameters.get("input_words")
assert input_words is not None
assert input_words.param_type == "array"
assert input_words.description == "Input words"
assert not input_words.required
assert input_words.items is not None
assert len(input_words.items) == 1
assert input_words.items.get("type") == "string"
assert input_words.title == "Input Words"
assert input_words.default == []

View file

@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
async def test_validate_input_missing_parameters_chat_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
assert errors[0].message == error_message
assert errors[0].param == param_path
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
],
)
async def test_validate_input_missing_parameters_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for text completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/completions",
"body": {"model": "test-model", "prompt": "Hello"},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import patch
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class TestBedrockBaseConfig:
def test_defaults_work_without_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
config = BedrockBaseConfig()
# Basic creds should be None
assert config.aws_access_key_id is None
assert config.aws_secret_access_key is None
assert config.region_name is None
# Timeouts get defaults
assert config.connect_timeout == 60.0
assert config.read_timeout == 60.0
assert config.session_ttl == 3600
def test_env_vars_get_picked_up(self):
env_vars = {
"AWS_ACCESS_KEY_ID": "AKIATEST123",
"AWS_SECRET_ACCESS_KEY": "secret123",
"AWS_DEFAULT_REGION": "us-west-2",
"AWS_MAX_ATTEMPTS": "5",
"AWS_RETRY_MODE": "adaptive",
"AWS_CONNECT_TIMEOUT": "30",
}
with patch.dict(os.environ, env_vars, clear=True):
config = BedrockBaseConfig()
assert config.aws_access_key_id == "AKIATEST123"
assert config.aws_secret_access_key == "secret123"
assert config.region_name == "us-west-2"
assert config.total_max_attempts == 5
assert config.retry_mode == "adaptive"
assert config.connect_timeout == 30.0
def test_partial_env_setup(self):
# Just setting one timeout var
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
config = BedrockBaseConfig()
assert config.connect_timeout == 120.0
assert config.read_timeout == 60.0 # still default
assert config.aws_access_key_id is None
def test_bad_max_attempts_breaks(self):
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
try:
BedrockBaseConfig()
raise AssertionError("Should have failed on bad int conversion")
except ValueError:
pass # expected

View file

@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching():
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
assert inference_adapter.client.api_key == api_key
def test_openai_provider_openai_client_caching():

View file

@ -26,7 +26,6 @@ class TestProviderDataValidator(BaseModel):
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: TestConfig):
super().__init__(
model_entries=[],
litellm_provider_name="test",
api_key_from_config=config.api_key,
provider_data_api_key_field="test_api_key",

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import os
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
@ -80,11 +80,22 @@ class TestOpenAIBaseURLConfig:
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client and its models.retrieve method
# Mock a model object that will be returned by models.list()
mock_model = MagicMock()
mock_model.id = "gpt-4"
# Create an async iterator that yields our mock model
async def mock_async_iterator():
yield mock_model
# Mock the AsyncOpenAI client and its models.list method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
mock_openai_class.return_value = mock_client
# Set the __provider_id__ attribute that's expected by list_models
adapter.__provider_id__ = "openai"
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
@ -94,8 +105,8 @@ class TestOpenAIBaseURLConfig:
base_url=custom_url,
)
# Verify the method was called and returned True
mock_client.models.retrieve.assert_called_once_with("gpt-4")
# Verify the models.list method was called
mock_client.models.list.assert_called_once()
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
@ -110,11 +121,22 @@ class TestOpenAIBaseURLConfig:
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client
# Mock a model object that will be returned by models.list()
mock_model = MagicMock()
mock_model.id = "gpt-4"
# Create an async iterator that yields our mock model
async def mock_async_iterator():
yield mock_model
# Mock the AsyncOpenAI client and its models.list method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
mock_openai_class.return_value = mock_client
# Set the __provider_id__ attribute that's expected by list_models
adapter.__provider_id__ = "openai"
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")

View file

@ -6,19 +6,15 @@
import asyncio
import json
import logging # allow-direct-logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoice,
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
@ -35,6 +31,9 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
SystemMessage,
ToolChoice,
ToolConfig,
@ -61,52 +60,21 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
class MockInferenceAdapterWithSleep:
def __init__(self, sleep_time: int, response: dict[str, Any]):
self.httpd = None
class DelayedRequestHandler(BaseHTTPRequestHandler):
# ruff: noqa: N802
def do_POST(self):
time.sleep(sleep_time)
response_body = json.dumps(response).encode("utf-8")
self.send_response(code=200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(response_body))
self.end_headers()
self.wfile.write(response_body)
self.request_handler = DelayedRequestHandler
def __enter__(self):
httpd = HTTPServer(("", 0), self.request_handler)
self.httpd = httpd
host, port = httpd.server_address
httpd_thread = threading.Thread(target=httpd.serve_forever)
httpd_thread.daemon = True # stop server if this thread terminates
httpd_thread.start()
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
inference_adapter = VLLMInferenceAdapter(config)
return inference_adapter
def __exit__(self, _exc_type, _exc_value, _traceback):
if self.httpd:
self.httpd.shutdown()
self.httpd.server_close()
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
@ -150,10 +118,16 @@ async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct
with patch.object(
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_nonstream_completion:
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
# Mock the model to return a proper provider_resource_id
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
@ -179,7 +153,7 @@ async def test_tool_call_response(vllm_inference_adapter):
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
@ -199,7 +173,7 @@ async def test_tool_call_delta_empty_tool_call_buf():
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)]
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
@ -225,7 +199,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -250,7 +224,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -275,7 +249,9 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@ -299,7 +275,7 @@ async def test_multiple_tool_calls():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -324,7 +300,7 @@ async def test_multiple_tool_calls():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
@ -349,7 +325,9 @@ async def test_multiple_tool_calls():
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@ -393,59 +371,6 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
caplog.set_level(logging.WARNING)
# Log when event loop is blocked for more than 200ms
loop.slow_callback_duration = 0.5
# Sleep for 500ms in our delayed http response
sleep_time = 0.5
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
mock_response = {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1,
"modle": "mock-model",
"choices": [
{
"message": {"content": ""},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
async def do_chat_completion():
await inference_adapter.chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
inference_adapter.model_store = AsyncMock()
inference_adapter.model_store.get_model.return_value = mock_model
loop.run_until_complete(inference_adapter.initialize())
# Clear the logs so far and run the actual chat completion we care about
caplog.clear()
loop.run_until_complete(do_chat_completion())
# Ensure we don't have any asyncio warnings in the captured log
# records from our chat completion call. A message gets logged
# here any time we exceed the slow_callback_duration configured
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
@ -638,33 +563,29 @@ async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
This test verifies that the health method returns a HealthResponse with status OK
when the /health endpoint responds successfully.
"""
# Set vllm_inference_adapter.client to None to ensure _create_client is called
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Create a mock async iterator that yields a model when iterated
async def mock_list():
for model in [MagicMock()]:
yield model
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models
mock_create_client.return_value = mock_client
# Create mock client instance
mock_client_instance = MagicMock()
mock_client_instance.get = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
# Verify that models.list was called
mock_models.list.assert_called_once()
# Verify that the health endpoint was called
mock_client_instance.get.assert_called_once()
call_args = mock_client_instance.get.call_args[0]
assert call_args[0].endswith("/health")
async def test_health_status_failure(vllm_inference_adapter):
@ -674,26 +595,190 @@ async def test_health_status_failure(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
# Create a mock async iterator that raises an exception when iterated
async def mock_list():
raise Exception("Connection failed")
yield # Unreachable code
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models
mock_create_client.return_value = mock_client
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock client instance that raises an exception
mock_client_instance = MagicMock()
mock_client_instance.get.side_effect = Exception("Connection failed")
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.ERROR
assert "Health check failed: Connection failed" in health_response["message"]
mock_models.list.assert_called_once()
async def test_health_status_no_static_api_key(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when no static API key is provided.
This test verifies that the health method returns a HealthResponse with status OK
when the /health endpoint responds successfully, regardless of API token configuration.
"""
with patch("httpx.AsyncClient") as mock_client_class:
# Create mock response
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Create mock client instance
mock_client_instance = MagicMock()
mock_client_instance.get = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
"""
Verify that openai_chat_completion is async and doesn't block the event loop.
To do this we mock the underlying inference with a sleep, start multiple
inference calls in parallel, and ensure the total time taken is less
than the sum of the individual sleep times.
"""
sleep_time = 0.5
async def mock_create(*args, **kwargs):
await asyncio.sleep(sleep_time)
return OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="nothing interesting",
),
finish_reason="stop",
index=0,
)
],
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=mock_create)
mock_create_client.return_value = mock_client
start_time = time.time()
await asyncio.gather(do_inference(), do_inference(), do_inference(), do_inference())
total_time = time.time() - start_time
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_should_refresh_models():
"""
Test the should_refresh_models method with different refresh_models configurations.
This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
"""
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.chat_completion(
"test-model",
[UserMessage(content="Hello")],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass

View file

@ -52,14 +52,19 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
self.evaluator_post_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
)
self.evaluator_delete_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_delete"
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
self.mock_evaluator_delete = self.evaluator_delete_patcher.start()
def tearDown(self):
"""Clean up after each test."""
self.evaluator_get_patcher.stop()
self.evaluator_post_patcher.stop()
self.evaluator_delete_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
@ -115,6 +120,13 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
def test_unregister_benchmark(self):
# Unregister the benchmark
self.run_async(self.eval_impl.unregister_benchmark(benchmark_id=MOCK_BENCHMARK_ID))
# Verify the Evaluator API was called correctly
self.mock_evaluator_delete.assert_called_once_with(f"/v1/evaluation/configs/nvidia/{MOCK_BENCHMARK_ID}")
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
@ -138,7 +150,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
self._assert_request_body(
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
"target": {"type": "model", "model": "Llama3.1-8B-Instruct"},
}
)

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.bedrock.bedrock import (
_get_region_prefix,
_to_inference_profile_id,
)
def test_region_prefixes():
assert _get_region_prefix("us-east-1") == "us."
assert _get_region_prefix("eu-west-1") == "eu."
assert _get_region_prefix("ap-south-1") == "ap."
assert _get_region_prefix("ca-central-1") == "us."
# Test case insensitive
assert _get_region_prefix("US-EAST-1") == "us."
assert _get_region_prefix("EU-WEST-1") == "eu."
assert _get_region_prefix("Ap-South-1") == "ap."
# Test None region
assert _get_region_prefix(None) == "us."
def test_model_id_conversion():
# Basic conversion
assert (
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
)
# Already has prefix
assert (
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
== "us.meta.llama3-1-70b-instruct-v1:0"
)
# ARN should be returned unchanged
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
assert _to_inference_profile_id(arn, "us-east-1") == arn
# ARN should be returned unchanged even without region
assert _to_inference_profile_id(arn) == arn
# Optional region parameter defaults to us-east-1
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
# Different regions work with optional parameter
assert (
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
)

View file

@ -0,0 +1,368 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class OpenAIMixinImpl(OpenAIMixin):
def __init__(self):
self.__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
@pytest.fixture
def mock_models():
"""Create multiple mock OpenAI model objects"""
models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]]
return models
@pytest.fixture
def mock_client_with_models(mock_models):
"""Create a mock client with models.list() set up to return mock_models"""
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
return mock_client
@pytest.fixture
def mock_client_with_empty_models():
"""Create a mock client with models.list() set up to return empty list"""
mock_client = MagicMock()
async def mock_empty_models_list():
return
yield # Make it an async generator but don't yield anything
mock_client.models.list.return_value = mock_empty_models_list()
return mock_client
@pytest.fixture
def mock_client_with_exception():
"""Create a mock client with models.list() set up to raise an exception"""
mock_client = MagicMock()
mock_client.models.list.side_effect = Exception("API Error")
return mock_client
@pytest.fixture
def mock_client_context():
"""Fixture that provides a context manager for mocking the OpenAI client"""
def _mock_client_context(mixin, mock_client):
return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client)
return _mock_client_context
class TestOpenAIMixinListModels:
"""Test cases for the list_models method"""
async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context):
"""Test successful model listing"""
assert len(mixin._model_cache) == 0
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 3
model_ids = [model.identifier for model in result]
assert "some-mock-model-id" in model_ids
assert "another-mock-model-id" in model_ids
assert "final-mock-model-id" in model_ids
for model in result:
assert model.provider_id == "test-provider"
assert model.model_type == ModelType.llm
assert model.provider_resource_id == model.identifier
assert len(mixin._model_cache) == 3
for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]:
assert model_id in mixin._model_cache
cached_model = mixin._model_cache[model_id]
assert cached_model.identifier == model_id
assert cached_model.provider_resource_id == model_id
async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context):
"""Test handling of empty model list"""
with mock_client_context(mixin, mock_client_with_empty_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
class TestOpenAIMixinCheckModelAvailability:
"""Test cases for the check_model_availability method"""
async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context):
"""Test model availability check when cache is populated"""
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
await mixin.list_models()
mock_client_with_models.models.list.assert_called_once()
assert await mixin.check_model_availability("some-mock-model-id")
assert await mixin.check_model_availability("another-mock-model-id")
assert await mixin.check_model_availability("final-mock-model-id")
assert not await mixin.check_model_availability("non-existent-model")
mock_client_with_models.models.list.assert_called_once()
async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context):
"""Test model availability check when cache is empty (calls list_models)"""
assert len(mixin._model_cache) == 0
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("some-mock-model-id")
mock_client_with_models.models.list.assert_called_once()
assert len(mixin._model_cache) == 3
assert "some-mock-model-id" in mixin._model_cache
async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context):
"""Test model availability check for non-existent model"""
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert not await mixin.check_model_availability("non-existent-model")
mock_client_with_models.models.list.assert_called_once()
assert len(mixin._model_cache) == 3
class TestOpenAIMixinCacheBehavior:
"""Test cases for cache behavior and edge cases"""
async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context):
"""Test that calling list_models overwrites existing cache"""
initial_model = Model(
provider_id="test-provider",
provider_resource_id="old-model",
identifier="old-model",
model_type=ModelType.llm,
)
mixin._model_cache = {"old-model": initial_model}
with mock_client_context(mixin, mock_client_with_models):
await mixin.list_models()
assert len(mixin._model_cache) == 3
assert "old-model" not in mixin._model_cache
assert "some-mock-model-id" in mixin._model_cache
assert "another-mock-model-id" in mixin._model_cache
assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == ""
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
class TestOpenAIMixinEmbeddingModelMetadata:
"""Test cases for embedding_model_metadata attribute functionality"""
async def test_embedding_model_identified_and_augmented(self, mixin_with_embeddings, mock_client_context):
"""Test that models in embedding_model_metadata are correctly identified as embeddings with metadata"""
# Create mock models: 1 embedding model and 1 LLM, while there are 2 known embedding models
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_embeddings, mock_client):
result = await mixin_with_embeddings.list_models()
assert result is not None
assert len(result) == 2
# Find the models in the result
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
llm_model = next(m for m in result if m.identifier == "gpt-4")
# Check embedding model
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinAllowedModels:
"""Test cases for allowed_models filtering functionality"""
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
"""Test that list_models filters models based on allowed_models set"""
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
model_ids = [model.identifier for model in result]
assert "some-mock-model-id" in model_ids
assert "another-mock-model-id" in model_ids
assert "final-mock-model-id" not in model_ids
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
"""Test that empty allowed_models set allows all models"""
assert len(mixin.allowed_models) == 0
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 3 # All models should be included
model_ids = [model.identifier for model in result]
assert "some-mock-model-id" in model_ids
assert "another-mock-model-id" in model_ids
assert "final-mock-model-id" in model_ids
async def test_check_model_availability_with_allowed_models(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability respects allowed_models"""
mixin.allowed_models = {"final-mock-model-id"}
with mock_client_context(mixin, mock_client_with_models):
assert await mixin.check_model_availability("final-mock-model-id")
assert not await mixin.check_model_availability("some-mock-model-id")
assert not await mixin.check_model_availability("another-mock-model-id")

View file

@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2

View file

@ -84,14 +84,14 @@ def unknown_model() -> Model:
@pytest.fixture
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
return ModelRegistryHelper([known_provider_model, known_provider_model2])
return ModelRegistryHelper(model_entries=[known_provider_model, known_provider_model2])
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
"""Test helper that simulates a provider with dynamically available models."""
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
super().__init__(model_entries)
super().__init__(model_entries=model_entries)
self._available_models = available_models
async def check_model_availability(self, model: str) -> bool:

View file

@ -54,7 +54,9 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
mock_vector_db.model_dump_json.return_value = (
'{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}'
'{"identifier": "'
+ vector_db_id
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
)
return mock_vector_db

View file

@ -26,9 +26,9 @@ def test_generate_chunk_id():
chunk_ids = sorted([chunk.chunk_id for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
"31d1f9a3-c8d2-66e7-3c37-af2acd329778",
"d07dade7-29c0-cda7-df29-0249a1dcbc3e",
"d14f75a1-5855-7f72-2c78-d9fc4275a346",
]
@ -36,14 +36,14 @@ def test_generate_chunk_id_with_window():
chunk = Chunk(content="test", metadata={"document_id": "doc-1"})
chunk_id1 = generate_chunk_id("doc-1", chunk, chunk_window="0-1")
chunk_id2 = generate_chunk_id("doc-1", chunk, chunk_window="1-2")
assert chunk_id1 == "149018fe-d0eb-0f8d-5f7f-726bdd2aeedb"
assert chunk_id2 == "4562c1ee-9971-1f3b-51a6-7d05e5211154"
assert chunk_id1 == "8630321a-d9cb-2bb6-cd28-ebf68dafd866"
assert chunk_id2 == "13a1c09a-cbda-b61a-2d1a-7baa90888685"
def test_chunk_id():
# Test with existing chunk ID
chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"})
assert chunk_with_id.chunk_id == "84ededcc-b80b-a83e-1a20-ca6515a11350"
assert chunk_with_id.chunk_id == "11704f92-42b6-61df-bf85-6473e7708fbd"
# Test with document ID in metadata
chunk_with_doc_id = Chunk(content="test", metadata={"document_id": "doc-1"})

View file

@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery:
async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
content = "test query content"
vector_db_ids = ["db1"]
@ -77,3 +81,58 @@ class TestRagQuery:
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
inference_api=MagicMock(),
files_api=MagicMock(),
)
vector_db_ids = ["db1", "db2"]
# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source1",
metadata_token_count=5,
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)
chunk_metadata2 = ChunkMetadata(
document_id="doc2",
chunk_id="chunk2",
source="test_source2",
metadata_token_count=5,
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)
rag_tool.vector_io_api.query_chunks = AsyncMock(
side_effect=[
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
]
)
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"]

View file

@ -774,3 +774,136 @@ def test_has_required_scope_function():
# Test no user (auth disabled)
assert _has_required_scope("test.read", None)
@pytest.fixture
def mock_kubernetes_api_server():
return "https://api.cluster.example.com:6443"
@pytest.fixture
def kubernetes_auth_app(mock_kubernetes_api_server):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_config={
"type": "kubernetes",
"api_server_url": mock_kubernetes_api_server,
"verify_tls": False,
"claims_mapping": {
"username": "roles",
"groups": "roles",
"uid": "uid_attr",
},
},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def kubernetes_auth_client(kubernetes_auth_app):
return TestClient(kubernetes_auth_app)
def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client):
response = kubernetes_auth_client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["error"]["message"]
def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client):
response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Invalid Authorization header format" in response.json()["error"]["message"]
async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs):
return MockResponse(
201,
{
"apiVersion": "authentication.k8s.io/v1",
"kind": "SelfSubjectReview",
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
"status": {
"userInfo": {
"username": "alice",
"uid": "alice-uid-123",
"groups": ["system:authenticated", "developers", "admins"],
"extra": {"scopes.authorization.openshift.io": ["user:full"]},
}
},
},
)
async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs):
return MockResponse(401, {"message": "Unauthorized"})
async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs):
return MockResponse(500, {"message": "Internal Server Error"})
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success)
def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid token" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
assert response.status_code == 401
assert "Token validation failed" in response.json()["error"]["message"]
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = MockResponse(
200,
{
"apiVersion": "authentication.k8s.io/v1",
"kind": "SelfSubjectReview",
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
"status": {
"userInfo": {
"username": "test-user",
"uid": "test-uid",
"groups": ["test-group"],
}
},
},
)
mock_post.return_value = mock_response
kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
# Verify the request was made with correct parameters
mock_post.assert_called_once()
call_args = mock_post.call_args
# Check URL (passed as positional argument)
assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews"
# Check headers (passed as keyword argument)
headers = call_args[1]["headers"]
assert headers["Authorization"] == f"Bearer {valid_token}"
assert headers["Content-Type"] == "application/json"
# Check request body (passed as keyword argument)
request_body = call_args[1]["json"]
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
assert request_body["kind"] == "SelfSubjectReview"

View file

@ -113,6 +113,15 @@ class TestTranslateException:
assert result.status_code == 504
assert result.detail == "Operation timed out: "
def test_translate_connection_error(self):
"""Test that ConnectionError is translated to 502 HTTP status."""
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 502
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
def test_translate_not_implemented_error(self):
"""Test that NotImplementedError is translated to 501 HTTP status."""
exc = NotImplementedError("Not implemented")

View file

@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control
await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document"
# User should only see their document
mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1
assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None
assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table(
table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy)
sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set()
for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name,
)
)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table(
table="user_data",