mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
simplified some, walked back some decisions
This commit is contained in:
parent
af7472cdb0
commit
636764c2a1
90 changed files with 887 additions and 570 deletions
|
|
@ -12,9 +12,15 @@ import pytest
|
|||
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
register_sqlstore_backends,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
|
||||
def get_postgres_config():
|
||||
|
|
@ -55,8 +61,9 @@ def authorized_store(backend_config):
|
|||
config_func = backend_config
|
||||
|
||||
config = config_func()
|
||||
|
||||
base_sqlstore = sqlstore_impl(config)
|
||||
backend_name = f"sql_{type(config).__name__.lower()}"
|
||||
register_sqlstore_backends({backend_name: config})
|
||||
base_sqlstore = sqlstore_impl(SqlStoreReference(backend=backend_name, table_name="authorized_store"))
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
yield authorized_store
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ import yaml
|
|||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
PostgresKVStoreConfig,
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
)
|
||||
|
||||
|
|
@ -20,21 +22,26 @@ def test_starter_distribution_config_loads_and_resolves():
|
|||
|
||||
config = StackRunConfig(**config_dict)
|
||||
|
||||
# Config should have storage with default backend
|
||||
# Config should have named backends and explicit store references
|
||||
assert config.storage is not None
|
||||
assert "default" in config.storage.backends
|
||||
assert isinstance(config.storage.backends["default"], SqliteSqlStoreConfig)
|
||||
assert "kv_default" in config.storage.backends
|
||||
assert "sql_default" in config.storage.backends
|
||||
assert isinstance(config.storage.backends["kv_default"], SqliteKVStoreConfig)
|
||||
assert isinstance(config.storage.backends["sql_default"], SqliteSqlStoreConfig)
|
||||
|
||||
# Stores should reference the default backend
|
||||
assert config.storage.metadata is not None
|
||||
assert config.storage.metadata.backend == "default"
|
||||
assert config.storage.metadata.namespace is not None
|
||||
assert config.metadata_store is not None
|
||||
assert config.metadata_store.backend == "kv_default"
|
||||
assert config.metadata_store.namespace == "registry"
|
||||
|
||||
assert config.storage.inference is not None
|
||||
assert config.storage.inference.backend == "default"
|
||||
assert config.storage.inference.table_name is not None
|
||||
assert config.storage.inference.max_write_queue_size > 0
|
||||
assert config.storage.inference.num_writers > 0
|
||||
assert config.inference_store is not None
|
||||
assert config.inference_store.backend == "sql_default"
|
||||
assert config.inference_store.table_name == "inference_store"
|
||||
assert config.inference_store.max_write_queue_size > 0
|
||||
assert config.inference_store.num_writers > 0
|
||||
|
||||
assert config.conversations_store is not None
|
||||
assert config.conversations_store.backend == "sql_default"
|
||||
assert config.conversations_store.table_name == "openai_conversations"
|
||||
|
||||
|
||||
def test_postgres_demo_distribution_config_loads():
|
||||
|
|
@ -46,17 +53,15 @@ def test_postgres_demo_distribution_config_loads():
|
|||
|
||||
# Should have postgres backend
|
||||
assert config.storage is not None
|
||||
assert "default" in config.storage.backends
|
||||
assert isinstance(config.storage.backends["default"], PostgresSqlStoreConfig)
|
||||
|
||||
# Both stores use same postgres backend
|
||||
assert config.storage.metadata is not None
|
||||
assert config.storage.metadata.backend == "default"
|
||||
|
||||
assert config.storage.inference is not None
|
||||
assert config.storage.inference.backend == "default"
|
||||
|
||||
# Backend config should be Postgres
|
||||
postgres_backend = config.storage.backends["default"]
|
||||
assert "kv_default" in config.storage.backends
|
||||
assert "sql_default" in config.storage.backends
|
||||
postgres_backend = config.storage.backends["sql_default"]
|
||||
assert isinstance(postgres_backend, PostgresSqlStoreConfig)
|
||||
assert postgres_backend.host == "${env.POSTGRES_HOST:=localhost}"
|
||||
|
||||
kv_backend = config.storage.backends["kv_default"]
|
||||
assert isinstance(kv_backend, PostgresKVStoreConfig)
|
||||
|
||||
# Stores target the Postgres backends explicitly
|
||||
assert config.metadata_store.backend == "kv_default"
|
||||
assert config.inference_store.backend == "sql_default"
|
||||
|
|
|
|||
|
|
@ -1,82 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
InferenceStoreReference,
|
||||
PersistenceConfig,
|
||||
StoreReference,
|
||||
StoresConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_backend_reference_validation_catches_missing_backend():
|
||||
"""Critical: Catch user typos in backend references before runtime."""
|
||||
with pytest.raises(ValidationError, match="not defined in persistence.backends"):
|
||||
PersistenceConfig(
|
||||
backends={
|
||||
"default": SqliteSqlStoreConfig(db_path="/tmp/store.db"),
|
||||
},
|
||||
stores=StoresConfig(
|
||||
metadata=StoreReference(backend="typo_backend"), # User typo
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_backend_reference_validation_accepts_valid_config():
|
||||
"""Valid config should parse without errors."""
|
||||
config = PersistenceConfig(
|
||||
backends={
|
||||
"default": SqliteSqlStoreConfig(db_path="/tmp/store.db"),
|
||||
},
|
||||
stores=StoresConfig(
|
||||
metadata=StoreReference(backend="default"),
|
||||
inference=InferenceStoreReference(backend="default"),
|
||||
),
|
||||
)
|
||||
assert config.stores.metadata.backend == "default"
|
||||
assert config.stores.inference.backend == "default"
|
||||
|
||||
|
||||
def test_multiple_stores_can_share_same_backend():
|
||||
"""Core use case: metadata and inference both use 'default' backend."""
|
||||
config = PersistenceConfig(
|
||||
backends={
|
||||
"default": SqliteSqlStoreConfig(db_path="/tmp/shared.db"),
|
||||
},
|
||||
stores=StoresConfig(
|
||||
metadata=StoreReference(backend="default", namespace="metadata"),
|
||||
inference=InferenceStoreReference(backend="default"),
|
||||
conversations=StoreReference(backend="default"),
|
||||
),
|
||||
)
|
||||
# All reference the same backend
|
||||
assert config.stores.metadata.backend == "default"
|
||||
assert config.stores.inference.backend == "default"
|
||||
assert config.stores.conversations.backend == "default"
|
||||
|
||||
|
||||
def test_mixed_backend_types_allowed():
|
||||
"""Should support KVStore and SqlStore backends simultaneously."""
|
||||
config = PersistenceConfig(
|
||||
backends={
|
||||
"kvstore": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
|
||||
"sqlstore": PostgresSqlStoreConfig(user="test", password="test", host="localhost", db="test"),
|
||||
},
|
||||
stores=StoresConfig(
|
||||
metadata=StoreReference(backend="kvstore"),
|
||||
inference=InferenceStoreReference(backend="sqlstore"),
|
||||
),
|
||||
)
|
||||
assert isinstance(config.backends["kvstore"], SqliteKVStoreConfig)
|
||||
assert isinstance(config.backends["sqlstore"], PostgresSqlStoreConfig)
|
||||
77
tests/unit/core/test_storage_references.py
Normal file
77
tests/unit/core/test_storage_references.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for storage backend/reference validation."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
|
||||
|
||||
def _base_run_config(**overrides):
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path="/tmp/sql.db"),
|
||||
}
|
||||
),
|
||||
)
|
||||
return StackRunConfig(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
image_name="test-distro",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
metadata_store=overrides.pop(
|
||||
"metadata_store",
|
||||
KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
),
|
||||
inference_store=overrides.pop(
|
||||
"inference_store",
|
||||
InferenceStoreReference(backend="sql_default", table_name="inference"),
|
||||
),
|
||||
conversations_store=overrides.pop(
|
||||
"conversations_store",
|
||||
SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def test_references_require_known_backend():
|
||||
with pytest.raises(ValidationError, match="unknown backend 'missing'"):
|
||||
_base_run_config(metadata_store=KVStoreReference(backend="missing", namespace="registry"))
|
||||
|
||||
|
||||
def test_references_must_match_backend_family():
|
||||
with pytest.raises(ValidationError, match="kv_.* is required"):
|
||||
_base_run_config(metadata_store=KVStoreReference(backend="sql_default", namespace="registry"))
|
||||
|
||||
with pytest.raises(ValidationError, match="sql_.* is required"):
|
||||
_base_run_config(
|
||||
inference_store=InferenceStoreReference(backend="kv_default", table_name="inference"),
|
||||
)
|
||||
|
||||
|
||||
def test_valid_configuration_passes_validation():
|
||||
config = _base_run_config()
|
||||
assert config.metadata_store.backend == "kv_default"
|
||||
assert config.inference_store.backend == "sql_default"
|
||||
assert config.conversations_store.backend == "sql_default"
|
||||
|
|
@ -13,6 +13,14 @@ from pydantic import BaseModel, Field, ValidationError
|
|||
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
|
|
@ -29,6 +37,42 @@ class SampleConfig(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
def _default_storage() -> StorageConfig:
|
||||
return StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_stack_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop("storage", _default_storage())
|
||||
metadata_store = overrides.pop(
|
||||
"metadata_store",
|
||||
KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
)
|
||||
inference_store = overrides.pop(
|
||||
"inference_store",
|
||||
InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
)
|
||||
conversations_store = overrides.pop(
|
||||
"conversations_store",
|
||||
SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
)
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
metadata_store=metadata_store,
|
||||
inference_store=inference_store,
|
||||
conversations_store=conversations_store,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return make_stack_config(**defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
|
|
@ -47,8 +91,8 @@ def mock_providers():
|
|||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
return make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
|
@ -222,8 +266,8 @@ class TestProviderRegistry:
|
|||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
config = make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
|
@ -278,7 +322,6 @@ pip_packages:
|
|||
"""Test loading an external provider from a module (success path)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
# Simulate a provider module with get_provider_spec
|
||||
|
|
@ -293,7 +336,7 @@ pip_packages:
|
|||
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -317,12 +360,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_not_found(self, mock_providers):
|
||||
"""Test handling ModuleNotFoundError for missing provider module."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -341,12 +383,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -399,13 +440,12 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -426,7 +466,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -444,7 +483,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -564,7 +603,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -589,7 +627,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -613,7 +651,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -638,7 +675,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -662,7 +699,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -688,7 +724,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -718,7 +754,6 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
|
|
@ -727,7 +762,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -751,7 +786,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
|
|
@ -765,7 +799,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -787,10 +821,9 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
|
|
@ -805,7 +838,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -830,7 +862,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ from llama_stack.apis.common.errors import ResourceNotFoundError
|
|||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -36,8 +37,11 @@ async def files_provider(tmp_path):
|
|||
storage_dir = tmp_path / "files"
|
||||
db_path = tmp_path / "files_metadata.db"
|
||||
|
||||
backend_name = "sql_localfs_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
config = LocalfsFilesImplConfig(
|
||||
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
|
||||
storage_dir=storage_dir.as_posix(),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="files_metadata"),
|
||||
)
|
||||
|
||||
provider = LocalfsFilesImpl(config, default_policy())
|
||||
|
|
|
|||
|
|
@ -9,7 +9,15 @@ import random
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -19,12 +27,26 @@ async def temp_prompt_store(tmp_path_factory):
|
|||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_test": SqliteKVStoreConfig(db_path=db_path),
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(temp_dir / f"{unique_id}_sql.db")),
|
||||
}
|
||||
)
|
||||
mock_run_config = StackRunConfig(
|
||||
image_name="test-distribution",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
metadata_store=KVStoreReference(backend="kv_test", namespace="registry"),
|
||||
inference_store=InferenceStoreReference(backend="sql_test", table_name="inference"),
|
||||
conversations_store=SqlStoreReference(backend="sql_test", table_name="conversations"),
|
||||
)
|
||||
config = PromptServiceConfig(run_config=mock_run_config)
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
|
||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||
register_kvstore_backends({"kv_test": storage.backends["kv_test"]})
|
||||
store.kvstore = await kvstore_impl(KVStoreReference(backend="kv_test", namespace="prompts"))
|
||||
|
||||
yield store
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -50,7 +50,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
|
|
@ -854,8 +854,10 @@ async def test_responses_store_list_input_items_logic():
|
|||
|
||||
# Create mock store and response store
|
||||
mock_sql_store = AsyncMock()
|
||||
backend_name = "sql_responses_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
|
||||
responses_store = ResponsesStore(
|
||||
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
|
||||
)
|
||||
responses_store.sql_store = mock_sql_store
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -23,8 +23,10 @@ async def provider():
|
|||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
backend_name = "kv_batches_test"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
register_kvstore_backends({backend_name: kvstore_config})
|
||||
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import boto3
|
|||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -38,11 +39,13 @@ def sample_text_file2():
|
|||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
backend_name = f"sql_s3_{tmp_path.name}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,13 +12,14 @@ import pytest
|
|||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
|
@ -112,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
backend_name = f"kv_vector_{unique_id}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ import pytest
|
|||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -72,7 +72,11 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
|
|||
|
||||
# Test cached version loads from disk
|
||||
db_path = sqlite_kvstore.db_path
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_test"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
cached_registry = CachedDiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
|
|
@ -101,7 +105,11 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
|
|||
|
||||
# Verify persisted to disk
|
||||
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_new"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
new_registry = DiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -11,7 +13,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
|
||||
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.core.server.quota import QuotaMiddleware
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
|
|
@ -29,8 +32,10 @@ class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
|
||||
def build_quota_config(db_path) -> QuotaConfig:
|
||||
backend_name = f"kv_quota_{uuid4().hex}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=str(db_path))})
|
||||
return QuotaConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||
kvstore=KVStoreReference(backend=backend_name, namespace="quota"),
|
||||
anonymous_max_requests=1,
|
||||
authenticated_max_requests=2,
|
||||
period=QuotaPeriod.DAY,
|
||||
|
|
|
|||
|
|
@ -12,14 +12,18 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.core.datatypes import (
|
||||
Api,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.resolver import resolve_impls
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
|
|
@ -65,6 +69,38 @@ class SampleImpl:
|
|||
pass
|
||||
|
||||
|
||||
def make_run_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
}
|
||||
),
|
||||
)
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
metadata_store=overrides.pop(
|
||||
"metadata_store",
|
||||
KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
),
|
||||
inference_store=overrides.pop(
|
||||
"inference_store",
|
||||
InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
),
|
||||
conversations_store=overrides.pop(
|
||||
"conversations_store",
|
||||
SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
async def test_resolve_impls_basic():
|
||||
# Create a real provider spec
|
||||
provider_spec = InlineProviderSpec(
|
||||
|
|
@ -78,7 +114,7 @@ async def test_resolve_impls_basic():
|
|||
# Create provider registry with our provider
|
||||
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
|
||||
|
||||
run_config = StackRunConfig(
|
||||
run_config = make_run_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -15,8 +16,18 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
|
||||
backend_name = f"sql_responses_{uuid4().hex}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path)})
|
||||
return ResponsesStore(
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"),
|
||||
policy=policy or [],
|
||||
)
|
||||
|
||||
|
||||
def create_test_response_object(
|
||||
|
|
@ -54,7 +65,7 @@ async def test_responses_store_pagination_basic():
|
|||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
|
|
@ -103,7 +114,7 @@ async def test_responses_store_pagination_ascending():
|
|||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
@ -141,7 +152,7 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
|
|
@ -182,7 +193,7 @@ async def test_responses_store_pagination_invalid_after():
|
|||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
|
|
@ -194,7 +205,7 @@ async def test_responses_store_pagination_no_limit():
|
|||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
@ -226,7 +237,7 @@ async def test_responses_store_get_response_object():
|
|||
"""Test retrieving a single response object."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response
|
||||
|
|
@ -254,7 +265,7 @@ async def test_responses_store_input_items_pagination():
|
|||
"""Test pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
@ -335,7 +346,7 @@ async def test_responses_store_input_items_before_pagination():
|
|||
"""Test before pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue