mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 08:17:18 +00:00
feat(stores)!: use backend storage references instead of configs (#3697)
**This PR changes configurations in a backward incompatible way.** Run configs today repeat full SQLite/Postgres snippets everywhere a store is needed, which means duplicated credentials, extra connection pools, and lots of drift between files. This PR introduces named storage backends so the stack and providers can share a single catalog and reference those backends by name. ## Key Changes - Add `storage.backends` to `StackRunConfig`, register each KV/SQL backend once at startup, and validate that references point to the right family. - Move server stores under `storage.stores` with lightweight references (backend + namespace/table) instead of full configs. - Update every provider/config/doc to use the new reference style; docs/codegen now surface the simplified YAML. ## Migration Before: ```yaml metadata_store: type: sqlite db_path: ~/.llama/distributions/foo/registry.db inference_store: type: postgres host: ${env.POSTGRES_HOST} port: ${env.POSTGRES_PORT} db: ${env.POSTGRES_DB} user: ${env.POSTGRES_USER} password: ${env.POSTGRES_PASSWORD} conversations_store: type: postgres host: ${env.POSTGRES_HOST} port: ${env.POSTGRES_PORT} db: ${env.POSTGRES_DB} user: ${env.POSTGRES_USER} password: ${env.POSTGRES_PASSWORD} ``` After: ```yaml storage: backends: kv_default: type: kv_sqlite db_path: ~/.llama/distributions/foo/kvstore.db sql_default: type: sql_postgres host: ${env.POSTGRES_HOST} port: ${env.POSTGRES_PORT} db: ${env.POSTGRES_DB} user: ${env.POSTGRES_USER} password: ${env.POSTGRES_PASSWORD} stores: metadata: backend: kv_default namespace: registry inference: backend: sql_default table_name: inference_store max_write_queue_size: 10000 num_writers: 4 conversations: backend: sql_default table_name: openai_conversations ``` Provider configs follow the same pattern—for example, a Chroma vector adapter switches from: ```yaml providers: vector_io: - provider_id: chromadb provider_type: remote::chromadb config: url: ${env.CHROMADB_URL} kvstore: type: sqlite db_path: ~/.llama/distributions/foo/chroma.db ``` to: ```yaml providers: vector_io: - provider_id: chromadb provider_type: remote::chromadb config: url: ${env.CHROMADB_URL} persistence: backend: kv_default namespace: vector_io::chroma_remote ``` Once the backends are declared, everything else just points at them, so rotating credentials or swapping to Postgres happens in one place and the stack reuses a single connection pool.
This commit is contained in:
parent
add64e8e2a
commit
2c43285e22
105 changed files with 2290 additions and 1292 deletions
18
tests/external/run-byoa.yaml
vendored
18
tests/external/run-byoa.yaml
vendored
|
@ -7,6 +7,24 @@ providers:
|
|||
- provider_id: kaze
|
||||
provider_type: remote::kaze
|
||||
config: {}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/external}/kvstore.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/external}/sql_store.db
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
external_apis_dir: ~/.llama/apis.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
server:
|
||||
|
|
|
@ -238,7 +238,7 @@ def instantiate_llama_stack_client(session):
|
|||
run_config = run_config_from_adhoc_config_spec(config)
|
||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||
with open(run_config_file.name, "w") as f:
|
||||
yaml.dump(run_config.model_dump(), f)
|
||||
yaml.dump(run_config.model_dump(mode="json"), f)
|
||||
config = run_config_file.name
|
||||
|
||||
client = LlamaStackAsLibraryClient(
|
||||
|
|
|
@ -12,9 +12,15 @@ import pytest
|
|||
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
register_sqlstore_backends,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
|
||||
def get_postgres_config():
|
||||
|
@ -55,8 +61,9 @@ def authorized_store(backend_config):
|
|||
config_func = backend_config
|
||||
|
||||
config = config_func()
|
||||
|
||||
base_sqlstore = sqlstore_impl(config)
|
||||
backend_name = f"sql_{type(config).__name__.lower()}"
|
||||
register_sqlstore_backends({backend_name: config})
|
||||
base_sqlstore = sqlstore_impl(SqlStoreReference(backend=backend_name, table_name="authorized_store"))
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
yield authorized_store
|
||||
|
|
71
tests/integration/test_persistence_integration.py
Normal file
71
tests/integration/test_persistence_integration.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
PostgresKVStoreConfig,
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_starter_distribution_config_loads_and_resolves():
|
||||
"""Integration: Actual starter config should parse and have correct storage structure."""
|
||||
with open("llama_stack/distributions/starter/run.yaml") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
config = StackRunConfig(**config_dict)
|
||||
|
||||
# Config should have named backends and explicit store references
|
||||
assert config.storage is not None
|
||||
assert "kv_default" in config.storage.backends
|
||||
assert "sql_default" in config.storage.backends
|
||||
assert isinstance(config.storage.backends["kv_default"], SqliteKVStoreConfig)
|
||||
assert isinstance(config.storage.backends["sql_default"], SqliteSqlStoreConfig)
|
||||
|
||||
stores = config.storage.stores
|
||||
assert stores.metadata is not None
|
||||
assert stores.metadata.backend == "kv_default"
|
||||
assert stores.metadata.namespace == "registry"
|
||||
|
||||
assert stores.inference is not None
|
||||
assert stores.inference.backend == "sql_default"
|
||||
assert stores.inference.table_name == "inference_store"
|
||||
assert stores.inference.max_write_queue_size > 0
|
||||
assert stores.inference.num_writers > 0
|
||||
|
||||
assert stores.conversations is not None
|
||||
assert stores.conversations.backend == "sql_default"
|
||||
assert stores.conversations.table_name == "openai_conversations"
|
||||
|
||||
|
||||
def test_postgres_demo_distribution_config_loads():
|
||||
"""Integration: Postgres demo should use Postgres backend for all stores."""
|
||||
with open("llama_stack/distributions/postgres-demo/run.yaml") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
config = StackRunConfig(**config_dict)
|
||||
|
||||
# Should have postgres backend
|
||||
assert config.storage is not None
|
||||
assert "kv_default" in config.storage.backends
|
||||
assert "sql_default" in config.storage.backends
|
||||
postgres_backend = config.storage.backends["sql_default"]
|
||||
assert isinstance(postgres_backend, PostgresSqlStoreConfig)
|
||||
assert postgres_backend.host == "${env.POSTGRES_HOST:=localhost}"
|
||||
|
||||
kv_backend = config.storage.backends["kv_default"]
|
||||
assert isinstance(kv_backend, PostgresKVStoreConfig)
|
||||
|
||||
stores = config.storage.stores
|
||||
# Stores target the Postgres backends explicitly
|
||||
assert stores.metadata is not None
|
||||
assert stores.metadata.backend == "kv_default"
|
||||
assert stores.inference is not None
|
||||
assert stores.inference.backend == "sql_default"
|
|
@ -23,6 +23,27 @@ def config_with_image_name_int():
|
|||
image_name: 1234
|
||||
apis_to_serve: []
|
||||
built_at: {datetime.now().isoformat()}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: /tmp/test_kv.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: /tmp/test_sql.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: kv_default
|
||||
namespace: metadata
|
||||
inference:
|
||||
backend: sql_default
|
||||
table_name: inference
|
||||
conversations:
|
||||
backend: sql_default
|
||||
table_name: conversations
|
||||
responses:
|
||||
backend: sql_default
|
||||
table_name: responses
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
@ -54,6 +75,27 @@ def up_to_date_config():
|
|||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {datetime.now().isoformat()}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: /tmp/test_kv.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: /tmp/test_sql.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: kv_default
|
||||
namespace: metadata
|
||||
inference:
|
||||
backend: sql_default
|
||||
table_name: inference
|
||||
conversations:
|
||||
backend: sql_default
|
||||
table_name: conversations
|
||||
responses:
|
||||
backend: sql_default
|
||||
table_name: responses
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
|
|
@ -20,7 +20,14 @@ from llama_stack.core.conversations.conversations import (
|
|||
ConversationServiceConfig,
|
||||
ConversationServiceImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
ServerStoresConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -28,7 +35,18 @@ async def service():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations.db"
|
||||
|
||||
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||
|
||||
config = ConversationServiceConfig(run_config=run_config, policy=[])
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
yield service
|
||||
|
@ -121,9 +139,18 @@ async def test_policy_configuration():
|
|||
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||
]
|
||||
|
||||
config = ConversationServiceConfig(
|
||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||
|
||||
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
|
||||
|
|
84
tests/unit/core/test_storage_references.py
Normal file
84
tests/unit/core/test_storage_references.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for storage backend/reference validation."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
|
||||
|
||||
def _base_run_config(**overrides):
|
||||
metadata_reference = overrides.pop(
|
||||
"metadata_reference",
|
||||
KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
)
|
||||
inference_reference = overrides.pop(
|
||||
"inference_reference",
|
||||
InferenceStoreReference(backend="sql_default", table_name="inference"),
|
||||
)
|
||||
conversations_reference = overrides.pop(
|
||||
"conversations_reference",
|
||||
SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
)
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path="/tmp/sql.db"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=metadata_reference,
|
||||
inference=inference_reference,
|
||||
conversations=conversations_reference,
|
||||
),
|
||||
),
|
||||
)
|
||||
return StackRunConfig(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
image_name="test-distro",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def test_references_require_known_backend():
|
||||
with pytest.raises(ValidationError, match="unknown backend 'missing'"):
|
||||
_base_run_config(metadata_reference=KVStoreReference(backend="missing", namespace="registry"))
|
||||
|
||||
|
||||
def test_references_must_match_backend_family():
|
||||
with pytest.raises(ValidationError, match="kv_.* is required"):
|
||||
_base_run_config(metadata_reference=KVStoreReference(backend="sql_default", namespace="registry"))
|
||||
|
||||
with pytest.raises(ValidationError, match="sql_.* is required"):
|
||||
_base_run_config(
|
||||
inference_reference=InferenceStoreReference(backend="kv_default", table_name="inference"),
|
||||
)
|
||||
|
||||
|
||||
def test_valid_configuration_passes_validation():
|
||||
config = _base_run_config()
|
||||
stores = config.storage.stores
|
||||
assert stores.metadata is not None and stores.metadata.backend == "kv_default"
|
||||
assert stores.inference is not None and stores.inference.backend == "sql_default"
|
||||
assert stores.conversations is not None and stores.conversations.backend == "sql_default"
|
|
@ -13,6 +13,15 @@ from pydantic import BaseModel, Field, ValidationError
|
|||
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
|
@ -29,6 +38,32 @@ class SampleConfig(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
def _default_storage() -> StorageConfig:
|
||||
return StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_stack_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop("storage", _default_storage())
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
|
@ -47,8 +82,8 @@ def mock_providers():
|
|||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
return make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
@ -222,8 +257,8 @@ class TestProviderRegistry:
|
|||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
config = make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
@ -278,7 +313,6 @@ pip_packages:
|
|||
"""Test loading an external provider from a module (success path)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
# Simulate a provider module with get_provider_spec
|
||||
|
@ -293,7 +327,7 @@ pip_packages:
|
|||
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -317,12 +351,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_not_found(self, mock_providers):
|
||||
"""Test handling ModuleNotFoundError for missing provider module."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -341,12 +374,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -399,13 +431,12 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -426,7 +457,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
@ -444,7 +474,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -564,7 +594,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
@ -589,7 +618,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -613,7 +642,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
@ -638,7 +666,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -662,7 +690,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
@ -688,7 +715,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -718,7 +745,6 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
|
@ -727,7 +753,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -751,7 +777,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
|
@ -765,7 +790,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
@ -787,10 +812,9 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
|
@ -805,7 +829,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
@ -830,7 +853,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
@ -11,11 +11,12 @@ from llama_stack.apis.common.errors import ResourceNotFoundError
|
|||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
@ -36,8 +37,11 @@ async def files_provider(tmp_path):
|
|||
storage_dir = tmp_path / "files"
|
||||
db_path = tmp_path / "files_metadata.db"
|
||||
|
||||
backend_name = "sql_localfs_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
config = LocalfsFilesImplConfig(
|
||||
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
|
||||
storage_dir=storage_dir.as_posix(),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="files_metadata"),
|
||||
)
|
||||
|
||||
provider = LocalfsFilesImpl(config, default_policy())
|
||||
|
|
|
@ -9,7 +9,16 @@ import random
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -19,12 +28,28 @@ async def temp_prompt_store(tmp_path_factory):
|
|||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_test": SqliteKVStoreConfig(db_path=db_path),
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(temp_dir / f"{unique_id}_sql.db")),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_test", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_test", table_name="inference"),
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="conversations"),
|
||||
),
|
||||
)
|
||||
mock_run_config = StackRunConfig(
|
||||
image_name="test-distribution",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
config = PromptServiceConfig(run_config=mock_run_config)
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
|
||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||
register_kvstore_backends({"kv_test": storage.backends["kv_test"]})
|
||||
store.kvstore = await kvstore_impl(KVStoreReference(backend="kv_test", namespace="prompts"))
|
||||
|
||||
yield store
|
||||
|
|
|
@ -26,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
|
|||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register KV and SQL store backends for testing."""
|
||||
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
kv_path = str(tmp_path / "test_kv.db")
|
||||
sql_path = str(tmp_path / "test_sql.db")
|
||||
|
||||
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
return {
|
||||
|
@ -40,15 +54,20 @@ def mock_apis():
|
|||
|
||||
@pytest.fixture
|
||||
def config(tmp_path):
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
|
||||
|
||||
return MetaReferenceAgentsImplConfig(
|
||||
persistence_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
responses_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
persistence=AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
|
@ -917,8 +917,10 @@ async def test_responses_store_list_input_items_logic():
|
|||
|
||||
# Create mock store and response store
|
||||
mock_sql_store = AsyncMock()
|
||||
backend_name = "sql_responses_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
|
||||
responses_store = ResponsesStore(
|
||||
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
|
||||
)
|
||||
responses_store.sql_store = mock_sql_store
|
||||
|
||||
|
|
|
@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -23,8 +23,10 @@ async def provider():
|
|||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
backend_name = "kv_batches_test"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
register_kvstore_backends({backend_name: kvstore_config})
|
||||
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
|
|
|
@ -8,8 +8,9 @@ import boto3
|
|||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
@ -38,11 +39,13 @@ def sample_text_file2():
|
|||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
backend_name = f"sql_s3_{tmp_path.name}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -12,13 +12,14 @@ import pytest
|
|||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
@ -112,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
backend_name = f"kv_vector_{unique_id}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -138,7 +140,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
|
@ -177,7 +179,7 @@ async def faiss_vec_index(embedding_dimension):
|
|||
@pytest.fixture
|
||||
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = FaissVectorIOConfig(
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config,
|
||||
|
@ -253,7 +255,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
|
|
@ -10,13 +10,13 @@ import pytest
|
|||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -72,7 +72,11 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
|
|||
|
||||
# Test cached version loads from disk
|
||||
db_path = sqlite_kvstore.db_path
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_test"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
cached_registry = CachedDiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
|
@ -101,7 +105,11 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
|
|||
|
||||
# Verify persisted to disk
|
||||
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_new"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
new_registry = DiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
@ -11,7 +13,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
|
||||
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.core.server.quota import QuotaMiddleware
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
|
@ -29,8 +32,10 @@ class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
|
||||
def build_quota_config(db_path) -> QuotaConfig:
|
||||
backend_name = f"kv_quota_{uuid4().hex}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=str(db_path))})
|
||||
return QuotaConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||
kvstore=KVStoreReference(backend=backend_name, namespace="quota"),
|
||||
anonymous_max_requests=1,
|
||||
authenticated_max_requests=2,
|
||||
period=QuotaPeriod.DAY,
|
||||
|
|
|
@ -12,15 +12,22 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.core.datatypes import (
|
||||
Api,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.resolver import resolve_impls
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||
|
@ -65,6 +72,35 @@ class SampleImpl:
|
|||
pass
|
||||
|
||||
|
||||
def make_run_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
),
|
||||
)
|
||||
register_kvstore_backends({name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("kv_")})
|
||||
register_sqlstore_backends(
|
||||
{name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("sql_")}
|
||||
)
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
async def test_resolve_impls_basic():
|
||||
# Create a real provider spec
|
||||
provider_spec = InlineProviderSpec(
|
||||
|
@ -78,7 +114,7 @@ async def test_resolve_impls_basic():
|
|||
# Create provider registry with our provider
|
||||
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
|
||||
|
||||
run_config = StackRunConfig(
|
||||
run_config = make_run_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -16,8 +15,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register SQL store backends for testing."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)})
|
||||
|
||||
|
||||
def create_test_chat_completion(
|
||||
|
@ -44,167 +51,162 @@ def create_test_chat_completion(
|
|||
|
||||
async def test_inference_store_pagination_basic():
|
||||
"""Test basic pagination functionality."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
|
||||
|
||||
async def test_inference_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
|
||||
|
||||
async def test_inference_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
|
||||
async def test_inference_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
|
||||
|
||||
async def test_inference_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -15,8 +16,18 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
|
||||
backend_name = f"sql_responses_{uuid4().hex}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path)})
|
||||
return ResponsesStore(
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"),
|
||||
policy=policy or [],
|
||||
)
|
||||
|
||||
|
||||
def create_test_response_object(
|
||||
|
@ -54,7 +65,7 @@ async def test_responses_store_pagination_basic():
|
|||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
|
@ -103,7 +114,7 @@ async def test_responses_store_pagination_ascending():
|
|||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
@ -141,7 +152,7 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
|
@ -182,7 +193,7 @@ async def test_responses_store_pagination_invalid_after():
|
|||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
|
@ -194,7 +205,7 @@ async def test_responses_store_pagination_no_limit():
|
|||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
@ -226,7 +237,7 @@ async def test_responses_store_get_response_object():
|
|||
"""Test retrieving a single response object."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response
|
||||
|
@ -254,7 +265,7 @@ async def test_responses_store_input_items_pagination():
|
|||
"""Test pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
@ -335,7 +346,7 @@ async def test_responses_store_input_items_before_pagination():
|
|||
"""Test before pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue