Merge remote-tracking branch 'origin/main' into dependabot/uv/openai-2.5.0

This commit is contained in:
Ashwin Bharambe 2025-10-22 12:17:03 -07:00
commit 13450c1a68
317 changed files with 86802 additions and 18957 deletions

View file

@ -23,6 +23,27 @@ def config_with_image_name_int():
image_name: 1234
apis_to_serve: []
built_at: {datetime.now().isoformat()}
storage:
backends:
kv_default:
type: kv_sqlite
db_path: /tmp/test_kv.db
sql_default:
type: sql_sqlite
db_path: /tmp/test_sql.db
stores:
metadata:
backend: kv_default
namespace: metadata
inference:
backend: sql_default
table_name: inference
conversations:
backend: sql_default
table_name: conversations
responses:
backend: sql_default
table_name: responses
providers:
inference:
- provider_id: provider1
@ -54,6 +75,27 @@ def up_to_date_config():
image_name: foo
apis_to_serve: []
built_at: {datetime.now().isoformat()}
storage:
backends:
kv_default:
type: kv_sqlite
db_path: /tmp/test_kv.db
sql_default:
type: sql_sqlite
db_path: /tmp/test_sql.db
stores:
metadata:
backend: kv_default
namespace: metadata
inference:
backend: sql_default
table_name: inference
conversations:
backend: sql_default
table_name: conversations
responses:
backend: sql_default
table_name: responses
providers:
inference:
- provider_id: provider1

View file

@ -20,7 +20,14 @@ from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.storage.datatypes import (
ServerStoresConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
@pytest.fixture
@ -28,7 +35,18 @@ async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
storage = StorageConfig(
backends={
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
},
stores=ServerStoresConfig(
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
),
)
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
config = ConversationServiceConfig(run_config=run_config, policy=[])
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
@ -121,9 +139,18 @@ async def test_policy_configuration():
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
storage = StorageConfig(
backends={
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
},
stores=ServerStoresConfig(
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
),
)
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
service = ConversationServiceImpl(config, {})
await service.initialize()

View file

@ -21,7 +21,7 @@ async def test_single_provider_auto_selection():
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
mock_routing_table.register_vector_db = AsyncMock(
mock_routing_table.register_vector_store = AsyncMock(
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
)
mock_routing_table.get_provider_impl = AsyncMock(

View file

@ -4,90 +4,64 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for Stack validation functions.
"""
"""Unit tests for Stack validation functions."""
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.stack import validate_default_embedding_model
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
from llama_stack.core.stack import validate_vector_stores_config
from llama_stack.providers.datatypes import Api
class TestStackValidation:
"""Test Stack validation functions."""
class TestVectorStoresValidation:
async def test_validate_missing_model(self):
"""Test validation fails when model not found."""
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="p",
model_id="missing",
),
),
)
mock_models = AsyncMock()
mock_models.list_models.return_value = ListModelsResponse(data=[])
@pytest.mark.parametrize(
"models,should_raise",
[
([], False), # No models
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
)
],
False,
), # Single default
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="emb2",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb2",
),
],
True,
), # Multiple defaults
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="llm1",
model_type=ModelType.llm,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="llm1",
),
],
False,
), # Ignores non-embedding
],
)
async def test_validate_default_embedding_model(self, models, should_raise):
"""Test validation with various model configurations."""
mock_models_impl = AsyncMock()
mock_models_impl.list_models.return_value = models
impls = {Api.models: mock_models_impl}
with pytest.raises(ValueError, match="not found"):
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
if should_raise:
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await validate_default_embedding_model(impls)
else:
await validate_default_embedding_model(impls)
async def test_validate_success(self):
"""Test validation passes with valid model."""
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="p",
model_id="valid",
),
),
)
mock_models = AsyncMock()
mock_models.list_models.return_value = ListModelsResponse(
data=[
Model(
identifier="p/valid", # Must match provider_id/model_id format
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768},
provider_id="p",
provider_resource_id="valid",
)
]
)
async def test_validate_default_embedding_model_no_models_api(self):
"""Test validation when models API is not available."""
await validate_default_embedding_model({})
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for storage backend/reference validation."""
import pytest
from pydantic import ValidationError
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
StackRunConfig,
)
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
def _base_run_config(**overrides):
metadata_reference = overrides.pop(
"metadata_reference",
KVStoreReference(backend="kv_default", namespace="registry"),
)
inference_reference = overrides.pop(
"inference_reference",
InferenceStoreReference(backend="sql_default", table_name="inference"),
)
conversations_reference = overrides.pop(
"conversations_reference",
SqlStoreReference(backend="sql_default", table_name="conversations"),
)
storage = overrides.pop(
"storage",
StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
"sql_default": SqliteSqlStoreConfig(db_path="/tmp/sql.db"),
},
stores=ServerStoresConfig(
metadata=metadata_reference,
inference=inference_reference,
conversations=conversations_reference,
),
),
)
return StackRunConfig(
version=LLAMA_STACK_RUN_CONFIG_VERSION,
image_name="test-distro",
apis=[],
providers={},
storage=storage,
**overrides,
)
def test_references_require_known_backend():
with pytest.raises(ValidationError, match="unknown backend 'missing'"):
_base_run_config(metadata_reference=KVStoreReference(backend="missing", namespace="registry"))
def test_references_must_match_backend_family():
with pytest.raises(ValidationError, match="kv_.* is required"):
_base_run_config(metadata_reference=KVStoreReference(backend="sql_default", namespace="registry"))
with pytest.raises(ValidationError, match="sql_.* is required"):
_base_run_config(
inference_reference=InferenceStoreReference(backend="kv_default", table_name="inference"),
)
def test_valid_configuration_passes_validation():
config = _base_run_config()
stores = config.storage.stores
assert stores.metadata is not None and stores.metadata.backend == "kv_default"
assert stores.inference is not None and stores.inference.backend == "sql_default"
assert stores.conversations is not None and stores.conversations.backend == "sql_default"

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.cli.stack._build import (
_run_stack_build_command_from_build_config,
)
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.utils.image_types import LlamaStackImageType
def test_container_build_passes_path(monkeypatch, tmp_path):
called_with = {}
def spy_build_image(build_config, image_name, distro_or_config, run_config=None):
called_with["path"] = distro_or_config
called_with["run_config"] = run_config
return 0
monkeypatch.setattr(
"llama_stack.cli.stack._build.build_image",
spy_build_image,
raising=True,
)
cfg = BuildConfig(
image_type=LlamaStackImageType.CONTAINER.value,
distribution_spec=DistributionSpec(providers={}, description=""),
)
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
assert "path" in called_with
assert isinstance(called_with["path"], str)
assert Path(called_with["path"]).exists()
assert called_with["run_config"] is None

View file

@ -13,6 +13,15 @@ from pydantic import BaseModel, Field, ValidationError
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import ProviderSpec
@ -29,6 +38,32 @@ class SampleConfig(BaseModel):
}
def _default_storage() -> StorageConfig:
return StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
),
)
def make_stack_config(**overrides) -> StackRunConfig:
storage = overrides.pop("storage", _default_storage())
defaults = dict(
image_name="test_image",
apis=[],
providers={},
storage=storage,
)
defaults.update(overrides)
return StackRunConfig(**defaults)
@pytest.fixture
def mock_providers():
"""Mock the available_providers function to return test providers."""
@ -47,8 +82,8 @@ def mock_providers():
@pytest.fixture
def base_config(tmp_path):
"""Create a base StackRunConfig with common settings."""
return StackRunConfig(
image_name="test_image",
return make_stack_config(
apis=["inference"],
providers={
"inference": [
Provider(
@ -222,8 +257,8 @@ class TestProviderRegistry:
def test_missing_directory(self, mock_providers):
"""Test handling of missing external providers directory."""
config = StackRunConfig(
image_name="test_image",
config = make_stack_config(
apis=["inference"],
providers={
"inference": [
Provider(
@ -278,7 +313,6 @@ pip_packages:
"""Test loading an external provider from a module (success path)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
# Simulate a provider module with get_provider_spec
@ -293,7 +327,7 @@ pip_packages:
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -317,12 +351,11 @@ pip_packages:
def test_external_provider_from_module_not_found(self, mock_providers):
"""Test handling ModuleNotFoundError for missing provider module."""
from llama_stack.core.datatypes import Provider, StackRunConfig
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -341,12 +374,11 @@ pip_packages:
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
from llama_stack.core.datatypes import Provider, StackRunConfig
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -399,13 +431,12 @@ class TestGetExternalProvidersFromModule:
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -426,7 +457,6 @@ class TestGetExternalProvidersFromModule:
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -444,7 +474,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -564,7 +594,6 @@ class TestGetExternalProvidersFromModule:
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -589,7 +618,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -613,7 +642,6 @@ class TestGetExternalProvidersFromModule:
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -638,7 +666,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -662,7 +690,6 @@ class TestGetExternalProvidersFromModule:
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -688,7 +715,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -718,7 +745,6 @@ class TestGetExternalProvidersFromModule:
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
@ -727,7 +753,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -751,7 +777,6 @@ class TestGetExternalProvidersFromModule:
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
@ -765,7 +790,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
@ -787,10 +812,9 @@ class TestGetExternalProvidersFromModule:
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={},
)
@ -805,7 +829,6 @@ class TestGetExternalProvidersFromModule:
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
@ -830,7 +853,7 @@ class TestGetExternalProvidersFromModule:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
config = make_stack_config(
image_name="test_image",
providers={
"inference": [

View file

@ -11,11 +11,12 @@ from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.providers.inline.files.localfs import (
LocalfsFilesImpl,
LocalfsFilesImplConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
class MockUploadFile:
@ -36,8 +37,11 @@ async def files_provider(tmp_path):
storage_dir = tmp_path / "files"
db_path = tmp_path / "files_metadata.db"
backend_name = "sql_localfs_test"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
config = LocalfsFilesImplConfig(
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
storage_dir=storage_dir.as_posix(),
metadata_store=SqlStoreReference(backend=backend_name, table_name="files_metadata"),
)
provider = LocalfsFilesImpl(config, default_policy())

View file

@ -9,7 +9,16 @@ import random
import pytest
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture
@ -19,12 +28,28 @@ async def temp_prompt_store(tmp_path_factory):
db_path = str(temp_dir / f"{unique_id}.db")
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
storage = StorageConfig(
backends={
"kv_test": SqliteKVStoreConfig(db_path=db_path),
"sql_test": SqliteSqlStoreConfig(db_path=str(temp_dir / f"{unique_id}_sql.db")),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_test", namespace="registry"),
inference=InferenceStoreReference(backend="sql_test", table_name="inference"),
conversations=SqlStoreReference(backend="sql_test", table_name="conversations"),
),
)
mock_run_config = StackRunConfig(
image_name="test-distribution",
apis=[],
providers={},
storage=storage,
)
config = PromptServiceConfig(run_config=mock_run_config)
store = PromptServiceImpl(config, deps={})
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
register_kvstore_backends({"kv_test": storage.backends["kv_test"]})
store.kvstore = await kvstore_impl(KVStoreReference(backend="kv_test", namespace="prompts"))
yield store

View file

@ -26,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register KV and SQL store backends for testing."""
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
kv_path = str(tmp_path / "test_kv.db")
sql_path = str(tmp_path / "test_sql.db")
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
@pytest.fixture
def mock_apis():
return {
@ -40,15 +54,20 @@ def mock_apis():
@pytest.fixture
def config(tmp_path):
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
return MetaReferenceAgentsImplConfig(
persistence_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
responses_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
persistence=AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
)

View file

@ -42,7 +42,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@ -50,7 +50,7 @@ from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -814,6 +814,69 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?"
async def test_create_openai_response_with_previous_response_instructions(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test prepending instructions and previous response with instructions."""
input_item_message = OpenAIResponseMessage(
id="123",
content="Name some towns in Ireland",
role="user",
)
response_output_message = OpenAIResponseMessage(
id="123",
content="Galway, Longford, Sligo",
status="completed",
role="assistant",
)
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
output=[response_output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
instructions="You are a helpful assistant.",
)
mock_responses_store.get_response_object.return_value = response
model = "meta-llama/Llama-3.1-8B-Instruct"
instructions = "You are a geography expert. Provide concise answers."
mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Execute
await openai_responses_impl.create_openai_response(
input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123"
)
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
# and that the previous response instructions were not carried over
assert len(sent_messages) == 4, sent_messages
assert sent_messages[0].role == "system"
assert sent_messages[0].content == instructions
# Check the rest of the messages were converted correctly
assert sent_messages[1].role == "user"
assert sent_messages[1].content == "Name some towns in Ireland"
assert sent_messages[2].role == "assistant"
assert sent_messages[2].content == "Galway, Longford, Sligo"
assert sent_messages[3].role == "user"
assert sent_messages[3].content == "Which is the largest?"
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup
@ -854,8 +917,10 @@ async def test_responses_store_list_input_items_logic():
# Create mock store and response store
mock_sql_store = AsyncMock()
backend_name = "sql_responses_test"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
responses_store = ResponsesStore(
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
)
responses_store.sql_store = mock_sql_store

View file

@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture
@ -23,8 +23,10 @@ async def provider():
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
backend_name = "kv_batches_test"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
register_kvstore_backends({backend_name: kvstore_config})
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
# Create kvstore and mock APIs
kvstore = await kvstore_impl(config.kvstore)

View file

@ -8,8 +8,9 @@ import boto3
import pytest
from moto import mock_aws
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
class MockUploadFile:
@ -38,11 +39,13 @@ def sample_text_file2():
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
backend_name = f"sql_s3_{tmp_path.name}"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
return S3FilesImplConfig(
bucket_name=f"test-bucket-{tmp_path.name}",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
)

View file

@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
}
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
# Adds rerank models via construct_model_from_identifier
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
def construct_model_from_identifier(self, identifier: str) -> Model:
if identifier in self.rerank_model_ids:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
@ -62,6 +84,13 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
def mixin_with_custom_model_construction():
"""Create a test instance using custom construct_model_from_identifier"""
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithCustomModelConstruction(config=config)
@pytest.fixture
def mock_models():
"""Create multiple mock OpenAI model objects"""
@ -113,6 +142,19 @@ def mock_client_context():
return _mock_client_context
def _assert_models_match_expected(actual_models, expected_models):
"""Verify the models match expected attributes.
Args:
actual_models: List of models to verify
expected_models: Mapping of model identifier to expected attribute values
"""
for identifier, expected_attrs in expected_models.items():
model = next(m for m in actual_models if m.identifier == identifier)
for attr_name, expected_value in expected_attrs.items():
assert getattr(model, attr_name) == expected_value
class TestOpenAIMixinListModels:
"""Test cases for the list_models method"""
@ -342,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert result is not None
assert len(result) == 2
# Find the models in the result
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
llm_model = next(m for m in result if m.identifier == "gpt-4")
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
# Check embedding model
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
_assert_models_match_expected(result, expected_models)
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinCustomModelConstruction:
"""Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_custom_model_construction, mock_client):
result = await mixin_with_custom_model_construction.list_models()
assert result is not None
assert len(result) == 3
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinAllowedModels:

View file

@ -10,15 +10,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import register_kvstore_backends
EMBEDDING_DIMENSION = 768
COLLECTION_PREFIX = "test_collection"
@ -30,7 +31,7 @@ def vector_provider(request):
@pytest.fixture
def vector_db_id() -> str:
def vector_store_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@ -112,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
backend_name = f"kv_vector_{unique_id}"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
@pytest.fixture(scope="session")
@ -138,18 +140,17 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
models_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
await adapter.register_vector_store(
VectorStore(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
@ -177,17 +178,16 @@ async def faiss_vec_index(embedding_dimension):
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
models_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
await adapter.register_vector_store(
VectorStore(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
@ -215,7 +215,7 @@ def mock_psycopg2_connection():
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
vector_store = VectorStore(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
@ -225,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = []
original_add_chunks = index.add_chunks
@ -253,7 +253,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
db="test_db",
user="test_user",
password="test_password",
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@ -281,30 +281,30 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
await adapter.initialize()
adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
raise ValueError(f"Vector DB {vector_store_id} not found")
await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
async def mock_query_chunks(vector_store_id, query, params=None):
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
raise ValueError(f"Vector DB {vector_store_id} not found")
return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB(
test_vector_store = VectorStore(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
await adapter.register_vector_db(test_vector_db)
adapter.test_collection_id = test_vector_db.identifier
await adapter.register_vector_store(test_vector_store)
adapter.test_collection_id = test_vector_store.identifier
yield adapter
await adapter.shutdown()

View file

@ -11,9 +11,8 @@ import numpy as np
import pytest
from llama_stack.apis.files import Files
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import (
@ -44,8 +43,8 @@ def embedding_dimension():
@pytest.fixture
def vector_db_id():
return "test_vector_db"
def vector_store_id():
return "test_vector_store"
@pytest.fixture
@ -62,12 +61,12 @@ def sample_embeddings(embedding_dimension):
@pytest.fixture
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "mock_embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = embedding_dimension
return mock_vector_db
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
mock_vector_store = MagicMock(spec=VectorStore)
mock_vector_store.embedding_model = "mock_embedding_model"
mock_vector_store.identifier = vector_store_id
mock_vector_store.embedding_dimension = embedding_dimension
return mock_vector_store
@pytest.fixture
@ -76,12 +75,6 @@ def mock_files_api():
return mock_api
@pytest.fixture
def mock_models_api():
mock_api = MagicMock(spec=Models)
return mock_api
@pytest.fixture
def faiss_config():
config = MagicMock(spec=FaissVectorIOConfig)
@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1]
async def test_health_success(mock_models_api):
async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
@ -126,9 +119,7 @@ async def test_health_success(mock_models_api):
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.return_value = MagicMock()
adapter = FaissVectorIOAdapter(
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
)
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()
@ -142,7 +133,7 @@ async def test_health_success(mock_models_api):
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
async def test_health_failure(mock_models_api):
async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api):
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.side_effect = Exception("Test error")
adapter = FaissVectorIOAdapter(
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
)
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()

View file

@ -6,14 +6,12 @@
import json
import time
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
@ -22,6 +20,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectorIO providers. This should only contain
@ -72,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB(
dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
@ -82,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize()
dummy = VectorDB(
dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.register_vector_store(dummy)
await vector_io_adapter.shutdown()
await vector_io_adapter.initialize()
@ -93,15 +92,15 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown()
async def test_register_and_unregister_vector_db(vector_io_adapter):
async def test_register_and_unregister_vector_store(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
dummy = VectorStore(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.register_vector_store(dummy)
assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
await vector_io_adapter.unregister_vector_store(dummy.identifier)
assert dummy.identifier not in vector_io_adapter.cache
@ -122,7 +121,7 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await vector_io_adapter.insert_chunks("db_not_exist", [])
@ -171,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("db_missing", "q", None)
@ -183,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -199,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -215,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -230,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -996,101 +995,11 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
assert batch.file_counts.in_progress == 8
async def test_get_default_embedding_model_success(vector_io_adapter):
"""Test successful default embedding model detection."""
# Mock models API with a default model
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="nomic-embed-text-v1.5",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={
"embedding_dimension": 768,
"default_configured": True,
},
)
]
)
)
vector_io_adapter.models_api = mock_models_api
result = await vector_io_adapter._get_default_embedding_model_and_dimension()
assert result is not None
model_id, dimension = result
assert model_id == "nomic-embed-text-v1.5"
assert dimension == 768
async def test_get_default_embedding_model_multiple_defaults_error(vector_io_adapter):
"""Test error when multiple models are marked as default."""
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="model1",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 768, "default_configured": True},
),
Model(
identifier="model2",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 512, "default_configured": True},
),
]
)
)
vector_io_adapter.models_api = mock_models_api
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await vector_io_adapter._get_default_embedding_model_and_dimension()
async def test_openai_create_vector_store_uses_default_model(vector_io_adapter):
"""Test that vector store creation uses default embedding model when none specified."""
# Mock models API and dependencies
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="default-model",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 512, "default_configured": True},
)
]
)
)
vector_io_adapter.models_api = mock_models_api
vector_io_adapter.register_vector_db = AsyncMock()
vector_io_adapter.__provider_id__ = "test-provider"
# Create vector store without specifying embedding model
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test-store")
result = await vector_io_adapter.openai_create_vector_store(params)
# Verify the vector store was created with default model
assert result.name == "test-store"
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
assert call_args.embedding_model == "default-model"
assert call_args.embedding_dimension == 512
async def test_embedding_config_from_metadata(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from metadata."""
# Mock register_vector_db to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock()
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
@ -1106,9 +1015,9 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorDB was registered with correct embedding config from metadata
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
# Verify VectorStore was registered with correct embedding config from metadata
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "test-embedding-model"
assert call_args.embedding_dimension == 512
@ -1116,8 +1025,8 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
async def test_embedding_config_from_extra_body(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
# Mock register_vector_db to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock()
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
@ -1133,9 +1042,9 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorDB was registered with correct embedding config from extra_body
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
# Verify VectorStore was registered with correct embedding config from extra_body
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "extra-body-model"
assert call_args.embedding_dimension == 1024
@ -1143,8 +1052,8 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
# Mock register_vector_db to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock()
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
@ -1164,61 +1073,17 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params)
# Should not raise any error and use metadata config
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "consistent-model"
assert call_args.embedding_dimension == 768
async def test_embedding_config_inconsistency_errors(vector_io_adapter):
"""Test that inconsistent embedding config between metadata and extra_body raises errors."""
# Mock register_vector_db to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with inconsistent embedding model
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={
"embedding_model": "metadata-model",
"embedding_dimension": "768",
},
**{
"embedding_model": "extra-body-model",
"embedding_dimension": 768,
},
)
with pytest.raises(ValueError, match="Embedding model inconsistent between metadata"):
await vector_io_adapter.openai_create_vector_store(params)
# Reset mock for second test
vector_io_adapter.register_vector_db.reset_mock()
# Test with inconsistent embedding dimension
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={
"embedding_model": "same-model",
"embedding_dimension": "512",
},
**{
"embedding_model": "same-model",
"embedding_dimension": 1024,
},
)
with pytest.raises(ValueError, match="Embedding dimension inconsistent between metadata"):
await vector_io_adapter.openai_create_vector_store(params)
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
"""Test that embedding dimension defaults to 768 when not provided."""
# Mock register_vector_db to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock()
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
@ -1234,8 +1099,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params)
# Should default to 768 dimensions
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "model-without-dimension"
assert call_args.embedding_dimension == 768
@ -1243,8 +1108,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
async def test_embedding_config_required_model_missing(vector_io_adapter):
"""Test that missing embedding model raises error."""
# Mock register_vector_db to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock()
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Mock the default model lookup to return None (no default model available)
@ -1253,5 +1118,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
# Test with no embedding model provided
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"):
with pytest.raises(ValueError, match="embedding_model is required"):
await vector_io_adapter.openai_create_vector_store(params)

View file

@ -18,7 +18,7 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery:
async def test_query_raises_on_empty_vector_db_ids(self):
async def test_query_raises_on_empty_vector_store_ids(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
@ -82,7 +82,7 @@ class TestRagQuery:
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
async def test_query_adds_vector_store_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),

View file

@ -21,7 +21,7 @@ from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorDBWithIndex,
VectorStoreWithIndex,
_validate_embedding,
content_from_doc,
make_overlapped_chunks,
@ -206,15 +206,15 @@ class TestVectorStore:
assert str(excinfo.value.__cause__) == "Cannot convert to string"
class TestVectorDBWithIndex:
class TestVectorStoreWithIndex:
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test-model without embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
chunks = [
@ -227,7 +227,7 @@ class TestVectorDBWithIndex:
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
]
await vector_db_with_index.insert_chunks(chunks)
await vector_store_with_index.insert_chunks(chunks)
# Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once()
@ -243,14 +243,14 @@ class TestVectorDBWithIndex:
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
mock_vector_db.embedding_dimension = 3
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test-model with embeddings"
mock_vector_store.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
chunks = [
@ -258,7 +258,7 @@ class TestVectorDBWithIndex:
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
]
await vector_db_with_index.insert_chunks(chunks)
await vector_store_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_not_called()
mock_index.add_chunks.assert_called_once()
@ -267,14 +267,14 @@ class TestVectorDBWithIndex:
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3
mock_vector_db.embedding_model = "test-model with invalid embeddings"
mock_vector_store = MagicMock()
mock_vector_store.embedding_dimension = 3
mock_vector_store.embedding_model = "test-model with invalid embeddings"
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
# Verify Chunk raises ValueError for invalid embedding type
@ -283,7 +283,7 @@ class TestVectorDBWithIndex:
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match="Input should be a valid list"):
await vector_db_with_index.insert_chunks(
await vector_store_with_index.insert_chunks(
[
Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
@ -292,7 +292,7 @@ class TestVectorDBWithIndex:
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
await vector_db_with_index.insert_chunks(
await vector_store_with_index.insert_chunks(
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
)
@ -300,20 +300,20 @@ class TestVectorDBWithIndex:
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
]
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
await vector_store_with_index.insert_chunks(chunks_wrong_dim)
mock_inference_api.openai_embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called()
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings"
mock_vector_db.embedding_dimension = 3
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test-model with partial embeddings"
mock_vector_store.embedding_dimension = 3
mock_index = AsyncMock()
mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
)
chunks = [
@ -327,7 +327,7 @@ class TestVectorDBWithIndex:
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
]
await vector_db_with_index.insert_chunks(chunks)
await vector_store_with_index.insert_chunks(chunks)
# Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once()

View file

@ -8,24 +8,24 @@
import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import VectorDBWithOwner
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.datatypes import VectorStoreWithOwner
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture
def sample_vector_db():
return VectorDB(
identifier="test_vector_db",
def sample_vector_store():
return VectorStore(
identifier="test_vector_store",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
@ -45,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
assert result is None
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await disk_dist_registry.register(sample_vector_db)
async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
print(f"Registering {sample_vector_store}")
await disk_dist_registry.register(sample_vector_store)
print(f"Registering {sample_model}")
await disk_dist_registry.register(sample_model)
print("Getting vector_db")
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id
print("Getting vector_store")
result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
assert result_vector_store is not None
assert result_vector_store.identifier == sample_vector_store.identifier
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
assert result_vector_store.provider_id == sample_vector_store.provider_id
result_model = await disk_dist_registry.get("model", "test_model")
assert result_model is not None
@ -63,127 +63,137 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
assert result_model.provider_id == sample_model.provider_id
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
await disk_registry.initialize()
await disk_registry.register(sample_vector_db)
await disk_registry.register(sample_vector_store)
await disk_registry.register(sample_model)
# Test cached version loads from disk
db_path = sqlite_kvstore.db_path
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
backend_name = "kv_cached_test"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
cached_registry = CachedDiskDistributionRegistry(
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
)
await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
assert result_vector_db.provider_id == sample_vector_db.provider_id
result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
assert result_vector_store is not None
assert result_vector_store.identifier == sample_vector_store.identifier
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
assert result_vector_store.provider_id == sample_vector_store.provider_id
async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB(
identifier="test_vector_db_2",
new_vector_store = VectorStore(
identifier="test_vector_store_2",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="test_vector_db_2",
provider_resource_id="test_vector_store_2",
provider_id="baz",
)
await cached_disk_dist_registry.register(new_vector_db)
await cached_disk_dist_registry.register(new_vector_store)
# Verify in cache
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
assert result_vector_store is not None
assert result_vector_store.identifier == new_vector_store.identifier
assert result_vector_store.provider_id == new_vector_store.provider_id
# Verify persisted to disk
db_path = cached_disk_dist_registry.kvstore.db_path
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
backend_name = "kv_cached_new"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
new_registry = DiskDistributionRegistry(
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
)
await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id
result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
assert result_vector_store is not None
assert result_vector_store.identifier == new_vector_store.identifier
assert result_vector_store.provider_id == new_vector_store.provider_id
async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB(
identifier="test_vector_db_2",
original_vector_store = VectorStore(
identifier="test_vector_store_2",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="test_vector_db_2",
provider_resource_id="test_vector_store_2",
provider_id="baz",
)
assert await cached_disk_dist_registry.register(original_vector_db)
assert await cached_disk_dist_registry.register(original_vector_store)
duplicate_vector_db = VectorDB(
identifier="test_vector_db_2",
duplicate_vector_store = VectorStore(
identifier="test_vector_store_2",
embedding_model="different-model",
embedding_dimension=768,
provider_resource_id="test_vector_db_2",
provider_resource_id="test_vector_store_2",
provider_id="baz", # Same provider_id
)
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
await cached_disk_dist_registry.register(duplicate_vector_db)
with pytest.raises(
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"
):
await cached_disk_dist_registry.register(duplicate_vector_store)
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks
# Create multiple test banks
test_vector_dbs = [
VectorDB(
identifier=f"test_vector_db_{i}",
test_vector_stores = [
VectorStore(
identifier=f"test_vector_store_{i}",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id=f"test_vector_db_{i}",
provider_resource_id=f"test_vector_store_{i}",
provider_id=f"provider_{i}",
)
for i in range(3)
]
# Register all vector_dbs
for vector_db in test_vector_dbs:
await cached_disk_dist_registry.register(vector_db)
# Register all vector_stores
for vector_store in test_vector_stores:
await cached_disk_dist_registry.register(vector_store)
# Test get_all retrieval
all_results = await cached_disk_dist_registry.get_all()
assert len(all_results) == 3
# Verify each vector_db was stored correctly
for original_vector_db in test_vector_dbs:
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
assert len(matching_vector_dbs) == 1
stored_vector_db = matching_vector_dbs[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
# Verify each vector_store was stored correctly
for original_vector_store in test_vector_stores:
matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
assert len(matching_vector_stores) == 1
stored_vector_store = matching_vector_stores[0]
assert stored_vector_store.embedding_model == original_vector_store.embedding_model
assert stored_vector_store.provider_id == original_vector_store.provider_id
assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_vector_db",
valid_db = VectorStore(
identifier="valid_vector_store",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_resource_id="valid_vector_db",
provider_resource_id="valid_vector_store",
provider_id="test-provider",
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
)
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
'{"type": "vector_store", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(sqlite_kvstore)
@ -194,18 +204,18 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
# Should have filtered out the invalid entries
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_vector_db"
assert all_objects[0].identifier == "valid_vector_store"
# Check that the get method also handles errors correctly
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
invalid_obj = await test_registry.get("vector_store", "corrupted_json")
assert invalid_obj is None
invalid_obj = await test_registry.get("vector_db", "missing_fields")
invalid_obj = await test_registry.get("vector_store", "missing_fields")
assert invalid_obj is None
async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB(
valid_db = VectorStore(
identifier="valid_cached_db",
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
@ -214,12 +224,12 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
'{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
@ -229,63 +239,65 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_cached_db"
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
assert invalid_obj is None
async def test_double_registration_identical_objects(disk_dist_registry):
"""Test that registering identical objects succeeds (idempotent)."""
vector_db = VectorDBWithOwner(
identifier="test_vector_db",
vector_store = VectorStoreWithOwner(
identifier="test_vector_store",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
# First registration should succeed
result1 = await disk_dist_registry.register(vector_db)
result1 = await disk_dist_registry.register(vector_store)
assert result1 is True
# Second registration of identical object should also succeed (idempotent)
result2 = await disk_dist_registry.register(vector_db)
result2 = await disk_dist_registry.register(vector_store)
assert result2 is True
# Verify object exists and is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
assert retrieved is not None
assert retrieved.identifier == vector_db.identifier
assert retrieved.embedding_model == vector_db.embedding_model
assert retrieved.identifier == vector_store.identifier
assert retrieved.embedding_model == vector_store.embedding_model
async def test_double_registration_different_objects(disk_dist_registry):
"""Test that registering different objects with same identifier fails."""
vector_db1 = VectorDBWithOwner(
identifier="test_vector_db",
vector_store1 = VectorStoreWithOwner(
identifier="test_vector_store",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
vector_db2 = VectorDBWithOwner(
identifier="test_vector_db", # Same identifier
vector_store2 = VectorStoreWithOwner(
identifier="test_vector_store", # Same identifier
embedding_model="different-model", # Different embedding model
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_resource_id="test_vector_store",
provider_id="test-provider",
)
# First registration should succeed
result1 = await disk_dist_registry.register(vector_db1)
result1 = await disk_dist_registry.register(vector_store1)
assert result1 is True
# Second registration with different data should fail
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
await disk_dist_registry.register(vector_db2)
with pytest.raises(
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"
):
await disk_dist_registry.register(vector_store2)
# Verify original object is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
assert retrieved is not None
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value

View file

@ -516,6 +516,82 @@ def test_get_attributes_from_claims():
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes["namespaces"] == ["my-tenant"]
# Test nested claims with dot notation (e.g., Keycloak resource_access structure)
claims = {
"sub": "user123",
"resource_access": {"llamastack": {"roles": ["inference_max", "admin"]}, "other-client": {"roles": ["viewer"]}},
"realm_access": {"roles": ["offline_access", "uma_authorization"]},
}
attributes = get_attributes_from_claims(
claims, {"resource_access.llamastack.roles": "roles", "realm_access.roles": "realm_roles"}
)
assert set(attributes["roles"]) == {"inference_max", "admin"}
assert set(attributes["realm_roles"]) == {"offline_access", "uma_authorization"}
# Test that dot notation takes precedence over literal keys with dots
claims = {
"my.dotted.key": "literal-value",
"my": {"dotted": {"key": "nested-value"}},
}
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
assert attributes["test"] == ["nested-value"]
# Test that literal key works when nested traversal doesn't exist
claims = {
"my.dotted.key": "literal-value",
}
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
assert attributes["test"] == ["literal-value"]
# Test missing nested paths are handled gracefully
claims = {
"sub": "user123",
"resource_access": {"other-client": {"roles": ["viewer"]}},
}
attributes = get_attributes_from_claims(
claims,
{
"resource_access.llamastack.roles": "roles", # Missing nested path
"resource_access.missing.key": "missing_attr", # Missing nested path
"completely.missing.path": "another_missing", # Completely missing
"sub": "username", # Existing path
},
)
# Only the existing claim should be in attributes
assert attributes["username"] == ["user123"]
assert "roles" not in attributes
assert "missing_attr" not in attributes
assert "another_missing" not in attributes
# Test mixture of flat and nested claims paths
claims = {
"sub": "user456",
"flat_key": "flat-value",
"scope": "read write admin",
"resource_access": {"app1": {"roles": ["role1", "role2"]}, "app2": {"roles": ["role3"]}},
"groups": ["group1", "group2"],
"metadata": {"tenant": "tenant1", "region": "us-west"},
}
attributes = get_attributes_from_claims(
claims,
{
"sub": "user_id", # Flat string
"scope": "permissions", # Flat string with spaces
"groups": "teams", # Flat list
"resource_access.app1.roles": "app1_roles", # Nested list
"resource_access.app2.roles": "app2_roles", # Nested list
"metadata.tenant": "tenant", # Nested string
"metadata.region": "region", # Nested string
},
)
assert attributes["user_id"] == ["user456"]
assert set(attributes["permissions"]) == {"read", "write", "admin"}
assert set(attributes["teams"]) == {"group1", "group2"}
assert set(attributes["app1_roles"]) == {"role1", "role2"}
assert attributes["app2_roles"] == ["role3"]
assert attributes["tenant"] == ["tenant1"]
assert attributes["region"] == ["us-west"]
# TODO: add more tests for oauth2 token provider

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from uuid import uuid4
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
@ -11,7 +13,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
from llama_stack.core.server.quota import QuotaMiddleware
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import register_kvstore_backends
class InjectClientIDMiddleware(BaseHTTPMiddleware):
@ -29,8 +32,10 @@ class InjectClientIDMiddleware(BaseHTTPMiddleware):
def build_quota_config(db_path) -> QuotaConfig:
backend_name = f"kv_quota_{uuid4().hex}"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=str(db_path))})
return QuotaConfig(
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
kvstore=KVStoreReference(backend=backend_name, namespace="quota"),
anonymous_max_requests=1,
authenticated_max_requests=2,
period=QuotaPeriod.DAY,

View file

@ -12,15 +12,22 @@ from unittest.mock import AsyncMock, MagicMock
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.core.datatypes import (
Api,
Provider,
StackRunConfig,
)
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.resolver import resolve_impls
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
from llama_stack.providers.utils.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
@ -65,6 +72,35 @@ class SampleImpl:
pass
def make_run_config(**overrides) -> StackRunConfig:
storage = overrides.pop(
"storage",
StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
),
),
)
register_kvstore_backends({name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("kv_")})
register_sqlstore_backends(
{name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("sql_")}
)
defaults = dict(
image_name="test_image",
apis=[],
providers={},
storage=storage,
)
defaults.update(overrides)
return StackRunConfig(**defaults)
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(
@ -78,7 +114,7 @@ async def test_resolve_impls_basic():
# Create provider registry with our provider
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
run_config = StackRunConfig(
run_config = make_run_config(
image_name="test_image",
providers={
"inference": [

View file

@ -41,7 +41,7 @@ class TestTranslateException:
self.identifier = identifier
self.owner = owner
resource = MockResource("vector_db", "test-db")
resource = MockResource("vector_store", "test-db")
exc = AccessDeniedError("create", resource, user)
result = translate_exception(exc)
@ -49,7 +49,7 @@ class TestTranslateException:
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert "test-user" in result.detail
assert "vector_db::test-db" in result.detail
assert "vector_store::test-db" in result.detail
assert "create" in result.detail
assert "roles=['user']" in result.detail
assert "teams=['dev']" in result.detail

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import time
from tempfile import TemporaryDirectory
import pytest
@ -16,8 +15,16 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
Order,
)
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register SQL store backends for testing."""
db_path = str(tmp_path / "test.db")
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)})
def create_test_chat_completion(
@ -44,167 +51,162 @@ def create_test_chat_completion(
async def test_inference_store_pagination_basic():
"""Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data with different timestamps
base_time = int(time.time())
test_data = [
("zebra-task", base_time + 1),
("apple-job", base_time + 2),
("moon-work", base_time + 3),
("banana-run", base_time + 4),
("car-exec", base_time + 5),
]
# Create test data with different timestamps
base_time = int(time.time())
test_data = [
("zebra-task", base_time + 1),
("apple-job", base_time + 2),
("moon-work", base_time + 3),
("banana-run", base_time + 4),
("car-exec", base_time + 5),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "car-exec" # Most recent first
assert result.data[1].id == "banana-run"
assert result.has_more is True
assert result.last_id == "banana-run"
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "car-exec" # Most recent first
assert result.data[1].id == "banana-run"
assert result.has_more is True
assert result.last_id == "banana-run"
# Test 2: Second page using 'after' parameter
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
assert len(result2.data) == 2
assert result2.data[0].id == "moon-work"
assert result2.data[1].id == "apple-job"
assert result2.has_more is True
# Test 2: Second page using 'after' parameter
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
assert len(result2.data) == 2
assert result2.data[0].id == "moon-work"
assert result2.data[1].id == "apple-job"
assert result2.has_more is True
# Test 3: Final page
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
assert len(result3.data) == 1
assert result3.data[0].id == "zebra-task"
assert result3.has_more is False
# Test 3: Final page
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
assert len(result3.data) == 1
assert result3.data[0].id == "zebra-task"
assert result3.has_more is False
async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("delta-item", base_time + 1),
("charlie-task", base_time + 2),
("alpha-work", base_time + 3),
]
# Create test data
base_time = int(time.time())
test_data = [
("delta-item", base_time + 1),
("charlie-task", base_time + 2),
("alpha-work", base_time + 3),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
assert result.data[0].id == "delta-item" # Oldest first
assert result.has_more is True
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
assert result.data[0].id == "delta-item" # Oldest first
assert result.has_more is True
# Second page with ascending order
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
assert len(result2.data) == 1
assert result2.data[0].id == "charlie-task"
assert result2.has_more is True
# Second page with ascending order
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
assert len(result2.data) == 1
assert result2.data[0].id == "charlie-task"
assert result2.has_more is True
async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data with different models
base_time = int(time.time())
test_data = [
("xyz-task", base_time + 1, "model-a"),
("def-work", base_time + 2, "model-b"),
("pqr-job", base_time + 3, "model-a"),
("abc-run", base_time + 4, "model-b"),
]
# Create test data with different models
base_time = int(time.time())
test_data = [
("xyz-task", base_time + 1, "model-a"),
("def-work", base_time + 2, "model-b"),
("pqr-job", base_time + 3, "model-a"),
("abc-run", base_time + 4, "model-b"),
]
# Store test chat completions
for completion_id, timestamp, model in test_data:
completion = create_test_chat_completion(completion_id, timestamp, model)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp, model in test_data:
completion = create_test_chat_completion(completion_id, timestamp, model)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
assert result.data[0].id == "pqr-job" # Most recent model-a
assert result.data[0].model == "model-a"
assert result.has_more is True
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
assert result.data[0].id == "pqr-job" # Most recent model-a
assert result.data[0].model == "model-a"
assert result.has_more is True
# Second page with model filter
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
assert len(result2.data) == 1
assert result2.data[0].id == "xyz-task"
assert result2.data[0].model == "model-a"
assert result2.has_more is False
# Second page with model filter
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
assert len(result2.data) == 1
assert result2.data[0].id == "xyz-task"
assert result2.data[0].model == "model-a"
assert result2.has_more is False
async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Try to paginate with non-existent ID
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
await store.list_chat_completions(after="non-existent", limit=2)
# Try to paginate with non-existent ID
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
await store.list_chat_completions(after="non-existent", limit=2)
async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("omega-first", base_time + 1),
("beta-second", base_time + 2),
]
# Create test data
base_time = int(time.time())
test_data = [
("omega-first", base_time + 1),
("beta-second", base_time + 2),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Wait for all queued writes to complete
await store.flush()
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "beta-second" # Most recent first
assert result.data[1].id == "omega-first"
assert result.has_more is False
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "beta-second" # Most recent first
assert result.data[1].id == "omega-first"
assert result.has_more is False

View file

@ -6,6 +6,7 @@
import time
from tempfile import TemporaryDirectory
from uuid import uuid4
import pytest
@ -15,8 +16,18 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
backend_name = f"sql_responses_{uuid4().hex}"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path)})
return ResponsesStore(
ResponsesStoreReference(backend=backend_name, table_name="responses"),
policy=policy or [],
)
def create_test_response_object(
@ -54,7 +65,7 @@ async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data with different timestamps
@ -103,7 +114,7 @@ async def test_responses_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data
@ -141,7 +152,7 @@ async def test_responses_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data with different models
@ -182,7 +193,7 @@ async def test_responses_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Try to paginate with non-existent ID
@ -194,7 +205,7 @@ async def test_responses_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Create test data
@ -226,7 +237,7 @@ async def test_responses_store_get_response_object():
"""Test retrieving a single response object."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Store a test response
@ -254,7 +265,7 @@ async def test_responses_store_input_items_pagination():
"""Test pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Store a test response with many inputs with explicit IDs
@ -335,7 +346,7 @@ async def test_responses_store_input_items_before_pagination():
"""Test before pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
store = build_store(db_path)
await store.initialize()
# Store a test response with many inputs with explicit IDs