Merge upstream/main into add-mongodb-vector_io

Resolved conflicts:
- Integrated MongoDB provider with newly added Qdrant and Weaviate providers
- Updated distribution configs to include all three providers
- Merged build.yaml and run.yaml configs for ci-tests, starter, and starter-gpu distributions
- Updated starter.py to include MongoDB, Qdrant, and Weaviate provider initialization
- Added MongoDB provider files to src/ directory structure
- Updated MongoDB provider to use new VectorStore API (was VectorDB)
- Updated MongoDB config to use KVStoreReference instead of KVStoreConfig
- Applied auto-formatting changes from pre-commit hooks
This commit is contained in:
Young Han 2025-10-28 16:11:03 -07:00
commit efe9c04849
1820 changed files with 402590 additions and 32499 deletions

View file

@ -23,6 +23,30 @@ def config_with_image_name_int():
image_name: 1234
apis_to_serve: []
built_at: {datetime.now().isoformat()}
storage:
backends:
kv_default:
type: kv_sqlite
db_path: /tmp/test_kv.db
sql_default:
type: sql_sqlite
db_path: /tmp/test_sql.db
stores:
metadata:
backend: kv_default
namespace: metadata
inference:
backend: sql_default
table_name: inference
conversations:
backend: sql_default
table_name: conversations
responses:
backend: sql_default
table_name: responses
prompts:
backend: kv_default
namespace: prompts
providers:
inference:
- provider_id: provider1
@ -54,6 +78,27 @@ def up_to_date_config():
image_name: foo
apis_to_serve: []
built_at: {datetime.now().isoformat()}
storage:
backends:
kv_default:
type: kv_sqlite
db_path: /tmp/test_kv.db
sql_default:
type: sql_sqlite
db_path: /tmp/test_sql.db
stores:
metadata:
backend: kv_default
namespace: metadata
inference:
backend: sql_default
table_name: inference
conversations:
backend: sql_default
table_name: conversations
responses:
backend: sql_default
table_name: responses
providers:
inference:
- provider_id: provider1

View file

@ -4,17 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest_socket
import logging # allow-direct-logging
import os
import warnings
# We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
import pytest
def pytest_runtest_setup(item):
"""Setup for each test - check if network access should be allowed."""
if "allow_network" in item.keywords:
pytest_socket.enable_socket()
else:
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
pytest_socket.disable_socket(allow_unix_socket=True)
def pytest_sessionstart(session) -> None:
if "LLAMA_STACK_LOGGING" not in os.environ:
os.environ["LLAMA_STACK_LOGGING"] = "all=WARNING"
# Silence common deprecation spam during unit tests.
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
@pytest.fixture(autouse=True)
def suppress_httpx_logs(caplog):
"""Suppress httpx INFO logs for all unit tests"""
caplog.set_level(logging.WARNING, logger="httpx")
pytest_plugins = ["tests.unit.fixtures"]

View file

@ -20,7 +20,14 @@ from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.storage.datatypes import (
ServerStoresConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
@pytest.fixture
@ -28,7 +35,18 @@ async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
storage = StorageConfig(
backends={
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
},
stores=ServerStoresConfig(
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
),
)
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
config = ConversationServiceConfig(run_config=run_config, policy=[])
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
@ -64,7 +82,7 @@ async def test_conversation_items(service):
assert len(item_list.data) == 1
assert item_list.data[0].id == "msg_test123"
items = await service.list(conversation.id)
items = await service.list_items(conversation.id)
assert len(items.data) == 1
@ -102,7 +120,7 @@ async def test_openai_type_compatibility(service):
assert hasattr(item_list, attr)
assert item_list.object == "list"
items = await service.list(conversation.id)
items = await service.list_items(conversation.id)
item = await service.retrieve(conversation.id, items.data[0].id)
item_dict = item.model_dump()
@ -121,9 +139,18 @@ async def test_policy_configuration():
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
storage = StorageConfig(
backends={
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
},
stores=ServerStoresConfig(
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
),
)
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
service = ConversationServiceImpl(config, {})
await service.initialize()

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import ListShieldsResponse, Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.core.routers.safety import SafetyRouter
async def test_run_moderation_uses_default_shield_when_model_missing():
routing_table = AsyncMock()
shield = Shield(
identifier="shield-1",
provider_resource_id="provider/shield-model",
provider_id="provider-id",
params={},
)
routing_table.list_shields.return_value = ListShieldsResponse(data=[shield])
moderation_response = ModerationObject(
id="mid",
model="shield-1",
results=[ModerationObjectResults(flagged=False)],
)
provider = AsyncMock()
provider.run_moderation.return_value = moderation_response
routing_table.get_provider_impl.return_value = provider
router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-1"))
result = await router.run_moderation("hello world")
assert result is moderation_response
routing_table.get_provider_impl.assert_awaited_once_with("shield-1")
provider.run_moderation.assert_awaited_once()
_, kwargs = provider.run_moderation.call_args
assert kwargs["model"] == "provider/shield-model"
assert kwargs["input"] == "hello world"

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, Mock
import pytest
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
from llama_stack.core.routers.vector_io import VectorIORouter
async def test_single_provider_auto_selection():
# provider_id automatically selected during vector store create() when only one provider available
mock_routing_table = Mock()
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
mock_routing_table.get_all_with_type = AsyncMock(
return_value=[
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
mock_routing_table.register_vector_store = AsyncMock(
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
)
result = await router.openai_create_vector_store(request)
assert result.id == "vs_123"
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
# if multiple providers are available, vector store create will error without provider_id
mock_routing_table = Mock()
mock_routing_table.impls_by_provider_id = {
"inline::faiss": "mock_provider_1",
"inline::sqlite-vec": "mock_provider_2",
}
mock_routing_table.get_all_with_type = AsyncMock(
return_value=[
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
)
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
await router.openai_create_vector_store(request)

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from contextlib import contextmanager
from contextvars import ContextVar
from llama_stack.core.utils.context import preserve_contexts_async_generator
# Define provider data context variable and context manager locally
PROVIDER_DATA_VAR = ContextVar("provider_data", default=None)
@contextmanager
def request_provider_data_context(headers):
val = headers.get("X-LlamaStack-Provider-Data")
provider_data = json.loads(val) if val else {}
token = PROVIDER_DATA_VAR.set(provider_data)
try:
yield
finally:
PROVIDER_DATA_VAR.reset(token)
def create_sse_event(data):
return f"data: {json.dumps(data)}\n\n"
async def sse_generator(event_gen_coroutine):
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0)
async def async_event_gen():
async def event_gen():
yield PROVIDER_DATA_VAR.get()
return event_gen()
async def test_provider_data_context_cleared_between_sse_requests():
headers = {"X-LlamaStack-Provider-Data": json.dumps({"api_key": "abc"})}
with request_provider_data_context(headers):
gen1 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
events1 = [event async for event in gen1]
assert events1 == [create_sse_event({"api_key": "abc"})]
assert PROVIDER_DATA_VAR.get() is None
gen2 = preserve_contexts_async_generator(sse_generator(async_event_gen()), [PROVIDER_DATA_VAR])
events2 = [event async for event in gen2]
assert events2 == [create_sse_event(None)]
assert PROVIDER_DATA_VAR.get() is None

View file

@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for Stack validation functions."""
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
from llama_stack.apis.shields import ListShieldsResponse, Shield
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
from llama_stack.providers.datatypes import Api
class TestVectorStoresValidation:
async def test_validate_missing_model(self):
"""Test validation fails when model not found."""
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="p",
model_id="missing",
),
),
)
mock_models = AsyncMock()
mock_models.list_models.return_value = ListModelsResponse(data=[])
with pytest.raises(ValueError, match="not found"):
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
async def test_validate_success(self):
"""Test validation passes with valid model."""
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="p",
model_id="valid",
),
),
)
mock_models = AsyncMock()
mock_models.list_models.return_value = ListModelsResponse(
data=[
Model(
identifier="p/valid", # Must match provider_id/model_id format
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768},
provider_id="p",
provider_resource_id="valid",
)
]
)
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
class TestSafetyConfigValidation:
async def test_validate_success(self):
safety_config = SafetyConfig(default_shield_id="shield-1")
shield = Shield(
identifier="shield-1",
provider_id="provider-x",
provider_resource_id="model-x",
params={},
)
shields_impl = AsyncMock()
shields_impl.list_shields.return_value = ListShieldsResponse(data=[shield])
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
async def test_validate_wrong_shield_id(self):
safety_config = SafetyConfig(default_shield_id="wrong-shield-id")
shields_impl = AsyncMock()
shields_impl.list_shields.return_value = ListShieldsResponse(
data=[
Shield(
identifier="shield-1",
provider_resource_id="model-x",
provider_id="provider-x",
params={},
)
]
)
with pytest.raises(ValueError, match="wrong-shield-id"):
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for storage backend/reference validation."""
import pytest
from pydantic import ValidationError
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
StackRunConfig,
)
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
def _base_run_config(**overrides):
metadata_reference = overrides.pop(
"metadata_reference",
KVStoreReference(backend="kv_default", namespace="registry"),
)
inference_reference = overrides.pop(
"inference_reference",
InferenceStoreReference(backend="sql_default", table_name="inference"),
)
conversations_reference = overrides.pop(
"conversations_reference",
SqlStoreReference(backend="sql_default", table_name="conversations"),
)
storage = overrides.pop(
"storage",
StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
"sql_default": SqliteSqlStoreConfig(db_path="/tmp/sql.db"),
},
stores=ServerStoresConfig(
metadata=metadata_reference,
inference=inference_reference,
conversations=conversations_reference,
),
),
)
return StackRunConfig(
version=LLAMA_STACK_RUN_CONFIG_VERSION,
image_name="test-distro",
apis=[],
providers={},
storage=storage,
**overrides,
)
def test_references_require_known_backend():
with pytest.raises(ValidationError, match="unknown backend 'missing'"):
_base_run_config(metadata_reference=KVStoreReference(backend="missing", namespace="registry"))
def test_references_must_match_backend_family():
with pytest.raises(ValidationError, match="kv_.* is required"):
_base_run_config(metadata_reference=KVStoreReference(backend="sql_default", namespace="registry"))
with pytest.raises(ValidationError, match="sql_.* is required"):
_base_run_config(
inference_reference=InferenceStoreReference(backend="kv_default", table_name="inference"),
)
def test_valid_configuration_passes_validation():
config = _base_run_config()
stores = config.storage.stores
assert stores.metadata is not None and stores.metadata.backend == "kv_default"
assert stores.inference is not None and stores.inference.backend == "sql_default"
assert stores.conversations is not None and stores.conversations.backend == "sql_default"

View file

@ -11,13 +11,13 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import RegistryEntrySource
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
@ -25,7 +25,6 @@ from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
@ -146,31 +145,6 @@ class ToolGroupsImpl(Impl):
)
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_create_vector_store(self, **kwargs):
import time
import uuid
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
return VectorStoreObject(
id=vector_store_id,
name=kwargs.get("name", vector_store_id),
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -263,40 +237,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.unregister_shield(identifier="non-existent")
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert vdb1.identifier in vector_db_ids
assert vdb2.identifier in vector_db_ids
# Verify they have UUID-based identifiers
assert vdb1.identifier.startswith("vs_")
assert vdb2.identifier.startswith("vs_")
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -511,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
await table.initialize()
# Register model with alias (model_id different from provider_model_id)
# NOTE: Aliases are not supported anymore, so this is a no-op
await table.register_model(
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
)
@ -519,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "my-alias" # Uses alias as identifier
assert model.identifier == "test_provider/actual-provider-model"
assert model.provider_resource_id == "actual-provider-model"
# Test lookup by alias works
retrieved_model = await table.get_model("my-alias")
assert retrieved_model.identifier == "my-alias"
# Test lookup by alias fails
with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"):
await table.get_model("my-alias")
retrieved_model = await table.get_model("test_provider/actual-provider-model")
assert retrieved_model.identifier == "test_provider/actual-provider-model"
assert retrieved_model.provider_resource_id == "actual-provider-model"
@ -555,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
assert model2.provider_resource_id == "common-model"
# Test lookup by unscoped provider_model_id fails with multiple providers error
try:
with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"):
await table.get_model("common-model")
raise AssertionError("Should have raised ValueError for multiple providers")
except ValueError as e:
assert "Multiple providers found" in str(e)
assert "provider1" in str(e) and "provider2" in str(e)
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
@ -583,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
assert retrieved_model.identifier == "test_provider/test-model"
# Test lookup by unscoped provider_model_id (fallback via iteration)
retrieved_model = await table.get_model("test-model")
assert retrieved_model.identifier == "test_provider/test-model"
assert retrieved_model.provider_resource_id == "test-model"
with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"):
await table.get_model("test-model")
# Test lookup of non-existent model fails
try:
with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"):
await table.get_model("non-existent")
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry):
@ -664,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
assert len(models.data) == 1
user_model = models.data[0]
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.identifier == "my-custom-alias"
assert user_model.identifier == "test_provider/provider-model-1"
assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh
@ -691,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
assert len(models.data) == 2
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None

View file

@ -1,381 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
import uuid
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.access_control.datatypes import AccessRule, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data}
for k in vdb_dict:
assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.cli.stack._build import (
_run_stack_build_command_from_build_config,
)
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.utils.image_types import LlamaStackImageType
def test_container_build_passes_path(monkeypatch, tmp_path):
called_with = {}
def spy_build_image(build_config, image_name, distro_or_config, run_config=None):
called_with["path"] = distro_or_config
called_with["run_config"] = run_config
return 0
monkeypatch.setattr(
"llama_stack.cli.stack._build.build_image",
spy_build_image,
raising=True,
)
cfg = BuildConfig(
image_type=LlamaStackImageType.CONTAINER.value,
distribution_spec=DistributionSpec(providers={}, description=""),
)
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
assert "path" in called_with
assert isinstance(called_with["path"], str)
assert Path(called_with["path"]).exists()
assert called_with["run_config"] is None

View file

@ -13,6 +13,15 @@ from pydantic import BaseModel, Field, ValidationError
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import ProviderSpec
@ -29,6 +38,33 @@ class SampleConfig(BaseModel):
}
def _default_storage() -> StorageConfig:
return StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
),
)
def make_stack_config(**overrides) -> StackRunConfig:
storage = overrides.pop("storage", _default_storage())
defaults = dict(
image_name="test_image",
apis=[],
providers={},
storage=storage,
)
defaults.update(overrides)
return StackRunConfig(**defaults)
@pytest.fixture
def mock_providers():
"""Mock the available_providers function to return test providers."""
@ -47,8 +83,8 @@ def mock_providers():
@pytest.fixture
def base_config(tmp_path):
"""Create a base StackRunConfig with common settings."""
return StackRunConfig(
image_name="test_image",
return make_stack_config(
apis=["inference"],
providers={
"inference": [
Provider(
@ -220,8 +256,8 @@ class TestProviderRegistry:
def test_missing_directory(self, mock_providers):
"""Test handling of missing external providers directory."""
config = StackRunConfig(
image_name="test_image",
config = make_stack_config(
apis=["inference"],
providers={
"inference": [
Provider(
@ -276,7 +312,6 @@ pip_packages:
"""Test loading an external provider from a module (success path)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
# Simulate a provider module with get_provider_spec
@ -291,7 +326,7 @@ pip_packages:
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -315,12 +350,11 @@ pip_packages:
def test_external_provider_from_module_not_found(self, mock_providers):
"""Test handling ModuleNotFoundError for missing provider module."""
from llama_stack.core.datatypes import Provider, StackRunConfig
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -339,12 +373,11 @@ pip_packages:
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
from llama_stack.core.datatypes import Provider, StackRunConfig
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -397,13 +430,12 @@ class TestGetExternalProvidersFromModule:
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -424,7 +456,6 @@ class TestGetExternalProvidersFromModule:
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -442,7 +473,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -562,7 +593,6 @@ class TestGetExternalProvidersFromModule:
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -587,7 +617,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -611,7 +641,6 @@ class TestGetExternalProvidersFromModule:
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -636,7 +665,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -660,7 +689,6 @@ class TestGetExternalProvidersFromModule:
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -686,7 +714,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -716,7 +744,6 @@ class TestGetExternalProvidersFromModule:
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
@ -725,7 +752,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -749,7 +776,6 @@ class TestGetExternalProvidersFromModule:
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
@ -763,7 +789,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -785,10 +811,9 @@ class TestGetExternalProvidersFromModule:
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={},
)
@ -803,7 +828,6 @@ class TestGetExternalProvidersFromModule:
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -828,7 +852,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from io import StringIO
from unittest.mock import patch
from llama_stack.cli.stack._list_deps import (
run_stack_list_deps_command,
)
def test_stack_list_deps_basic():
args = argparse.Namespace(
config=None,
env_name="test-env",
providers="inference=remote::ollama",
format="deps-only",
)
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
run_stack_list_deps_command(args)
output = mock_stdout.getvalue()
# deps-only format should NOT include "uv pip install" or "Dependencies for"
assert "uv pip install" not in output
assert "Dependencies for" not in output
# Check that expected dependencies are present
assert "ollama" in output
assert "aiohttp" in output
assert "fastapi" in output
def test_stack_list_deps_with_distro_uv():
args = argparse.Namespace(
config="starter",
env_name=None,
providers=None,
format="uv",
)
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
run_stack_list_deps_command(args)
output = mock_stdout.getvalue()
assert "uv pip install" in output

View file

@ -11,11 +11,12 @@ from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.providers.inline.files.localfs import (
LocalfsFilesImpl,
LocalfsFilesImplConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
class MockUploadFile:
@ -36,8 +37,11 @@ async def files_provider(tmp_path):
storage_dir = tmp_path / "files"
db_path = tmp_path / "files_metadata.db"
backend_name = "sql_localfs_test"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
config = LocalfsFilesImplConfig(
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
storage_dir=storage_dir.as_posix(),
metadata_store=SqlStoreReference(backend=backend_name, table_name="files_metadata"),
)
provider = LocalfsFilesImpl(config, default_policy())

View file

@ -9,7 +9,16 @@ import random
import pytest
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.utils.kvstore import register_kvstore_backends
@pytest.fixture
@ -19,12 +28,29 @@ async def temp_prompt_store(tmp_path_factory):
db_path = str(temp_dir / f"{unique_id}.db")
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
storage = StorageConfig(
backends={
"kv_test": SqliteKVStoreConfig(db_path=db_path),
"sql_test": SqliteSqlStoreConfig(db_path=str(temp_dir / f"{unique_id}_sql.db")),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_test", namespace="registry"),
inference=InferenceStoreReference(backend="sql_test", table_name="inference"),
conversations=SqlStoreReference(backend="sql_test", table_name="conversations"),
prompts=KVStoreReference(backend="kv_test", namespace="prompts"),
),
)
mock_run_config = StackRunConfig(
image_name="test-distribution",
apis=[],
providers={},
storage=storage,
)
config = PromptServiceConfig(run_config=mock_run_config)
store = PromptServiceImpl(config, deps={})
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
register_kvstore_backends({"kv_test": storage.backends["kv_test"]})
await store.initialize()
yield store

View file

@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
@ -25,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register KV and SQL store backends for testing."""
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
kv_path = str(tmp_path / "test_kv.db")
sql_path = str(tmp_path / "test_sql.db")
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
@pytest.fixture
def mock_apis():
return {
@ -33,20 +48,26 @@ def mock_apis():
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
"conversations_api": AsyncMock(spec=Conversations),
}
@pytest.fixture
def config(tmp_path):
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
return MetaReferenceAgentsImplConfig(
persistence_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
responses_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
persistence=AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
)
@ -59,7 +80,8 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
mock_apis["conversations_api"],
[],
)
await impl.initialize()
yield impl

View file

@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
@ -33,6 +34,7 @@ from llama_stack.apis.agents.openai_responses import (
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIJSONSchema,
OpenAIResponseFormatJSONObject,
@ -41,7 +43,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@ -49,7 +51,7 @@ from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -83,9 +85,28 @@ def mock_vector_io_api():
return vector_io_api
@pytest.fixture
def mock_conversations_api():
"""Mock conversations API for testing."""
mock_api = AsyncMock()
return mock_api
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_safety_api,
mock_conversations_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -93,6 +114,8 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api,
)
@ -148,12 +171,17 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
stream_options={
"include_usage": True,
},
)
)
# Should have content part events for text streaming
@ -240,13 +268,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == "What is the capital of Ireland?"
assert first_params.tools is not None
assert first_params.temperature == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "Dublin"
assert second_call.kwargs["temperature"] == 0.1
second_params = second_call.args[0]
assert second_params.messages[-1].content == "Dublin"
assert second_params.temperature == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
@ -332,9 +362,10 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output)
assert len(chunks[0].response.output) == 0
@ -378,9 +409,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
def assert_common_expectations(chunks) -> None:
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
assert len(chunks[0].response.output) == 0
completed_chunk = chunks[-1]
assert completed_chunk.type == "response.completed"
@ -496,7 +528,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
# Verify the the correct messages were sent to the inference API i.e.
# All of the responses message were convered to the chat completion message objects
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
params = call_args.args[0]
inference_messages = params.messages
for i, m in enumerate(input_messages):
if isinstance(m.content, str):
assert inference_messages[i].content == m.content
@ -664,7 +698,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 2
@ -702,7 +737,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4 # 1 system + 3 input messages
@ -762,7 +798,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4, sent_messages
@ -778,6 +815,69 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?"
async def test_create_openai_response_with_previous_response_instructions(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test prepending instructions and previous response with instructions."""
input_item_message = OpenAIResponseMessage(
id="123",
content="Name some towns in Ireland",
role="user",
)
response_output_message = OpenAIResponseMessage(
id="123",
content="Galway, Longford, Sligo",
status="completed",
role="assistant",
)
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
output=[response_output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
instructions="You are a helpful assistant.",
)
mock_responses_store.get_response_object.return_value = response
model = "meta-llama/Llama-3.1-8B-Instruct"
instructions = "You are a geography expert. Provide concise answers."
mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Execute
await openai_responses_impl.create_openai_response(
input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123"
)
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
# and that the previous response instructions were not carried over
assert len(sent_messages) == 4, sent_messages
assert sent_messages[0].role == "system"
assert sent_messages[0].content == instructions
# Check the rest of the messages were converted correctly
assert sent_messages[1].role == "user"
assert sent_messages[1].content == "Name some towns in Ireland"
assert sent_messages[2].role == "assistant"
assert sent_messages[2].content == "Galway, Longford, Sligo"
assert sent_messages[3].role == "user"
assert sent_messages[3].content == "Which is the largest?"
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup
@ -818,8 +918,10 @@ async def test_responses_store_list_input_items_logic():
# Create mock store and response store
mock_sql_store = AsyncMock()
backend_name = "sql_responses_test"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
responses_store = ResponsesStore(
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
)
responses_store.sql_store = mock_sql_store
@ -1002,7 +1104,8 @@ async def test_reuse_mcp_tool_list(
)
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
tools_seen = second_call.kwargs["tools"]
second_params = second_call.args[0]
tools_seen = second_params.tools
assert len(tools_seen) == 1
assert tools_seen[0]["function"]["name"] == "test_tool"
assert tools_seen[0]["function"]["description"] == "a test tool"
@ -1049,8 +1152,9 @@ async def test_create_openai_response_with_text_format(
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["response_format"] == response_format
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.response_format == response_format
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
@ -1066,3 +1170,75 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model,
text=OpenAIResponseText(format={"type": "invalid"}),
)
async def test_create_openai_response_with_output_types_as_input(
openai_responses_impl, mock_inference_api, mock_responses_store
):
"""Test that response outputs can be used as inputs in multi-turn conversations.
Before adding OpenAIResponseOutput types to OpenAIResponseInput,
creating a _OpenAIResponseObjectWithInputAndMessages with some output types
in the input field would fail with a Pydantic ValidationError.
This test simulates storing a response where the input contains output message
types (MCP calls, function calls), which happens in multi-turn conversations.
"""
model = "meta-llama/Llama-3.1-8B-Instruct"
# Mock the inference response
mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Create a response with store=True to trigger the storage path
result = await openai_responses_impl.create_openai_response(
input="What's the weather?",
model=model,
stream=True,
temperature=0.1,
store=True,
)
# Consume the stream
_ = [chunk async for chunk in result]
# Verify store was called
assert mock_responses_store.store_response_object.called
# Get the stored data
store_call_args = mock_responses_store.store_response_object.call_args
stored_response = store_call_args.kwargs["response_object"]
# Now simulate a multi-turn conversation where outputs become inputs
input_with_output_types = [
OpenAIResponseMessage(role="user", content="What's the weather?", name=None),
# These output types need to be valid OpenAIResponseInput
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="get_weather",
arguments='{"city": "Tokyo"}',
type="function_call",
),
OpenAIResponseOutputMessageMCPCall(
id="mcp_456",
type="mcp_call",
server_label="weather_server",
name="get_temperature",
arguments='{"location": "Tokyo"}',
output="25°C",
),
]
# This simulates storing a response in a multi-turn conversation
# where previous outputs are included in the input.
stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages(
id=stored_response.id,
created_at=stored_response.created_at,
model=stored_response.model,
status=stored_response.status,
output=stored_response.output,
input=input_with_output_types, # This will trigger Pydantic validation
messages=None,
)
assert stored_with_outputs.input == input_with_output_types
assert len(stored_with_outputs.input) == 3

View file

@ -0,0 +1,249 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@pytest.fixture
def responses_impl_with_conversations(
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
)
class TestConversationValidation:
"""Test conversation ID validation logic."""
async def test_nonexistent_conversation_raises_error(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
conv_id = "conv_nonexistent"
# Mock conversation not found
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError):
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation=conv_id, stream=False
)
class TestMessageSyncing:
"""Test message syncing to conversations."""
async def test_sync_response_to_conversation_simple(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing simple response to conversation."""
conv_id = "conv_test123"
input_text = "What are the 5 Ds of dodgeball?"
# Output items (what the model generated)
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
# should call add_items with user input and assistant response
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
assert call_args[0][0] == conv_id # conversation_id
items = call_args[0][1] # conversation_items
assert len(items) == 2
# User message
assert items[0].type == "message"
assert items[0].role == "user"
assert items[0].content[0].type == "input_text"
assert items[0].content[0].text == input_text
# Assistant message
assert items[1].type == "message"
assert items[1].role == "assistant"
async def test_sync_response_to_conversation_api_error(
self, responses_impl_with_conversations, mock_conversations_api
):
mock_conversations_api.add_items.side_effect = Exception("API Error")
output_items = []
# matching the behavior of OpenAI here
with pytest.raises(Exception, match="API Error"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_test123", "Hello", output_items
)
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
"""Test syncing with list of input messages."""
conv_id = "conv_test123"
input_messages = [
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
]
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
items = call_args[0][1]
# Should have input message + output message
assert len(items) == 2
class TestIntegrationWorkflow:
"""Integration tests for the full conversation workflow."""
async def test_create_response_with_valid_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a valid conversation parameter."""
mock_conversations_api.list_items.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
async def mock_streaming_response(*args, **kwargs):
message_item = OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="Test response", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
# Emit output_item.done event first (needed for conversation sync)
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id="resp_test123",
item=message_item,
output_index=0,
sequence_number=1,
type="response.output_item.done",
)
# Then emit response.completed
mock_response = OpenAIResponseObject(
id="resp_test123",
created_at=1234567890,
model="test-model",
object="response",
output=[message_item],
status="completed",
)
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
input_text = "Hello, how are you?"
conversation_id = "conv_test123"
response = await responses_impl_with_conversations.create_openai_response(
input=input_text, model="test-model", conversation=conversation_id, stream=False
)
assert response is not None
assert response.id == "resp_test123"
# Note: conversation sync happens inside _create_streaming_response,
# which we're mocking here, so we can't test it in this unit test.
# The sync logic is tested separately in TestMessageSyncing.
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
"""Test creating a response with an invalid conversation ID."""
with pytest.raises(InvalidConversationIdError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="invalid_id", stream=False
)
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
async def test_create_response_with_nonexistent_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a non-existent conversation."""
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
)
assert "not found" in str(exc_info.value)
async def test_conversation_and_previous_response_id(
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
):
with pytest.raises(ValueError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
)
assert "Mutually exclusive parameters" in str(exc_info.value)
assert "previous_response_id" in str(exc_info.value)
assert "conversation" in str(exc_info.value)

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_guardrail_ids,
run_guardrails,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
def test_extract_guardrail_ids_from_strings(responses_impl):
"""Test extraction from simple string guardrail IDs."""
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_from_objects(responses_impl):
"""Test extraction from ResponseGuardrailSpec objects."""
guardrails = [
ResponseGuardrailSpec(type="llama-guard"),
ResponseGuardrailSpec(type="content-filter"),
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter"]
def test_extract_guardrail_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
guardrails = [
"llama-guard",
ResponseGuardrailSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_guardrail_ids(None)
assert result == []
def test_extract_guardrail_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_guardrail_ids([])
assert result == []
def test_extract_guardrail_ids_unknown_format(responses_impl):
"""Test extraction with unknown guardrail format raises ValueError."""
# Create an object that's neither string nor ResponseGuardrailSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
extract_guardrail_ids(guardrails)
@pytest.fixture
def mock_safety_api():
"""Create mock safety API for guardrails testing."""
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
return safety_api
async def test_run_guardrails_no_violation(mock_safety_api):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock moderation to return non-flagged content
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result is None
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model"
async def test_run_guardrails_with_violation(mock_safety_api):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
flagged_result = ModerationObjectResults(
flagged=True,
categories={"violence": True},
user_message="Content flagged by moderation",
metadata={"violation_type": ["S1"]},
)
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
async def test_run_guardrails_empty_inputs(mock_safety_api):
"""Test guardrails validation with empty inputs."""
# Test empty guardrail_ids
result = await run_guardrails(mock_safety_api, "test", [])
assert result is None
# Test empty text
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
assert result is None
# Test both empty
result = await run_guardrails(mock_safety_api, "", [])
assert result is None

View file

@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture
@ -23,8 +23,10 @@ async def provider():
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
backend_name = "kv_batches_test"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
register_kvstore_backends({backend_name: kvstore_config})
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
# Create kvstore and mock APIs
kvstore = await kvstore_impl(config.kvstore)

View file

@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/invalid/endpoint",
"",
],
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
async def test_create_batch_embeddings_endpoint(self, provider):
"""Test that batch creation succeeds with embeddings endpoint."""
batch = await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/embeddings",
completion_window="24h",
)
assert batch.endpoint == "/v1/embeddings"

View file

@ -8,8 +8,9 @@ import boto3
import pytest
from moto import mock_aws
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
class MockUploadFile:
@ -38,11 +39,13 @@ def sample_text_file2():
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
backend_name = f"sql_s3_{tmp_path.name}"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
return S3FilesImplConfig(
bucket_name=f"test-bucket-{tmp_path.name}",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
)

View file

@ -10,47 +10,124 @@ from unittest.mock import MagicMock
import pytest
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
"config_cls,adapter_cls,provider_data_validator,config_params",
[
(
GroqConfig,
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
{},
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
{},
),
(
TogetherImplConfig,
TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
{},
),
(
LlamaCompatConfig,
LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
{},
),
(
CerebrasImplConfig,
CerebrasInferenceAdapter,
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
{},
),
(
DatabricksImplConfig,
DatabricksInferenceAdapter,
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
{},
),
(
NVIDIAConfig,
NVIDIAInferenceAdapter,
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
{},
),
(
RunpodImplConfig,
RunpodInferenceAdapter,
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
{},
),
(
FireworksImplConfig,
FireworksInferenceAdapter,
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
{},
),
(
AnthropicConfig,
AnthropicInferenceAdapter,
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
{},
),
(
GeminiConfig,
GeminiInferenceAdapter,
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
{},
),
(
SambaNovaImplConfig,
SambaNovaInferenceAdapter,
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
{},
),
(
VLLMInferenceAdapterConfig,
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
},
),
],
)
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
"""Ensure the OpenAI provider does not cache api keys across client requests"""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter = adapter_cls(config=config_cls(**config_params))
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator

View file

@ -15,16 +15,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
# Test fixtures and helper classes
class TestConfig(BaseModel):
class FakeConfig(BaseModel):
api_key: str | None = Field(default=None)
class TestProviderDataValidator(BaseModel):
class FakeProviderDataValidator(BaseModel):
test_api_key: str | None = Field(default=None)
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: TestConfig):
class FakeLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: FakeConfig):
super().__init__(
litellm_provider_name="test",
api_key_from_config=config.api_key,
@ -36,11 +36,11 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
@pytest.fixture
def adapter_with_config_key():
"""Fixture to create adapter with API key in config"""
config = TestConfig(api_key="config-api-key")
adapter = TestLiteLLMAdapter(config)
config = FakeConfig(api_key="config-api-key")
adapter = FakeLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
)
return adapter
@ -48,11 +48,11 @@ def adapter_with_config_key():
@pytest.fixture
def adapter_without_config_key():
"""Fixture to create adapter without API key in config"""
config = TestConfig(api_key=None)
adapter = TestLiteLLMAdapter(config)
config = FakeConfig(api_key=None)
adapter = FakeLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
)
return adapter

View file

@ -13,10 +13,16 @@ import pytest
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChoice,
OpenAICompletion,
OpenAICompletionChoice,
OpenAICompletionRequestWithExtraBody,
ToolChoice,
)
from llama_stack.apis.models import Model
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
@ -56,13 +62,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
tools=None,
tool_choice=ToolChoice.auto.value,
)
await vllm_inference_adapter.openai_chat_completion(params)
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
@ -171,9 +178,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
)
await vllm_inference_adapter.openai_chat_completion(params)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
@ -186,3 +196,148 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_vllm_completion_extra_body():
"""
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(
return_value=OpenAICompletion(
id="cmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAICompletionChoice(
text="joy",
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with guided_choice and prompt_logprobs as extra fields
params = OpenAICompletionRequestWithExtraBody(
model="mock-model",
prompt="I am feeling happy",
stream=False,
guided_choice=["joy", "sadness"],
prompt_logprobs=5,
)
await router.openai_completion(params)
# Verify that the client was called with extra_body containing both parameters
mock_client.completions.create.assert_called_once()
call_kwargs = mock_client.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "guided_choice" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
assert "prompt_logprobs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
async def test_vllm_chat_completion_extra_body():
"""
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="test response",
),
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with chat_template_kwargs as extra field
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
chat_template_kwargs={"thinking": True},
)
await router.openai_chat_completion(params)
# Verify that the client was called with extra_body containing chat_template_kwargs
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}

View file

@ -4,10 +4,45 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
convert_tooldef_to_chat_tool,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
# Mock run_moderation to return non-flagged result by default
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
return safety_api
@pytest.fixture
def mock_inference_api():
inference_api = AsyncMock()
return inference_api
@pytest.fixture
def mock_context():
context = AsyncMock(spec=ChatCompletionContext)
# Add required attributes that StreamingResponseOrchestrator expects
context.tool_context = AsyncMock()
context.tool_context.previous_tools = {}
context.messages = []
return context
def test_convert_tooldef_to_chat_tool_preserves_items_field():

View file

@ -19,7 +19,7 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
class FakeNVIDIASafetyAdapter(NVIDIASafetyAdapter):
"""Test implementation that provides the required shield_store."""
def __init__(self, config: NVIDIASafetyConfig, shield_store):
@ -41,7 +41,7 @@ def nvidia_adapter():
shield_store = AsyncMock()
shield_store.get_shield = AsyncMock()
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
adapter = FakeNVIDIASafetyAdapter(config=config, shield_store=shield_store)
return adapter

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -23,10 +23,10 @@ class OpenAIMixinImpl(OpenAIMixin):
__provider_id__: str = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
return "test-api-key"
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
return "http://test-base-url"
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
}
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
# Adds rerank models via construct_model_from_identifier
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
def construct_model_from_identifier(self, identifier: str) -> Model:
if identifier in self.rerank_model_ids:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
@ -62,6 +84,13 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
def mixin_with_custom_model_construction():
"""Create a test instance using custom construct_model_from_identifier"""
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithCustomModelConstruction(config=config)
@pytest.fixture
def mock_models():
"""Create multiple mock OpenAI model objects"""
@ -113,6 +142,19 @@ def mock_client_context():
return _mock_client_context
def _assert_models_match_expected(actual_models, expected_models):
"""Verify the models match expected attributes.
Args:
actual_models: List of models to verify
expected_models: Mapping of model identifier to expected attribute values
"""
for identifier, expected_attrs in expected_models.items():
model = next(m for m in actual_models if m.identifier == identifier)
for attr_name, expected_value in expected_attrs.items():
assert getattr(model, attr_name) == expected_value
class TestOpenAIMixinListModels:
"""Test cases for the list_models method"""
@ -205,7 +247,7 @@ class TestOpenAIMixinCheckModelAvailability:
assert await mixin.check_model_availability("pre-registered-model")
# Should not call the provider's list_models since model was found in store
mock_client_with_models.models.list.assert_not_called()
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
self, mixin, mock_client_with_models, mock_client_context
@ -222,7 +264,7 @@ class TestOpenAIMixinCheckModelAvailability:
assert await mixin.check_model_availability("some-mock-model-id")
# Should call the provider's list_models since model was not found in store
mock_client_with_models.models.list.assert_called_once()
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
class TestOpenAIMixinCacheBehavior:
@ -271,7 +313,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -303,7 +346,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called()
@ -340,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert result is not None
assert len(result) == 2
# Find the models in the result
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
llm_model = next(m for m in result if m.identifier == "gpt-4")
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
# Check embedding model
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
_assert_models_match_expected(result, expected_models)
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinCustomModelConstruction:
"""Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_custom_model_construction, mock_client):
result = await mixin_with_custom_model_construction.list_models()
assert result is not None
assert len(result) == 3
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinAllowedModels:

View file

@ -10,17 +10,18 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import register_kvstore_backends
EMBEDDING_DIMENSION = 384
EMBEDDING_DIMENSION = 768
COLLECTION_PREFIX = "test_collection"
@ -30,7 +31,7 @@ def vector_provider(request):
@pytest.fixture
def vector_db_id() -> str:
def vector_store_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@ -112,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
backend_name = f"kv_vector_{unique_id}"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
@pytest.fixture(scope="session")
@ -138,7 +140,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
@ -147,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
await adapter.register_vector_store(
VectorStore(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
@ -176,7 +178,7 @@ async def faiss_vec_index(embedding_dimension):
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
@ -184,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
await adapter.register_vector_store(
VectorStore(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
@ -213,7 +215,7 @@ def mock_psycopg2_connection():
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
vector_store = VectorStore(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
@ -223,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = []
original_add_chunks = index.add_chunks
@ -251,7 +253,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
db="test_db",
user="test_user",
password="test_password",
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@ -279,30 +281,30 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
await adapter.initialize()
adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
raise ValueError(f"Vector DB {vector_store_id} not found")
await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
async def mock_query_chunks(vector_store_id, query, params=None):
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
raise ValueError(f"Vector DB {vector_store_id} not found")
return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB(
test_vector_store = VectorStore(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
await adapter.register_vector_db(test_vector_db)
adapter.test_collection_id = test_vector_db.identifier
await adapter.register_vector_store(test_vector_store)
adapter.test_collection_id = test_vector_store.identifier
yield adapter
await adapter.shutdown()

View file

@ -11,8 +11,8 @@ import numpy as np
import pytest
from llama_stack.apis.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import (
@ -39,12 +39,12 @@ def loop():
@pytest.fixture
def embedding_dimension():
return 384
return 768
@pytest.fixture
def vector_db_id():
return "test_vector_db"
def vector_store_id():
return "test_vector_store"
@pytest.fixture
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
@pytest.fixture
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "mock_embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = embedding_dimension
return mock_vector_db
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
mock_vector_store = MagicMock(spec=VectorStore)
mock_vector_store.embedding_model = "mock_embedding_model"
mock_vector_store.identifier = vector_store_id
mock_vector_store.embedding_dimension = embedding_dimension
return mock_vector_store
@pytest.fixture

View file

@ -12,13 +12,15 @@ import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectorIO providers. This should only contain
@ -69,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB(
dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
@ -79,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize()
dummy = VectorDB(
dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.register_vector_store(dummy)
await vector_io_adapter.shutdown()
await vector_io_adapter.initialize()
@ -90,15 +92,15 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown()
async def test_register_and_unregister_vector_db(vector_io_adapter):
async def test_register_and_unregister_vector_store(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
dummy = VectorStore(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.register_vector_store(dummy)
assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
await vector_io_adapter.unregister_vector_store(dummy.identifier)
assert dummy.identifier not in vector_io_adapter.cache
@ -119,12 +121,43 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await vector_io_adapter.insert_chunks("db_not_exist", [])
async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
"""Ensure no KeyError when document_id is missing or in different places."""
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
fake_index = AsyncMock()
vector_io_adapter.cache["db1"] = fake_index
# Various document_id scenarios that shouldn't crash
chunks = [
Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}),
Chunk(content="no doc_id anywhere", metadata={"source": "test"}),
Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")),
]
# Should work without KeyError
await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once()
async def test_document_id_with_invalid_type_raises_error():
"""Ensure TypeError is raised when document_id is not a string."""
from llama_stack.apis.vector_io import Chunk
# Integer document_id should raise TypeError
chunk = Chunk(content="test", metadata={"document_id": 12345})
with pytest.raises(TypeError) as exc_info:
_ = chunk.document_id
assert "metadata['document_id'] must be a string" in str(exc_info.value)
assert "got int" in str(exc_info.value)
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
@ -137,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("db_missing", "q", None)
@ -149,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -165,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -181,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -196,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -326,8 +359,7 @@ async def test_create_vector_store_file_batch(vector_io_adapter):
vector_io_adapter._process_file_batch_async = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
assert batch.vector_store_id == store_id
@ -354,8 +386,7 @@ async def test_retrieve_vector_store_file_batch(vector_io_adapter):
# Create batch first
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Retrieve batch
@ -388,8 +419,7 @@ async def test_cancel_vector_store_file_batch(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Cancel batch
@ -434,8 +464,7 @@ async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# List files
@ -455,7 +484,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
with pytest.raises(VectorStoreNotFoundError):
await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id="nonexistent",
file_ids=["file_1"],
params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"]),
)
# Setup store for remaining tests
@ -472,8 +501,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
# Test wrong vector store for batch
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=["file_1"],
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
)
# Create wrong_store so it exists but the batch doesn't belong to it
@ -520,8 +548,7 @@ async def test_file_batch_pagination(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Test pagination with limit
@ -593,8 +620,7 @@ async def test_file_batch_status_filtering(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Test filtering by completed status
@ -636,8 +662,7 @@ async def test_cancel_completed_batch_fails(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Manually update status to completed
@ -671,8 +696,7 @@ async def test_file_batch_persistence_across_restarts(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
batch_id = batch.id
@ -727,8 +751,7 @@ async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
batch_id = batch.id
@ -775,10 +798,10 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
# Create multiple batches
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_1"]
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
)
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_2"]
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_2"])
)
# Complete one batch (should persist with completed status)
@ -791,7 +814,7 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
# Create a third batch that stays in progress
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_3"]
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_3"])
)
# Simulate restart - clear memory and reload from persistence
@ -952,8 +975,7 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Give time for the semaphore logic to start processing files
@ -971,3 +993,130 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
assert batch.status == "in_progress"
assert batch.file_counts.total == 8
assert batch.file_counts.in_progress == 8
async def test_embedding_config_from_metadata(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from metadata."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with embedding config in metadata
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={
"embedding_model": "test-embedding-model",
"embedding_dimension": "512",
},
model_extra={},
)
await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorStore was registered with correct embedding config from metadata
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "test-embedding-model"
assert call_args.embedding_dimension == 512
async def test_embedding_config_from_extra_body(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with embedding config in extra_body only (metadata has no embedding_model)
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={}, # Empty metadata to ensure extra_body is used
**{
"embedding_model": "extra-body-model",
"embedding_dimension": 1024,
},
)
await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorStore was registered with correct embedding config from extra_body
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "extra-body-model"
assert call_args.embedding_dimension == 1024
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with consistent embedding config in both metadata and extra_body
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={
"embedding_model": "consistent-model",
"embedding_dimension": "768",
},
**{
"embedding_model": "consistent-model",
"embedding_dimension": 768,
},
)
await vector_io_adapter.openai_create_vector_store(params)
# Should not raise any error and use metadata config
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "consistent-model"
assert call_args.embedding_dimension == 768
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
"""Test that embedding dimension defaults to 768 when not provided."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with only embedding model, no dimension (metadata empty to use extra_body)
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={}, # Empty metadata to ensure extra_body is used
**{
"embedding_model": "model-without-dimension",
},
)
await vector_io_adapter.openai_create_vector_store(params)
# Should default to 768 dimensions
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "model-without-dimension"
assert call_args.embedding_dimension == 768
async def test_embedding_config_required_model_missing(vector_io_adapter):
"""Test that missing embedding model raises error."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Mock the default model lookup to return None (no default model available)
vector_io_adapter._get_default_embedding_model_and_dimension = AsyncMock(return_value=None)
# Test with no embedding model provided
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
with pytest.raises(ValueError, match="embedding_model is required"):
await vector_io_adapter.openai_create_vector_store(params)

View file

@ -18,19 +18,19 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery:
async def test_query_raises_on_empty_vector_db_ids(self):
async def test_query_raises_on_empty_vector_store_ids(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
await rag_tool.query(content=MagicMock(), vector_store_ids=[])
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
content = "test query content"
vector_db_ids = ["db1"]
vector_store_ids = ["db1"]
chunk_metadata = ChunkMetadata(
document_id="doc1",
@ -55,7 +55,7 @@ class TestRagQuery:
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
result = await rag_tool.query(content=content, vector_store_ids=vector_store_ids)
assert result is not None
expected_metadata_string = (
@ -82,7 +82,7 @@ class TestRagQuery:
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
async def test_query_adds_vector_store_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
@ -90,7 +90,7 @@ class TestRagQuery:
files_api=MagicMock(),
)
vector_db_ids = ["db1", "db2"]
vector_store_ids = ["db1", "db2"]
# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
@ -101,7 +101,7 @@ class TestRagQuery:
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
metadata={"vector_store_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)
@ -114,7 +114,7 @@ class TestRagQuery:
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
metadata={"vector_store_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)
@ -126,13 +126,13 @@ class TestRagQuery:
]
)
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
result = await rag_tool.query(content="test", vector_store_ids=vector_store_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"]
returned_vector_store_ids = result.metadata["vector_store_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"]
assert returned_vector_store_ids == ["db1", "db2"]

View file

@ -13,12 +13,15 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from llama_stack.apis.inference.inference import OpenAIEmbeddingData
from llama_stack.apis.inference.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorDBWithIndex,
VectorStoreWithIndex,
_validate_embedding,
content_from_doc,
make_overlapped_chunks,
@ -203,15 +206,15 @@ class TestVectorStore:
assert str(excinfo.value.__cause__) == "Cannot convert to string"
class TestVectorDBWithIndex:
class TestVectorStoreWithIndex:
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test-model without embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
chunks = [
@ -224,25 +227,30 @@ class TestVectorDBWithIndex:
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
]
await vector_db_with_index.insert_chunks(chunks)
await vector_store_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_called_once_with(
"test-model without embeddings", ["Test 1", "Test 2"]
)
# Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once()
call_args = mock_inference_api.openai_embeddings.call_args[0]
assert len(call_args) == 1
params = call_args[0]
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
assert params.model == "test-model without embeddings"
assert params.input == ["Test 1", "Test 2"]
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
mock_vector_db.embedding_dimension = 3
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test-model with embeddings"
mock_vector_store.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
chunks = [
@ -250,7 +258,7 @@ class TestVectorDBWithIndex:
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
]
await vector_db_with_index.insert_chunks(chunks)
await vector_store_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_not_called()
mock_index.add_chunks.assert_called_once()
@ -259,14 +267,14 @@ class TestVectorDBWithIndex:
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3
mock_vector_db.embedding_model = "test-model with invalid embeddings"
mock_vector_store = MagicMock()
mock_vector_store.embedding_dimension = 3
mock_vector_store.embedding_model = "test-model with invalid embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
# Verify Chunk raises ValueError for invalid embedding type
@ -275,7 +283,7 @@ class TestVectorDBWithIndex:
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match="Input should be a valid list"):
await vector_db_with_index.insert_chunks(
await vector_store_with_index.insert_chunks(
[
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
@ -284,7 +292,7 @@ class TestVectorDBWithIndex:
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
await vector_db_with_index.insert_chunks(
await vector_store_with_index.insert_chunks(
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
)
@ -292,20 +300,20 @@ class TestVectorDBWithIndex:
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
]
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
await vector_store_with_index.insert_chunks(chunks_wrong_dim)
mock_inference_api.openai_embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called()
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings"
mock_vector_db.embedding_dimension = 3
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test-model with partial embeddings"
mock_vector_store.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
chunks = [
@ -319,11 +327,16 @@ class TestVectorDBWithIndex:
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
]
await vector_db_with_index.insert_chunks(chunks)
await vector_store_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_called_once_with(
"test-model with partial embeddings", ["Test 1", "Test 3"]
)
# Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once()
call_args = mock_inference_api.openai_embeddings.call_args[0]
assert len(call_args) == 1
params = call_args[0]
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
assert params.model == "test-model with partial embeddings"
assert params.input == ["Test 1", "Test 3"]
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert len(args[0]) == 3

View file

@ -8,24 +8,24 @@
import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import VectorDBWithOwner
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.datatypes import VectorStoreWithOwner
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture
def sample_vector_db():
return VectorDB(
identifier="test_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
def sample_vector_store():
return VectorStore(
identifier="test_vector_store",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
@ -45,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
assert result is None
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await disk_dist_registry.register(sample_vector_db)
async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
print(f"Registering {sample_vector_store}")
await disk_dist_registry.register(sample_vector_store)
print(f"Registering {sample_model}")
await disk_dist_registry.register(sample_model)
print("Getting vector_db")
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id
print("Getting vector_store")
result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
assert result_vector_store is not None
assert result_vector_store.identifier == sample_vector_store.identifier
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
assert result_vector_store.provider_id == sample_vector_store.provider_id
result_model = await disk_dist_registry.get("model", "test_model")
assert result_model is not None
@ -63,127 +63,137 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
assert result_model.provider_id == sample_model.provider_id
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
await disk_registry.initialize()
await disk_registry.register(sample_vector_db)
await disk_registry.register(sample_vector_store)
await disk_registry.register(sample_model)
# Test cached version loads from disk
db_path = sqlite_kvstore.db_path
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
backend_name = "kv_cached_test"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
cached_registry = CachedDiskDistributionRegistry(
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
)
await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
assert result_vector_db.provider_id == sample_vector_db.provider_id
result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
assert result_vector_store is not None
assert result_vector_store.identifier == sample_vector_store.identifier
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
assert result_vector_store.provider_id == sample_vector_store.provider_id
async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db_2",
new_vector_store = VectorStore(
identifier="test_vector_store_2",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="test_vector_store_2",
provider_id="baz",
)
await cached_disk_dist_registry.register(new_vector_db)
await cached_disk_dist_registry.register(new_vector_store)
# Verify in cache
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
assert result_vector_store is not None
assert result_vector_store.identifier == new_vector_store.identifier
assert result_vector_store.provider_id == new_vector_store.provider_id
# Verify persisted to disk
db_path = cached_disk_dist_registry.kvstore.db_path
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
backend_name = "kv_cached_new"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
new_registry = DiskDistributionRegistry(
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
)
await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
assert result_vector_store is not None
assert result_vector_store.identifier == new_vector_store.identifier
assert result_vector_store.provider_id == new_vector_store.provider_id
async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB(
identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db_2",
original_vector_store = VectorStore(
identifier="test_vector_store_2",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="test_vector_store_2",
provider_id="baz",
)
assert await cached_disk_dist_registry.register(original_vector_db)
assert await cached_disk_dist_registry.register(original_vector_store)
duplicate_vector_db = VectorDB(
identifier="test_vector_db_2",
duplicate_vector_store = VectorStore(
identifier="test_vector_store_2",
embedding_model="different-model",
embedding_dimension=384,
provider_resource_id="test_vector_db_2",
embedding_dimension=768,
provider_resource_id="test_vector_store_2",
provider_id="baz", # Same provider_id
)
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
await cached_disk_dist_registry.register(duplicate_vector_db)
with pytest.raises(
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"
):
await cached_disk_dist_registry.register(duplicate_vector_store)
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks
# Create multiple test banks
test_vector_dbs = [
VectorDB(
identifier=f"test_vector_db_{i}",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id=f"test_vector_db_{i}",
test_vector_stores = [
VectorStore(
identifier=f"test_vector_store_{i}",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id=f"test_vector_store_{i}",
provider_id=f"provider_{i}",
)
for i in range(3)
]
# Register all vector_dbs
for vector_db in test_vector_dbs:
await cached_disk_dist_registry.register(vector_db)
# Register all vector_stores
for vector_store in test_vector_stores:
await cached_disk_dist_registry.register(vector_store)
# Test get_all retrieval
all_results = await cached_disk_dist_registry.get_all()
assert len(all_results) == 3
# Verify each vector_db was stored correctly
for original_vector_db in test_vector_dbs:
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
assert len(matching_vector_dbs) == 1
stored_vector_db = matching_vector_dbs[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
# Verify each vector_store was stored correctly
for original_vector_store in test_vector_stores:
matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
assert len(matching_vector_stores) == 1
stored_vector_store = matching_vector_stores[0]
assert stored_vector_store.embedding_model == original_vector_store.embedding_model
assert stored_vector_store.provider_id == original_vector_store.provider_id
assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_vector_db",
valid_db = VectorStore(
identifier="valid_vector_store",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="valid_vector_store",
provider_id="test-provider",
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
)
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
'{"type": "vector_store", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(sqlite_kvstore)
@ -194,32 +204,32 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
# Should have filtered out the invalid entries
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_vector_db"
assert all_objects[0].identifier == "valid_vector_store"
# Check that the get method also handles errors correctly
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
invalid_obj = await test_registry.get("vector_store", "corrupted_json")
assert invalid_obj is None
invalid_obj = await test_registry.get("vector_db", "missing_fields")
invalid_obj = await test_registry.get("vector_store", "missing_fields")
assert invalid_obj is None
async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB(
valid_db = VectorStore(
identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="valid_cached_db",
provider_id="test-provider",
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
'{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
@ -229,63 +239,65 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_cached_db"
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
assert invalid_obj is None
async def test_double_registration_identical_objects(disk_dist_registry):
"""Test that registering identical objects succeeds (idempotent)."""
vector_db = VectorDBWithOwner(
identifier="test_vector_db",
vector_store = VectorStoreWithOwner(
identifier="test_vector_store",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
# First registration should succeed
result1 = await disk_dist_registry.register(vector_db)
result1 = await disk_dist_registry.register(vector_store)
assert result1 is True
# Second registration of identical object should also succeed (idempotent)
result2 = await disk_dist_registry.register(vector_db)
result2 = await disk_dist_registry.register(vector_store)
assert result2 is True
# Verify object exists and is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
assert retrieved is not None
assert retrieved.identifier == vector_db.identifier
assert retrieved.embedding_model == vector_db.embedding_model
assert retrieved.identifier == vector_store.identifier
assert retrieved.embedding_model == vector_store.embedding_model
async def test_double_registration_different_objects(disk_dist_registry):
"""Test that registering different objects with same identifier fails."""
vector_db1 = VectorDBWithOwner(
identifier="test_vector_db",
vector_store1 = VectorStoreWithOwner(
identifier="test_vector_store",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
vector_db2 = VectorDBWithOwner(
identifier="test_vector_db", # Same identifier
vector_store2 = VectorStoreWithOwner(
identifier="test_vector_store", # Same identifier
embedding_model="different-model", # Different embedding model
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
# First registration should succeed
result1 = await disk_dist_registry.register(vector_db1)
result1 = await disk_dist_registry.register(vector_store1)
assert result1 is True
# Second registration with different data should fail
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
await disk_dist_registry.register(vector_db2)
with pytest.raises(
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"
):
await disk_dist_registry.register(vector_store2)
# Verify original object is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
assert retrieved is not None
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value

View file

@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
- permit:
principal: user-2
actions: [read]
resource: model::model-1
resource: model::test_provider/model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
resource: model::test_provider/model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.register_model(
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
)
await routing_table.register_model(
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
)
await routing_table.register_model(
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "test_provider/model-1"
model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "test_provider/model-2"
model = await routing_table.get_model("test_provider/model-3")
assert model.identifier == "test_provider/model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "test_provider/model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
await routing_table.get_model("test_provider/model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
await routing_table.unregister_model("test_provider/model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "test_provider/model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
await routing_table.get_model("test_provider/model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
await routing_table.unregister_model("test_provider/model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
await routing_table.unregister_model("test_provider/model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
def test_permit_when():

View file

@ -5,7 +5,9 @@
# the root directory of this source tree.
import base64
from unittest.mock import AsyncMock, patch
import json
import logging # allow-direct-logging
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import FastAPI
@ -26,6 +28,13 @@ from llama_stack.core.server.auth_providers import (
)
@pytest.fixture
def suppress_auth_errors(caplog):
"""Suppress expected ERROR/WARNING logs for tests that deliberately trigger authentication errors"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
@ -122,7 +131,7 @@ def mock_impls():
@pytest.fixture
def scope_middleware_with_mocks(mock_auth_endpoint):
def middleware_with_mocks(mock_auth_endpoint):
"""Create AuthenticationMiddleware with mocked route implementations"""
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
@ -137,18 +146,20 @@ def scope_middleware_with_mocks(mock_auth_endpoint):
# Mock the route_impls to simulate finding routes with required scopes
from llama_stack.schema_utils import WebMethod
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
public_webmethod = WebMethod(route="/test/public", method="GET")
routes = {
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),
("GET", "/test/public"): WebMethod(route="/test/public", method="GET"),
("GET", "/health"): WebMethod(route="/health", method="GET", require_authentication=False),
("GET", "/version"): WebMethod(route="/version", method="GET", require_authentication=False),
("GET", "/models/list"): WebMethod(route="/models/list", method="GET", require_authentication=True),
}
# Mock the route finding logic
def mock_find_matching_route(method, path, route_impls):
if method == "POST" and path == "/test/scoped":
return None, {}, "/test/scoped", scoped_webmethod
elif method == "GET" and path == "/test/public":
return None, {}, "/test/public", public_webmethod
else:
raise ValueError("No matching route")
webmethod = routes.get((method, path))
if webmethod:
return None, {}, path, webmethod
raise ValueError("No matching route")
import llama_stack.core.server.auth
@ -234,20 +245,20 @@ def test_valid_http_authentication(http_client, valid_api_key):
@patch("httpx.AsyncClient.post", new=mock_post_failure)
def test_invalid_http_authentication(http_client, invalid_api_key):
def test_invalid_http_authentication(http_client, invalid_api_key, suppress_auth_errors):
response = http_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Authentication failed" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_exception)
def test_http_auth_service_error(http_client, valid_api_key):
def test_http_auth_service_error(http_client, valid_api_key, suppress_auth_errors):
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 401
assert "Authentication service error" in response.json()["error"]["message"]
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint):
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint, suppress_auth_errors):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = MockResponse(200, {"message": "Authentication successful"})
mock_post.return_value = mock_response
@ -372,7 +383,7 @@ async def mock_jwks_response(*args, **kwargs):
@pytest.fixture
def jwt_token_valid():
from jose import jwt
import jwt
return jwt.encode(
{
@ -387,15 +398,37 @@ def jwt_token_valid():
)
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
@pytest.fixture
def mock_jwks_urlopen():
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
with patch("urllib.request.urlopen") as mock_urlopen:
# Mock the JWKS response for PyJWKClient
mock_response = Mock()
mock_response.read.return_value = json.dumps(
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": base64.b64encode(b"foobarbaz").decode(),
}
]
}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
yield mock_urlopen
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid JWT token" in response.json()["error"]["message"]
@ -440,13 +473,12 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 401
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@ -492,6 +524,82 @@ def test_get_attributes_from_claims():
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes["namespaces"] == ["my-tenant"]
# Test nested claims with dot notation (e.g., Keycloak resource_access structure)
claims = {
"sub": "user123",
"resource_access": {"llamastack": {"roles": ["inference_max", "admin"]}, "other-client": {"roles": ["viewer"]}},
"realm_access": {"roles": ["offline_access", "uma_authorization"]},
}
attributes = get_attributes_from_claims(
claims, {"resource_access.llamastack.roles": "roles", "realm_access.roles": "realm_roles"}
)
assert set(attributes["roles"]) == {"inference_max", "admin"}
assert set(attributes["realm_roles"]) == {"offline_access", "uma_authorization"}
# Test that dot notation takes precedence over literal keys with dots
claims = {
"my.dotted.key": "literal-value",
"my": {"dotted": {"key": "nested-value"}},
}
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
assert attributes["test"] == ["nested-value"]
# Test that literal key works when nested traversal doesn't exist
claims = {
"my.dotted.key": "literal-value",
}
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
assert attributes["test"] == ["literal-value"]
# Test missing nested paths are handled gracefully
claims = {
"sub": "user123",
"resource_access": {"other-client": {"roles": ["viewer"]}},
}
attributes = get_attributes_from_claims(
claims,
{
"resource_access.llamastack.roles": "roles", # Missing nested path
"resource_access.missing.key": "missing_attr", # Missing nested path
"completely.missing.path": "another_missing", # Completely missing
"sub": "username", # Existing path
},
)
# Only the existing claim should be in attributes
assert attributes["username"] == ["user123"]
assert "roles" not in attributes
assert "missing_attr" not in attributes
assert "another_missing" not in attributes
# Test mixture of flat and nested claims paths
claims = {
"sub": "user456",
"flat_key": "flat-value",
"scope": "read write admin",
"resource_access": {"app1": {"roles": ["role1", "role2"]}, "app2": {"roles": ["role3"]}},
"groups": ["group1", "group2"],
"metadata": {"tenant": "tenant1", "region": "us-west"},
}
attributes = get_attributes_from_claims(
claims,
{
"sub": "user_id", # Flat string
"scope": "permissions", # Flat string with spaces
"groups": "teams", # Flat list
"resource_access.app1.roles": "app1_roles", # Nested list
"resource_access.app2.roles": "app2_roles", # Nested list
"metadata.tenant": "tenant", # Nested string
"metadata.region": "region", # Nested string
},
)
assert attributes["user_id"] == ["user456"]
assert set(attributes["permissions"]) == {"read", "write", "admin"}
assert set(attributes["teams"]) == {"group1", "group2"}
assert set(attributes["app1_roles"]) == {"role1", "role2"}
assert attributes["app2_roles"] == ["role3"]
assert attributes["tenant"] == ["tenant1"]
assert attributes["region"] == ["us-west"]
# TODO: add more tests for oauth2 token provider
@ -626,21 +734,21 @@ def test_valid_introspection_authentication(introspection_client, valid_api_key)
@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
def test_inactive_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Token not active" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
def test_invalid_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Not JSON" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
def test_failed_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Token introspection failed: 500" in response.json()["error"]["message"]
@ -659,9 +767,9 @@ def test_valid_introspection_with_custom_mapping_authentication(
# Scope-based authorization tests
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
async def test_scope_authorization_success(middleware_with_mocks, valid_api_key):
"""Test that user with required scope can access protected endpoint"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -680,9 +788,9 @@ async def test_scope_authorization_success(scope_middleware_with_mocks, valid_ap
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
async def test_scope_authorization_denied(middleware_with_mocks, valid_api_key):
"""Test that user without required scope gets 403 access denied"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -710,9 +818,9 @@ async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
async def test_public_endpoint_no_scope_required(middleware_with_mocks, valid_api_key):
"""Test that public endpoints work without specific scopes"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -730,9 +838,9 @@ async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, va
mock_send.assert_not_called()
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
async def test_scope_authorization_no_auth_disabled(middleware_with_mocks):
"""Test that when auth is disabled (no user), scope checks are bypassed"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -857,20 +965,22 @@ def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_toke
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token, suppress_auth_errors):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid token" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token, suppress_auth_errors):
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
assert response.status_code == 401
assert "Token validation failed" in response.json()["error"]["message"]
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
def test_kubernetes_auth_request_payload(
kubernetes_auth_client, valid_token, mock_kubernetes_api_server, suppress_auth_errors
):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = MockResponse(
200,
@ -907,3 +1017,41 @@ def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mo
request_body = call_args[1]["json"]
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
assert request_body["kind"] == "SelfSubjectReview"
async def test_unauthenticated_endpoint_access_health(middleware_with_mocks):
"""Test that /health endpoints can be accessed without authentication"""
middleware, mock_app = middleware_with_mocks
# Test request to /health without auth header (level prefix v1 is added by router)
scope = {"type": "http", "path": "/health", "headers": [], "method": "GET"}
receive = AsyncMock()
send = AsyncMock()
# Should allow the request to proceed without authentication
await middleware(scope, receive, send)
# Verify that the request was passed to the app
mock_app.assert_called_once_with(scope, receive, send)
# Verify that no error response was sent
assert not any(call[0][0].get("status") == 401 for call in send.call_args_list)
async def test_unauthenticated_endpoint_denied_for_other_paths(middleware_with_mocks):
"""Test that endpoints other than /health and /version require authentication"""
middleware, mock_app = middleware_with_mocks
# Test request to /models/list without auth header
scope = {"type": "http", "path": "/models/list", "headers": [], "method": "GET"}
receive = AsyncMock()
send = AsyncMock()
# Should return 401 error
await middleware(scope, receive, send)
# Verify that the app was NOT called
mock_app.assert_not_called()
# Verify that a 401 error response was sent
assert any(call[0][0].get("status") == 401 for call in send.call_args_list)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging # allow-direct-logging
from unittest.mock import AsyncMock, patch
import httpx
@ -15,6 +16,13 @@ from llama_stack.core.datatypes import AuthenticationConfig, AuthProviderType, G
from llama_stack.core.server.auth import AuthenticationMiddleware
@pytest.fixture
def suppress_auth_errors(caplog):
"""Suppress expected ERROR logs for tests that deliberately trigger authentication errors"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
@ -119,7 +127,7 @@ def test_authenticated_endpoint_with_valid_github_token(mock_client_class, githu
@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client, suppress_auth_errors):
"""Test accessing protected endpoint with invalid GitHub token"""
# Mock the GitHub API to return 401 Unauthorized
mock_client = AsyncMock()

View file

@ -4,6 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging # allow-direct-logging
from uuid import uuid4
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
@ -11,7 +14,14 @@ from starlette.middleware.base import BaseHTTPMiddleware
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
from llama_stack.core.server.quota import QuotaMiddleware
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import register_kvstore_backends
@pytest.fixture
def suppress_quota_warnings(caplog):
"""Suppress expected WARNING logs for SQLite backend and quota exceeded"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.quota")
class InjectClientIDMiddleware(BaseHTTPMiddleware):
@ -29,8 +39,10 @@ class InjectClientIDMiddleware(BaseHTTPMiddleware):
def build_quota_config(db_path) -> QuotaConfig:
backend_name = f"kv_quota_{uuid4().hex}"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=str(db_path))})
return QuotaConfig(
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
kvstore=KVStoreReference(backend=backend_name, namespace="quota"),
anonymous_max_requests=1,
authenticated_max_requests=2,
period=QuotaPeriod.DAY,
@ -65,13 +77,13 @@ def auth_app(tmp_path, request):
return app
def test_authenticated_quota_allows_up_to_limit(auth_app):
def test_authenticated_quota_allows_up_to_limit(auth_app, suppress_quota_warnings):
client = TestClient(auth_app)
assert client.get("/test").status_code == 200
assert client.get("/test").status_code == 200
def test_authenticated_quota_blocks_after_limit(auth_app):
def test_authenticated_quota_blocks_after_limit(auth_app, suppress_quota_warnings):
client = TestClient(auth_app)
client.get("/test")
client.get("/test")
@ -80,7 +92,7 @@ def test_authenticated_quota_blocks_after_limit(auth_app):
assert resp.json()["error"]["message"] == "Quota exceeded"
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
def test_anonymous_quota_allows_up_to_limit(tmp_path, request, suppress_quota_warnings):
inner_app = FastAPI()
@inner_app.get("/test")
@ -102,7 +114,7 @@ def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
assert client.get("/test").status_code == 200
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
def test_anonymous_quota_blocks_after_limit(tmp_path, request, suppress_quota_warnings):
inner_app = FastAPI()
@inner_app.get("/test")

View file

@ -12,15 +12,22 @@ from unittest.mock import AsyncMock, MagicMock
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.core.datatypes import (
Api,
Provider,
StackRunConfig,
)
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.resolver import resolve_impls
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
from llama_stack.providers.utils.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
@ -65,6 +72,35 @@ class SampleImpl:
pass
def make_run_config(**overrides) -> StackRunConfig:
storage = overrides.pop(
"storage",
StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
),
),
)
register_kvstore_backends({name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("kv_")})
register_sqlstore_backends(
{name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("sql_")}
)
defaults = dict(
image_name="test_image",
apis=[],
providers={},
storage=storage,
)
defaults.update(overrides)
return StackRunConfig(**defaults)
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(
@ -78,7 +114,7 @@ async def test_resolve_impls_basic():
# Create provider registry with our provider
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
run_config = StackRunConfig(
run_config = make_run_config(
image_name="test_image",
providers={
"inference": [

View file

@ -41,7 +41,7 @@ class TestTranslateException:
self.identifier = identifier
self.owner = owner
resource = MockResource("vector_db", "test-db")
resource = MockResource("vector_store", "test-db")
exc = AccessDeniedError("create", resource, user)
result = translate_exception(exc)
@ -49,7 +49,7 @@ class TestTranslateException:
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert "test-user" in result.detail
assert "vector_db::test-db" in result.detail
assert "vector_store::test-db" in result.detail
assert "create" in result.detail
assert "roles=['user']" in result.detail
assert "teams=['dev']" in result.detail

View file

@ -5,12 +5,21 @@
# the root directory of this source tree.
import asyncio
import logging # allow-direct-logging
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
@pytest.fixture
def suppress_sse_errors(caplog):
"""Suppress expected ERROR logs for tests that deliberately trigger SSE errors"""
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.server")
async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
@ -70,7 +79,7 @@ async def test_sse_generator_client_disconnected_before_response_starts():
assert len(seen_events) == 0
async def test_sse_generator_error_before_response_starts():
async def test_sse_generator_error_before_response_starts(suppress_sse_errors):
# Raise an error before the response starts
async def async_event_gen():
raise Exception("Test error")

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import time
from tempfile import TemporaryDirectory
import pytest
@ -16,8 +15,16 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
Order,
)
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register SQL store backends for testing."""
db_path = str(tmp_path / "test.db")
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)})
def create_test_chat_completion(
@ -44,167 +51,162 @@ def create_test_chat_completion(
async def test_inference_store_pagination_basic():
"""Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data with different timestamps
base_time = int(time.time())
test_data = [
("zebra-task", base_time + 1),
("apple-job", base_time + 2),
("moon-work", base_time + 3),
("banana-run", base_time + 4),
("car-exec", base_time + 5),
]
# Create test data with different timestamps
base_time = int(time.time())
test_data = [
("zebra-task", base_time + 1),
("apple-job", base_time + 2),
("moon-work", base_time + 3),
("banana-run", base_time + 4),
("car-exec", base_time + 5),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "car-exec" # Most recent first
assert result.data[1].id == "banana-run"
assert result.has_more is True
assert result.last_id == "banana-run"
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "car-exec" # Most recent first
assert result.data[1].id == "banana-run"
assert result.has_more is True
assert result.last_id == "banana-run"
# Test 2: Second page using 'after' parameter
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
assert len(result2.data) == 2
assert result2.data[0].id == "moon-work"
assert result2.data[1].id == "apple-job"
assert result2.has_more is True
# Test 2: Second page using 'after' parameter
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
assert len(result2.data) == 2
assert result2.data[0].id == "moon-work"
assert result2.data[1].id == "apple-job"
assert result2.has_more is True
# Test 3: Final page
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
assert len(result3.data) == 1
assert result3.data[0].id == "zebra-task"
assert result3.has_more is False
# Test 3: Final page
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
assert len(result3.data) == 1
assert result3.data[0].id == "zebra-task"
assert result3.has_more is False
async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("delta-item", base_time + 1),
("charlie-task", base_time + 2),
("alpha-work", base_time + 3),
]
# Create test data
base_time = int(time.time())
test_data = [
("delta-item", base_time + 1),
("charlie-task", base_time + 2),
("alpha-work", base_time + 3),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
assert result.data[0].id == "delta-item" # Oldest first
assert result.has_more is True
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
assert result.data[0].id == "delta-item" # Oldest first
assert result.has_more is True
# Second page with ascending order
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
assert len(result2.data) == 1
assert result2.data[0].id == "charlie-task"
assert result2.has_more is True
# Second page with ascending order
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
assert len(result2.data) == 1
assert result2.data[0].id == "charlie-task"
assert result2.has_more is True
async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data with different models
base_time = int(time.time())
test_data = [
("xyz-task", base_time + 1, "model-a"),
("def-work", base_time + 2, "model-b"),
("pqr-job", base_time + 3, "model-a"),
("abc-run", base_time + 4, "model-b"),
]
# Create test data with different models
base_time = int(time.time())
test_data = [
("xyz-task", base_time + 1, "model-a"),
("def-work", base_time + 2, "model-b"),
("pqr-job", base_time + 3, "model-a"),
("abc-run", base_time + 4, "model-b"),
]
# Store test chat completions
for completion_id, timestamp, model in test_data:
completion = create_test_chat_completion(completion_id, timestamp, model)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp, model in test_data:
completion = create_test_chat_completion(completion_id, timestamp, model)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
assert result.data[0].id == "pqr-job" # Most recent model-a
assert result.data[0].model == "model-a"
assert result.has_more is True
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
assert result.data[0].id == "pqr-job" # Most recent model-a
assert result.data[0].model == "model-a"
assert result.has_more is True
# Second page with model filter
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
assert len(result2.data) == 1
assert result2.data[0].id == "xyz-task"
assert result2.data[0].model == "model-a"
assert result2.has_more is False
# Second page with model filter
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
assert len(result2.data) == 1
assert result2.data[0].id == "xyz-task"
assert result2.data[0].model == "model-a"
assert result2.has_more is False
async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Try to paginate with non-existent ID
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
await store.list_chat_completions(after="non-existent", limit=2)
# Try to paginate with non-existent ID
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
await store.list_chat_completions(after="non-existent", limit=2)
async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("omega-first", base_time + 1),
("beta-second", base_time + 2),
]
# Create test data
base_time = int(time.time())
test_data = [
("omega-first", base_time + 1),
("beta-second", base_time + 2),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "beta-second" # Most recent first
assert result.data[1].id == "omega-first"
assert result.has_more is False
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "beta-second" # Most recent first
assert result.data[1].id == "omega-first"
assert result.has_more is False

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
async def test_memory_kvstore_persistence_behavior():
"""Test that :memory: database doesn't persist across instances."""
config = SqliteKVStoreConfig(db_path=":memory:")
# First instance
store1 = SqliteKVStoreImpl(config)
await store1.initialize()
await store1.set("persist_test", "should_not_persist")
await store1.shutdown()
# Second instance with same config
store2 = SqliteKVStoreImpl(config)
await store2.initialize()
# Data should not be present
result = await store2.get("persist_test")
assert result is None
await store2.shutdown()

View file

@ -6,6 +6,7 @@
import time
from tempfile import TemporaryDirectory
from uuid import uuid4
import pytest
@ -15,8 +16,18 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
backend_name = f"sql_responses_{uuid4().hex}"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path)})
return ResponsesStore(
ResponsesStoreReference(backend=backend_name, table_name="responses"),
policy=policy or [],
)
def create_test_response_object(
@ -54,7 +65,7 @@ async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data with different timestamps
@ -103,7 +114,7 @@ async def test_responses_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data
@ -141,7 +152,7 @@ async def test_responses_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data with different models
@ -182,7 +193,7 @@ async def test_responses_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Try to paginate with non-existent ID
@ -194,7 +205,7 @@ async def test_responses_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data
@ -226,7 +237,7 @@ async def test_responses_store_get_response_object():
"""Test retrieving a single response object."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Store a test response
@ -254,7 +265,7 @@ async def test_responses_store_input_items_pagination():
"""Test pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Store a test response with many inputs with explicit IDs
@ -335,7 +346,7 @@ async def test_responses_store_input_items_before_pagination():
"""Test before pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Store a test response with many inputs with explicit IDs