mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 08:02:36 +00:00
Merge remote-tracking branch 'origin/main' into dependabot/uv/openai-2.5.0
This commit is contained in:
commit
13450c1a68
317 changed files with 86802 additions and 18957 deletions
|
|
@ -23,6 +23,27 @@ def config_with_image_name_int():
|
|||
image_name: 1234
|
||||
apis_to_serve: []
|
||||
built_at: {datetime.now().isoformat()}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: /tmp/test_kv.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: /tmp/test_sql.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: kv_default
|
||||
namespace: metadata
|
||||
inference:
|
||||
backend: sql_default
|
||||
table_name: inference
|
||||
conversations:
|
||||
backend: sql_default
|
||||
table_name: conversations
|
||||
responses:
|
||||
backend: sql_default
|
||||
table_name: responses
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
|
@ -54,6 +75,27 @@ def up_to_date_config():
|
|||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {datetime.now().isoformat()}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: /tmp/test_kv.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: /tmp/test_sql.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: kv_default
|
||||
namespace: metadata
|
||||
inference:
|
||||
backend: sql_default
|
||||
table_name: inference
|
||||
conversations:
|
||||
backend: sql_default
|
||||
table_name: conversations
|
||||
responses:
|
||||
backend: sql_default
|
||||
table_name: responses
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
|
|
|||
|
|
@ -20,7 +20,14 @@ from llama_stack.core.conversations.conversations import (
|
|||
ConversationServiceConfig,
|
||||
ConversationServiceImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
ServerStoresConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -28,7 +35,18 @@ async def service():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations.db"
|
||||
|
||||
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||
|
||||
config = ConversationServiceConfig(run_config=run_config, policy=[])
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
yield service
|
||||
|
|
@ -121,9 +139,18 @@ async def test_policy_configuration():
|
|||
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||
]
|
||||
|
||||
config = ConversationServiceConfig(
|
||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||
|
||||
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ async def test_single_provider_auto_selection():
|
|||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.register_vector_db = AsyncMock(
|
||||
mock_routing_table.register_vector_store = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
|
|
|
|||
|
|
@ -4,90 +4,64 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for Stack validation functions.
|
||||
"""
|
||||
"""Unit tests for Stack validation functions."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.stack import validate_default_embedding_model
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
||||
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_vector_stores_config
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class TestStackValidation:
|
||||
"""Test Stack validation functions."""
|
||||
class TestVectorStoresValidation:
|
||||
async def test_validate_missing_model(self):
|
||||
"""Test validation fails when model not found."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="p",
|
||||
model_id="missing",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = ListModelsResponse(data=[])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"models,should_raise",
|
||||
[
|
||||
([], False), # No models
|
||||
(
|
||||
[
|
||||
Model(
|
||||
identifier="emb1",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb1",
|
||||
)
|
||||
],
|
||||
False,
|
||||
), # Single default
|
||||
(
|
||||
[
|
||||
Model(
|
||||
identifier="emb1",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb1",
|
||||
),
|
||||
Model(
|
||||
identifier="emb2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb2",
|
||||
),
|
||||
],
|
||||
True,
|
||||
), # Multiple defaults
|
||||
(
|
||||
[
|
||||
Model(
|
||||
identifier="emb1",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="emb1",
|
||||
),
|
||||
Model(
|
||||
identifier="llm1",
|
||||
model_type=ModelType.llm,
|
||||
metadata={"default_configured": True},
|
||||
provider_id="p",
|
||||
provider_resource_id="llm1",
|
||||
),
|
||||
],
|
||||
False,
|
||||
), # Ignores non-embedding
|
||||
],
|
||||
)
|
||||
async def test_validate_default_embedding_model(self, models, should_raise):
|
||||
"""Test validation with various model configurations."""
|
||||
mock_models_impl = AsyncMock()
|
||||
mock_models_impl.list_models.return_value = models
|
||||
impls = {Api.models: mock_models_impl}
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
|
||||
await validate_default_embedding_model(impls)
|
||||
else:
|
||||
await validate_default_embedding_model(impls)
|
||||
async def test_validate_success(self):
|
||||
"""Test validation passes with valid model."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="p",
|
||||
model_id="valid",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = ListModelsResponse(
|
||||
data=[
|
||||
Model(
|
||||
identifier="p/valid", # Must match provider_id/model_id format
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768},
|
||||
provider_id="p",
|
||||
provider_resource_id="valid",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def test_validate_default_embedding_model_no_models_api(self):
|
||||
"""Test validation when models API is not available."""
|
||||
await validate_default_embedding_model({})
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
|
|
|||
84
tests/unit/core/test_storage_references.py
Normal file
84
tests/unit/core/test_storage_references.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for storage backend/reference validation."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
|
||||
|
||||
def _base_run_config(**overrides):
|
||||
metadata_reference = overrides.pop(
|
||||
"metadata_reference",
|
||||
KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
)
|
||||
inference_reference = overrides.pop(
|
||||
"inference_reference",
|
||||
InferenceStoreReference(backend="sql_default", table_name="inference"),
|
||||
)
|
||||
conversations_reference = overrides.pop(
|
||||
"conversations_reference",
|
||||
SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
)
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path="/tmp/sql.db"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=metadata_reference,
|
||||
inference=inference_reference,
|
||||
conversations=conversations_reference,
|
||||
),
|
||||
),
|
||||
)
|
||||
return StackRunConfig(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
image_name="test-distro",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def test_references_require_known_backend():
|
||||
with pytest.raises(ValidationError, match="unknown backend 'missing'"):
|
||||
_base_run_config(metadata_reference=KVStoreReference(backend="missing", namespace="registry"))
|
||||
|
||||
|
||||
def test_references_must_match_backend_family():
|
||||
with pytest.raises(ValidationError, match="kv_.* is required"):
|
||||
_base_run_config(metadata_reference=KVStoreReference(backend="sql_default", namespace="registry"))
|
||||
|
||||
with pytest.raises(ValidationError, match="sql_.* is required"):
|
||||
_base_run_config(
|
||||
inference_reference=InferenceStoreReference(backend="kv_default", table_name="inference"),
|
||||
)
|
||||
|
||||
|
||||
def test_valid_configuration_passes_validation():
|
||||
config = _base_run_config()
|
||||
stores = config.storage.stores
|
||||
assert stores.metadata is not None and stores.metadata.backend == "kv_default"
|
||||
assert stores.inference is not None and stores.inference.backend == "sql_default"
|
||||
assert stores.conversations is not None and stores.conversations.backend == "sql_default"
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.stack._build import (
|
||||
_run_stack_build_command_from_build_config,
|
||||
)
|
||||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||
called_with = {}
|
||||
|
||||
def spy_build_image(build_config, image_name, distro_or_config, run_config=None):
|
||||
called_with["path"] = distro_or_config
|
||||
called_with["run_config"] = run_config
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.cli.stack._build.build_image",
|
||||
spy_build_image,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
cfg = BuildConfig(
|
||||
image_type=LlamaStackImageType.CONTAINER.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description=""),
|
||||
)
|
||||
|
||||
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
|
||||
|
||||
assert "path" in called_with
|
||||
assert isinstance(called_with["path"], str)
|
||||
assert Path(called_with["path"]).exists()
|
||||
assert called_with["run_config"] is None
|
||||
|
|
@ -13,6 +13,15 @@ from pydantic import BaseModel, Field, ValidationError
|
|||
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
|
|
@ -29,6 +38,32 @@ class SampleConfig(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
def _default_storage() -> StorageConfig:
|
||||
return StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_stack_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop("storage", _default_storage())
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
|
|
@ -47,8 +82,8 @@ def mock_providers():
|
|||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
return make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
|
@ -222,8 +257,8 @@ class TestProviderRegistry:
|
|||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
config = make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
|
@ -278,7 +313,6 @@ pip_packages:
|
|||
"""Test loading an external provider from a module (success path)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
# Simulate a provider module with get_provider_spec
|
||||
|
|
@ -293,7 +327,7 @@ pip_packages:
|
|||
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -317,12 +351,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_not_found(self, mock_providers):
|
||||
"""Test handling ModuleNotFoundError for missing provider module."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -341,12 +374,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -399,13 +431,12 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -426,7 +457,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -444,7 +474,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -564,7 +594,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -589,7 +618,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -613,7 +642,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -638,7 +666,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -662,7 +690,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -688,7 +715,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -718,7 +745,6 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
|
|
@ -727,7 +753,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -751,7 +777,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
|
|
@ -765,7 +790,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -787,10 +812,9 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
|
|
@ -805,7 +829,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -830,7 +853,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ from llama_stack.apis.common.errors import ResourceNotFoundError
|
|||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -36,8 +37,11 @@ async def files_provider(tmp_path):
|
|||
storage_dir = tmp_path / "files"
|
||||
db_path = tmp_path / "files_metadata.db"
|
||||
|
||||
backend_name = "sql_localfs_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
config = LocalfsFilesImplConfig(
|
||||
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
|
||||
storage_dir=storage_dir.as_posix(),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="files_metadata"),
|
||||
)
|
||||
|
||||
provider = LocalfsFilesImpl(config, default_policy())
|
||||
|
|
|
|||
|
|
@ -9,7 +9,16 @@ import random
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -19,12 +28,28 @@ async def temp_prompt_store(tmp_path_factory):
|
|||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_test": SqliteKVStoreConfig(db_path=db_path),
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(temp_dir / f"{unique_id}_sql.db")),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_test", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_test", table_name="inference"),
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="conversations"),
|
||||
),
|
||||
)
|
||||
mock_run_config = StackRunConfig(
|
||||
image_name="test-distribution",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
config = PromptServiceConfig(run_config=mock_run_config)
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
|
||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||
register_kvstore_backends({"kv_test": storage.backends["kv_test"]})
|
||||
store.kvstore = await kvstore_impl(KVStoreReference(backend="kv_test", namespace="prompts"))
|
||||
|
||||
yield store
|
||||
|
|
|
|||
|
|
@ -26,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
|
|||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register KV and SQL store backends for testing."""
|
||||
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
kv_path = str(tmp_path / "test_kv.db")
|
||||
sql_path = str(tmp_path / "test_sql.db")
|
||||
|
||||
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
return {
|
||||
|
|
@ -40,15 +54,20 @@ def mock_apis():
|
|||
|
||||
@pytest.fixture
|
||||
def config(tmp_path):
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
|
||||
|
||||
return MetaReferenceAgentsImplConfig(
|
||||
persistence_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
responses_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
persistence=AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -50,7 +50,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
|
|
@ -814,6 +814,69 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_previous_response_instructions(
|
||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test prepending instructions and previous response with instructions."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Name some towns in Ireland",
|
||||
role="user",
|
||||
)
|
||||
response_output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Galway, Longford, Sligo",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[
|
||||
OpenAIUserMessageParam(content="Name some towns in Ireland"),
|
||||
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
|
||||
],
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123"
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
# and that the previous response instructions were not carried over
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
assert sent_messages[0].role == "system"
|
||||
assert sent_messages[0].content == instructions
|
||||
|
||||
# Check the rest of the messages were converted correctly
|
||||
assert sent_messages[1].role == "user"
|
||||
assert sent_messages[1].content == "Name some towns in Ireland"
|
||||
assert sent_messages[2].role == "assistant"
|
||||
assert sent_messages[2].content == "Galway, Longford, Sligo"
|
||||
assert sent_messages[3].role == "user"
|
||||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
|
||||
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
|
||||
# Setup
|
||||
|
|
@ -854,8 +917,10 @@ async def test_responses_store_list_input_items_logic():
|
|||
|
||||
# Create mock store and response store
|
||||
mock_sql_store = AsyncMock()
|
||||
backend_name = "sql_responses_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
|
||||
responses_store = ResponsesStore(
|
||||
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
|
||||
)
|
||||
responses_store.sql_store = mock_sql_store
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -23,8 +23,10 @@ async def provider():
|
|||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
backend_name = "kv_batches_test"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
register_kvstore_backends({backend_name: kvstore_config})
|
||||
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import boto3
|
|||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -38,11 +39,13 @@ def sample_text_file2():
|
|||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
backend_name = f"sql_s3_{tmp_path.name}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
|||
}
|
||||
|
||||
|
||||
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
|
||||
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
# Adds rerank models via construct_model_from_identifier
|
||||
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
if identifier in self.rerank_model_ids:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
|
|
@ -62,6 +84,13 @@ def mixin_with_embeddings():
|
|||
return OpenAIMixinWithEmbeddingsImpl(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_custom_model_construction():
|
||||
"""Create a test instance using custom construct_model_from_identifier"""
|
||||
config = RemoteInferenceProviderConfig()
|
||||
return OpenAIMixinWithCustomModelConstruction(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models():
|
||||
"""Create multiple mock OpenAI model objects"""
|
||||
|
|
@ -113,6 +142,19 @@ def mock_client_context():
|
|||
return _mock_client_context
|
||||
|
||||
|
||||
def _assert_models_match_expected(actual_models, expected_models):
|
||||
"""Verify the models match expected attributes.
|
||||
|
||||
Args:
|
||||
actual_models: List of models to verify
|
||||
expected_models: Mapping of model identifier to expected attribute values
|
||||
"""
|
||||
for identifier, expected_attrs in expected_models.items():
|
||||
model = next(m for m in actual_models if m.identifier == identifier)
|
||||
for attr_name, expected_value in expected_attrs.items():
|
||||
assert getattr(model, attr_name) == expected_value
|
||||
|
||||
|
||||
class TestOpenAIMixinListModels:
|
||||
"""Test cases for the list_models method"""
|
||||
|
||||
|
|
@ -342,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
|
|||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
# Find the models in the result
|
||||
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
|
||||
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
# Check embedding model
|
||||
assert embedding_model.model_type == ModelType.embedding
|
||||
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
|
||||
assert embedding_model.provider_id == "test-provider"
|
||||
assert embedding_model.provider_resource_id == "text-embedding-3-small"
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
# Check LLM model
|
||||
assert llm_model.model_type == ModelType.llm
|
||||
assert llm_model.metadata == {} # No metadata for LLMs
|
||||
assert llm_model.provider_id == "test-provider"
|
||||
assert llm_model.provider_resource_id == "gpt-4"
|
||||
|
||||
class TestOpenAIMixinCustomModelConstruction:
|
||||
"""Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
|
||||
|
||||
async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
|
||||
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
|
||||
# Create mock models: 1 embedding, 1 rerank, 1 LLM
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||
mock_llm_model = MagicMock(id="gpt-4")
|
||||
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
with mock_client_context(mixin_with_custom_model_construction, mock_client):
|
||||
result = await mixin_with_custom_model_construction.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"rerank-model-1": {
|
||||
"model_type": ModelType.rerank,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "rerank-model-1",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModels:
|
||||
|
|
|
|||
|
|
@ -10,15 +10,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
|
@ -30,7 +31,7 @@ def vector_provider(request):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
def vector_store_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
||||
|
||||
|
|
@ -112,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
backend_name = f"kv_vector_{unique_id}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -138,18 +140,17 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -177,17 +178,16 @@ async def faiss_vec_index(embedding_dimension):
|
|||
@pytest.fixture
|
||||
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = FaissVectorIOConfig(
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -215,7 +215,7 @@ def mock_psycopg2_connection():
|
|||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
vector_store = VectorStore(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
|
|
@ -225,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index._test_chunks = []
|
||||
original_add_chunks = index.add_chunks
|
||||
|
||||
|
|
@ -253,7 +253,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
|
@ -281,30 +281,30 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
await adapter.initialize()
|
||||
adapter.conn = mock_conn
|
||||
|
||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_query_chunks(vector_store_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_db = VectorDB(
|
||||
test_vector_store = VectorStore(
|
||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_db(test_vector_db)
|
||||
adapter.test_collection_id = test_vector_db.identifier
|
||||
await adapter.register_vector_store(test_vector_store)
|
||||
adapter.test_collection_id = test_vector_store.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
|
|
|||
|
|
@ -11,9 +11,8 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||
|
|
@ -44,8 +43,8 @@ def embedding_dimension():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id():
|
||||
return "test_vector_db"
|
||||
def vector_store_id():
|
||||
return "test_vector_store"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -62,12 +61,12 @@ def sample_embeddings(embedding_dimension):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "mock_embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = embedding_dimension
|
||||
return mock_vector_db
|
||||
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_store = MagicMock(spec=VectorStore)
|
||||
mock_vector_store.embedding_model = "mock_embedding_model"
|
||||
mock_vector_store.identifier = vector_store_id
|
||||
mock_vector_store.embedding_dimension = embedding_dimension
|
||||
return mock_vector_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -76,12 +75,6 @@ def mock_files_api():
|
|||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_api():
|
||||
mock_api = MagicMock(spec=Models)
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config():
|
||||
config = MagicMock(spec=FaissVectorIOConfig)
|
||||
|
|
@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
|||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
||||
|
||||
async def test_health_success(mock_models_api):
|
||||
async def test_health_success():
|
||||
"""Test that the health check returns OK status when faiss is working correctly."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
config = MagicMock()
|
||||
|
|
@ -126,9 +119,7 @@ async def test_health_success(mock_models_api):
|
|||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
mock_index_flat.return_value = MagicMock()
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
# Calling the health method directly
|
||||
response = await adapter.health()
|
||||
|
|
@ -142,7 +133,7 @@ async def test_health_success(mock_models_api):
|
|||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||
|
||||
|
||||
async def test_health_failure(mock_models_api):
|
||||
async def test_health_failure():
|
||||
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
config = MagicMock()
|
||||
|
|
@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api):
|
|||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
mock_index_flat.side_effect = Exception("Test error")
|
||||
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
# Calling the health method directly
|
||||
response = await adapter.health()
|
||||
|
|
|
|||
|
|
@ -6,14 +6,12 @@
|
|||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
|
|
@ -22,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
|
||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||
|
|
@ -72,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
|||
|
||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
||||
|
|
@ -82,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
|||
|
||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||
await vector_io_adapter.initialize()
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
await vector_io_adapter.initialize()
|
||||
|
|
@ -93,15 +92,15 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
|||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||
async def test_register_and_unregister_vector_store(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
assert dummy.identifier in vector_io_adapter.cache
|
||||
await vector_io_adapter.unregister_vector_db(dummy.identifier)
|
||||
await vector_io_adapter.unregister_vector_store(dummy.identifier)
|
||||
assert dummy.identifier not in vector_io_adapter.cache
|
||||
|
||||
|
||||
|
|
@ -122,7 +121,7 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
|||
|
||||
|
||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||
|
|
@ -171,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
|||
|
||||
|
||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||
|
|
@ -183,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -199,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -215,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -230,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -996,101 +995,11 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
assert batch.file_counts.in_progress == 8
|
||||
|
||||
|
||||
async def test_get_default_embedding_model_success(vector_io_adapter):
|
||||
"""Test successful default embedding model detection."""
|
||||
# Mock models API with a default model
|
||||
mock_models_api = Mock()
|
||||
mock_models_api.list_models = AsyncMock(
|
||||
return_value=Mock(
|
||||
data=[
|
||||
Model(
|
||||
identifier="nomic-embed-text-v1.5",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"default_configured": True,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
vector_io_adapter.models_api = mock_models_api
|
||||
result = await vector_io_adapter._get_default_embedding_model_and_dimension()
|
||||
|
||||
assert result is not None
|
||||
model_id, dimension = result
|
||||
assert model_id == "nomic-embed-text-v1.5"
|
||||
assert dimension == 768
|
||||
|
||||
|
||||
async def test_get_default_embedding_model_multiple_defaults_error(vector_io_adapter):
|
||||
"""Test error when multiple models are marked as default."""
|
||||
mock_models_api = Mock()
|
||||
mock_models_api.list_models = AsyncMock(
|
||||
return_value=Mock(
|
||||
data=[
|
||||
Model(
|
||||
identifier="model1",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={"embedding_dimension": 768, "default_configured": True},
|
||||
),
|
||||
Model(
|
||||
identifier="model2",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={"embedding_dimension": 512, "default_configured": True},
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
vector_io_adapter.models_api = mock_models_api
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
|
||||
await vector_io_adapter._get_default_embedding_model_and_dimension()
|
||||
|
||||
|
||||
async def test_openai_create_vector_store_uses_default_model(vector_io_adapter):
|
||||
"""Test that vector store creation uses default embedding model when none specified."""
|
||||
# Mock models API and dependencies
|
||||
mock_models_api = Mock()
|
||||
mock_models_api.list_models = AsyncMock(
|
||||
return_value=Mock(
|
||||
data=[
|
||||
Model(
|
||||
identifier="default-model",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="test-provider",
|
||||
metadata={"embedding_dimension": 512, "default_configured": True},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
vector_io_adapter.models_api = mock_models_api
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
vector_io_adapter.__provider_id__ = "test-provider"
|
||||
|
||||
# Create vector store without specifying embedding model
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test-store")
|
||||
result = await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify the vector store was created with default model
|
||||
assert result.name == "test-store"
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
assert call_args.embedding_model == "default-model"
|
||||
assert call_args.embedding_dimension == 512
|
||||
|
||||
|
||||
async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from metadata."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1106,9 +1015,9 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
|
|||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorDB was registered with correct embedding config from metadata
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
# Verify VectorStore was registered with correct embedding config from metadata
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "test-embedding-model"
|
||||
assert call_args.embedding_dimension == 512
|
||||
|
||||
|
|
@ -1116,8 +1025,8 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
|
|||
async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1133,9 +1042,9 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
|
|||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorDB was registered with correct embedding config from extra_body
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
# Verify VectorStore was registered with correct embedding config from extra_body
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "extra-body-model"
|
||||
assert call_args.embedding_dimension == 1024
|
||||
|
||||
|
|
@ -1143,8 +1052,8 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
|
|||
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1164,61 +1073,17 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
|||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should not raise any error and use metadata config
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "consistent-model"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
||||
async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
||||
"""Test that inconsistent embedding config between metadata and extra_body raises errors."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with inconsistent embedding model
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={
|
||||
"embedding_model": "metadata-model",
|
||||
"embedding_dimension": "768",
|
||||
},
|
||||
**{
|
||||
"embedding_model": "extra-body-model",
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding model inconsistent between metadata"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Reset mock for second test
|
||||
vector_io_adapter.register_vector_db.reset_mock()
|
||||
|
||||
# Test with inconsistent embedding dimension
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={
|
||||
"embedding_model": "same-model",
|
||||
"embedding_dimension": "512",
|
||||
},
|
||||
**{
|
||||
"embedding_model": "same-model",
|
||||
"embedding_dimension": 1024,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension inconsistent between metadata"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
|
||||
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||
"""Test that embedding dimension defaults to 768 when not provided."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1234,8 +1099,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
|||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should default to 768 dimensions
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "model-without-dimension"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
|
@ -1243,8 +1108,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
|||
async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||
"""Test that missing embedding model raises error."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
# Mock the default model lookup to return None (no default model available)
|
||||
|
|
@ -1253,5 +1118,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
|||
# Test with no embedding model provided
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
||||
|
||||
with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"):
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
|||
|
||||
|
||||
class TestRagQuery:
|
||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||
async def test_query_raises_on_empty_vector_store_ids(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
|
|
@ -82,7 +82,7 @@ class TestRagQuery:
|
|||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="wrong_mode")
|
||||
|
||||
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
||||
async def test_query_adds_vector_store_id_to_chunk_metadata(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(),
|
||||
vector_io_api=MagicMock(),
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from llama_stack.apis.tools import RAGDocument
|
|||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
URL,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
_validate_embedding,
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
|
|
@ -206,15 +206,15 @@ class TestVectorStore:
|
|||
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
||||
|
||||
|
||||
class TestVectorDBWithIndex:
|
||||
class TestVectorStoreWithIndex:
|
||||
async def test_insert_chunks_without_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model without embeddings"
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model without embeddings"
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -227,7 +227,7 @@ class TestVectorDBWithIndex:
|
|||
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
|
|
@ -243,14 +243,14 @@ class TestVectorDBWithIndex:
|
|||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
|
||||
async def test_insert_chunks_with_valid_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model with embeddings"
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model with embeddings"
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -258,7 +258,7 @@ class TestVectorDBWithIndex:
|
|||
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_not_called()
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
|
|
@ -267,14 +267,14 @@ class TestVectorDBWithIndex:
|
|||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
|
||||
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_db.embedding_model = "test-model with invalid embeddings"
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_vector_store.embedding_model = "test-model with invalid embeddings"
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
# Verify Chunk raises ValueError for invalid embedding type
|
||||
|
|
@ -283,7 +283,7 @@ class TestVectorDBWithIndex:
|
|||
|
||||
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||
await vector_db_with_index.insert_chunks(
|
||||
await vector_store_with_index.insert_chunks(
|
||||
[
|
||||
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
||||
|
|
@ -292,7 +292,7 @@ class TestVectorDBWithIndex:
|
|||
|
||||
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
||||
await vector_db_with_index.insert_chunks(
|
||||
await vector_store_with_index.insert_chunks(
|
||||
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
||||
)
|
||||
|
||||
|
|
@ -300,20 +300,20 @@ class TestVectorDBWithIndex:
|
|||
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
||||
]
|
||||
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
|
||||
await vector_store_with_index.insert_chunks(chunks_wrong_dim)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_not_called()
|
||||
mock_index.add_chunks.assert_not_called()
|
||||
|
||||
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model with partial embeddings"
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -327,7 +327,7 @@ class TestVectorDBWithIndex:
|
|||
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -8,24 +8,24 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.datatypes import VectorStoreWithOwner
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector_db():
|
||||
return VectorDB(
|
||||
identifier="test_vector_db",
|
||||
def sample_vector_store():
|
||||
return VectorStore(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
|
@ -45,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
|
|||
assert result is None
|
||||
|
||||
|
||||
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
||||
print(f"Registering {sample_vector_db}")
|
||||
await disk_dist_registry.register(sample_vector_db)
|
||||
async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
|
||||
print(f"Registering {sample_vector_store}")
|
||||
await disk_dist_registry.register(sample_vector_store)
|
||||
print(f"Registering {sample_model}")
|
||||
await disk_dist_registry.register(sample_model)
|
||||
print("Getting vector_db")
|
||||
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
print("Getting vector_store")
|
||||
result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||
|
||||
result_model = await disk_dist_registry.get("model", "test_model")
|
||||
assert result_model is not None
|
||||
|
|
@ -63,127 +63,137 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
|
|||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_vector_db)
|
||||
await disk_registry.register(sample_vector_store)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
db_path = sqlite_kvstore.db_path
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_test"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
cached_registry = CachedDiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||
assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
|
||||
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||
|
||||
|
||||
async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||
new_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
new_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_disk_dist_registry.register(new_vector_db)
|
||||
await cached_disk_dist_registry.register(new_vector_store)
|
||||
|
||||
# Verify in cache
|
||||
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == new_vector_store.identifier
|
||||
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_new"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
new_registry = DiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == new_vector_store.identifier
|
||||
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||
|
||||
|
||||
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||
original_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
original_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
assert await cached_disk_dist_registry.register(original_vector_db)
|
||||
assert await cached_disk_dist_registry.register(original_vector_store)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
duplicate_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="different-model",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"
|
||||
):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_store)
|
||||
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
|
||||
|
||||
|
||||
async def test_get_all_objects(cached_disk_dist_registry):
|
||||
# Create multiple test banks
|
||||
# Create multiple test banks
|
||||
test_vector_dbs = [
|
||||
VectorDB(
|
||||
identifier=f"test_vector_db_{i}",
|
||||
test_vector_stores = [
|
||||
VectorStore(
|
||||
identifier=f"test_vector_store_{i}",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id=f"test_vector_db_{i}",
|
||||
provider_resource_id=f"test_vector_store_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all vector_dbs
|
||||
for vector_db in test_vector_dbs:
|
||||
await cached_disk_dist_registry.register(vector_db)
|
||||
# Register all vector_stores
|
||||
for vector_store in test_vector_stores:
|
||||
await cached_disk_dist_registry.register(vector_store)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
# Verify each vector_store was stored correctly
|
||||
for original_vector_store in test_vector_stores:
|
||||
matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
|
||||
assert len(matching_vector_stores) == 1
|
||||
stored_vector_store = matching_vector_stores[0]
|
||||
assert stored_vector_store.embedding_model == original_vector_store.embedding_model
|
||||
assert stored_vector_store.provider_id == original_vector_store.provider_id
|
||||
assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
|
||||
|
||||
|
||||
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_vector_db",
|
||||
valid_db = VectorStore(
|
||||
identifier="valid_vector_store",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="valid_vector_db",
|
||||
provider_resource_id="valid_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||
KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
|
||||
'{"type": "vector_store", "identifier": "missing_fields"}',
|
||||
)
|
||||
|
||||
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||
|
|
@ -194,18 +204,18 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
|||
|
||||
# Should have filtered out the invalid entries
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_vector_db"
|
||||
assert all_objects[0].identifier == "valid_vector_store"
|
||||
|
||||
# Check that the get method also handles errors correctly
|
||||
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
||||
invalid_obj = await test_registry.get("vector_store", "corrupted_json")
|
||||
assert invalid_obj is None
|
||||
|
||||
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
||||
invalid_obj = await test_registry.get("vector_store", "missing_fields")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||
valid_db = VectorDB(
|
||||
valid_db = VectorStore(
|
||||
identifier="valid_cached_db",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
|
|
@ -214,12 +224,12 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
)
|
||||
|
||||
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||
|
|
@ -229,63 +239,65 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_cached_db"
|
||||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||
"""Test that registering identical objects succeeds (idempotent)."""
|
||||
vector_db = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
vector_store = VectorStoreWithOwner(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db)
|
||||
result1 = await disk_dist_registry.register(vector_store)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration of identical object should also succeed (idempotent)
|
||||
result2 = await disk_dist_registry.register(vector_db)
|
||||
result2 = await disk_dist_registry.register(vector_store)
|
||||
assert result2 is True
|
||||
|
||||
# Verify object exists and is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert retrieved is not None
|
||||
assert retrieved.identifier == vector_db.identifier
|
||||
assert retrieved.embedding_model == vector_db.embedding_model
|
||||
assert retrieved.identifier == vector_store.identifier
|
||||
assert retrieved.embedding_model == vector_store.embedding_model
|
||||
|
||||
|
||||
async def test_double_registration_different_objects(disk_dist_registry):
|
||||
"""Test that registering different objects with same identifier fails."""
|
||||
vector_db1 = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
vector_store1 = VectorStoreWithOwner(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
vector_db2 = VectorDBWithOwner(
|
||||
identifier="test_vector_db", # Same identifier
|
||||
vector_store2 = VectorStoreWithOwner(
|
||||
identifier="test_vector_store", # Same identifier
|
||||
embedding_model="different-model", # Different embedding model
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db1)
|
||||
result1 = await disk_dist_registry.register(vector_store1)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
||||
await disk_dist_registry.register(vector_db2)
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"
|
||||
):
|
||||
await disk_dist_registry.register(vector_store2)
|
||||
|
||||
# Verify original object is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert retrieved is not None
|
||||
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
||||
|
||||
|
|
|
|||
|
|
@ -516,6 +516,82 @@ def test_get_attributes_from_claims():
|
|||
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
|
||||
assert attributes["namespaces"] == ["my-tenant"]
|
||||
|
||||
# Test nested claims with dot notation (e.g., Keycloak resource_access structure)
|
||||
claims = {
|
||||
"sub": "user123",
|
||||
"resource_access": {"llamastack": {"roles": ["inference_max", "admin"]}, "other-client": {"roles": ["viewer"]}},
|
||||
"realm_access": {"roles": ["offline_access", "uma_authorization"]},
|
||||
}
|
||||
attributes = get_attributes_from_claims(
|
||||
claims, {"resource_access.llamastack.roles": "roles", "realm_access.roles": "realm_roles"}
|
||||
)
|
||||
assert set(attributes["roles"]) == {"inference_max", "admin"}
|
||||
assert set(attributes["realm_roles"]) == {"offline_access", "uma_authorization"}
|
||||
|
||||
# Test that dot notation takes precedence over literal keys with dots
|
||||
claims = {
|
||||
"my.dotted.key": "literal-value",
|
||||
"my": {"dotted": {"key": "nested-value"}},
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
|
||||
assert attributes["test"] == ["nested-value"]
|
||||
|
||||
# Test that literal key works when nested traversal doesn't exist
|
||||
claims = {
|
||||
"my.dotted.key": "literal-value",
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
|
||||
assert attributes["test"] == ["literal-value"]
|
||||
|
||||
# Test missing nested paths are handled gracefully
|
||||
claims = {
|
||||
"sub": "user123",
|
||||
"resource_access": {"other-client": {"roles": ["viewer"]}},
|
||||
}
|
||||
attributes = get_attributes_from_claims(
|
||||
claims,
|
||||
{
|
||||
"resource_access.llamastack.roles": "roles", # Missing nested path
|
||||
"resource_access.missing.key": "missing_attr", # Missing nested path
|
||||
"completely.missing.path": "another_missing", # Completely missing
|
||||
"sub": "username", # Existing path
|
||||
},
|
||||
)
|
||||
# Only the existing claim should be in attributes
|
||||
assert attributes["username"] == ["user123"]
|
||||
assert "roles" not in attributes
|
||||
assert "missing_attr" not in attributes
|
||||
assert "another_missing" not in attributes
|
||||
|
||||
# Test mixture of flat and nested claims paths
|
||||
claims = {
|
||||
"sub": "user456",
|
||||
"flat_key": "flat-value",
|
||||
"scope": "read write admin",
|
||||
"resource_access": {"app1": {"roles": ["role1", "role2"]}, "app2": {"roles": ["role3"]}},
|
||||
"groups": ["group1", "group2"],
|
||||
"metadata": {"tenant": "tenant1", "region": "us-west"},
|
||||
}
|
||||
attributes = get_attributes_from_claims(
|
||||
claims,
|
||||
{
|
||||
"sub": "user_id", # Flat string
|
||||
"scope": "permissions", # Flat string with spaces
|
||||
"groups": "teams", # Flat list
|
||||
"resource_access.app1.roles": "app1_roles", # Nested list
|
||||
"resource_access.app2.roles": "app2_roles", # Nested list
|
||||
"metadata.tenant": "tenant", # Nested string
|
||||
"metadata.region": "region", # Nested string
|
||||
},
|
||||
)
|
||||
assert attributes["user_id"] == ["user456"]
|
||||
assert set(attributes["permissions"]) == {"read", "write", "admin"}
|
||||
assert set(attributes["teams"]) == {"group1", "group2"}
|
||||
assert set(attributes["app1_roles"]) == {"role1", "role2"}
|
||||
assert attributes["app2_roles"] == ["role3"]
|
||||
assert attributes["tenant"] == ["tenant1"]
|
||||
assert attributes["region"] == ["us-west"]
|
||||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -11,7 +13,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
|
||||
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.core.server.quota import QuotaMiddleware
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
|
|
@ -29,8 +32,10 @@ class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
|
||||
def build_quota_config(db_path) -> QuotaConfig:
|
||||
backend_name = f"kv_quota_{uuid4().hex}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=str(db_path))})
|
||||
return QuotaConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||
kvstore=KVStoreReference(backend=backend_name, namespace="quota"),
|
||||
anonymous_max_requests=1,
|
||||
authenticated_max_requests=2,
|
||||
period=QuotaPeriod.DAY,
|
||||
|
|
|
|||
|
|
@ -12,15 +12,22 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.core.datatypes import (
|
||||
Api,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.resolver import resolve_impls
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||
|
|
@ -65,6 +72,35 @@ class SampleImpl:
|
|||
pass
|
||||
|
||||
|
||||
def make_run_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
),
|
||||
)
|
||||
register_kvstore_backends({name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("kv_")})
|
||||
register_sqlstore_backends(
|
||||
{name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("sql_")}
|
||||
)
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
async def test_resolve_impls_basic():
|
||||
# Create a real provider spec
|
||||
provider_spec = InlineProviderSpec(
|
||||
|
|
@ -78,7 +114,7 @@ async def test_resolve_impls_basic():
|
|||
# Create provider registry with our provider
|
||||
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
|
||||
|
||||
run_config = StackRunConfig(
|
||||
run_config = make_run_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestTranslateException:
|
|||
self.identifier = identifier
|
||||
self.owner = owner
|
||||
|
||||
resource = MockResource("vector_db", "test-db")
|
||||
resource = MockResource("vector_store", "test-db")
|
||||
|
||||
exc = AccessDeniedError("create", resource, user)
|
||||
result = translate_exception(exc)
|
||||
|
|
@ -49,7 +49,7 @@ class TestTranslateException:
|
|||
assert isinstance(result, HTTPException)
|
||||
assert result.status_code == 403
|
||||
assert "test-user" in result.detail
|
||||
assert "vector_db::test-db" in result.detail
|
||||
assert "vector_store::test-db" in result.detail
|
||||
assert "create" in result.detail
|
||||
assert "roles=['user']" in result.detail
|
||||
assert "teams=['dev']" in result.detail
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -16,8 +15,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register SQL store backends for testing."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)})
|
||||
|
||||
|
||||
def create_test_chat_completion(
|
||||
|
|
@ -44,167 +51,162 @@ def create_test_chat_completion(
|
|||
|
||||
async def test_inference_store_pagination_basic():
|
||||
"""Test basic pagination functionality."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
|
||||
|
||||
async def test_inference_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
|
||||
|
||||
async def test_inference_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
|
||||
async def test_inference_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
|
||||
|
||||
async def test_inference_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -15,8 +16,18 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
|
||||
backend_name = f"sql_responses_{uuid4().hex}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path)})
|
||||
return ResponsesStore(
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"),
|
||||
policy=policy or [],
|
||||
)
|
||||
|
||||
|
||||
def create_test_response_object(
|
||||
|
|
@ -54,7 +65,7 @@ async def test_responses_store_pagination_basic():
|
|||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
|
|
@ -103,7 +114,7 @@ async def test_responses_store_pagination_ascending():
|
|||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
@ -141,7 +152,7 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
|
|
@ -182,7 +193,7 @@ async def test_responses_store_pagination_invalid_after():
|
|||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
|
|
@ -194,7 +205,7 @@ async def test_responses_store_pagination_no_limit():
|
|||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
@ -226,7 +237,7 @@ async def test_responses_store_get_response_object():
|
|||
"""Test retrieving a single response object."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response
|
||||
|
|
@ -254,7 +265,7 @@ async def test_responses_store_input_items_pagination():
|
|||
"""Test pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
@ -335,7 +346,7 @@ async def test_responses_store_input_items_before_pagination():
|
|||
"""Test before pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue