From a3580e6bc012535a43e0b08bfae4f6e6563a4bbd Mon Sep 17 00:00:00 2001 From: Anastas Stoyanovsky Date: Tue, 18 Nov 2025 14:25:08 -0500 Subject: [PATCH 1/3] feat!: Wire through parallel_tool_calls to Responses API (#4124) # What does this PR do? Initial PR against #4123 Adds `parallel_tool_calls` spec to Responses API and basic initial implementation where no more than one function call is generated when set to `False`. ## Test Plan * Unit tests have been added to verify no more than one function call is generated. * A followup PR will verify passing through `parallel_tool_calls` to providers. * A followup PR will address verification and/or implementation of incremental function calling across multiple conversational turns. --------- Signed-off-by: Anastas Stoyanovsky --- client-sdks/stainless/openapi.yml | 19 +++++++++++++------ docs/static/deprecated-llama-stack-spec.yaml | 19 +++++++++++++------ .../static/experimental-llama-stack-spec.yaml | 14 ++++++++------ docs/static/llama-stack-spec.yaml | 19 +++++++++++++------ docs/static/stainless-llama-stack-spec.yaml | 19 +++++++++++++------ .../inline/agents/meta_reference/agents.py | 2 ++ .../responses/openai_responses.py | 4 ++++ .../meta_reference/responses/streaming.py | 4 ++++ src/llama_stack_api/agents.py | 1 + src/llama_stack_api/openai_responses.py | 4 ++-- 10 files changed, 73 insertions(+), 32 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 3a6735cbc..a6ebc868c 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -6723,9 +6723,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -7125,6 +7126,11 @@ components: anyOf: - type: string - type: 'null' + parallel_tool_calls: + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -7251,9 +7257,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 0bade1866..207af8926 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -3566,9 +3566,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -3968,6 +3969,11 @@ components: anyOf: - type: string - type: 'null' + parallel_tool_calls: + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -4094,9 +4100,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 4271989d6..f81a93d33 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -3263,9 +3263,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -3662,9 +3663,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index a12ac342f..816f3d0fb 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -5744,9 +5744,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -6146,6 +6147,11 @@ components: anyOf: - type: string - type: 'null' + parallel_tool_calls: + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -6272,9 +6278,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 3a6735cbc..a6ebc868c 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6723,9 +6723,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -7125,6 +7126,11 @@ components: anyOf: - type: string - type: 'null' + parallel_tool_calls: + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string @@ -7251,9 +7257,10 @@ components: type: array title: Output parallel_tool_calls: - type: boolean - title: Parallel Tool Calls - default: false + anyOf: + - type: boolean + - type: 'null' + default: true previous_response_id: anyOf: - type: string diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index 347f6fdb1..e47e757be 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -92,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents): model: str, prompt: OpenAIResponsePrompt | None = None, instructions: str | None = None, + parallel_tool_calls: bool | None = True, previous_response_id: str | None = None, conversation: str | None = None, store: bool | None = True, @@ -120,6 +121,7 @@ class MetaReferenceAgentsImpl(Agents): include, max_infer_iters, guardrails, + parallel_tool_calls, max_tool_calls, ) return result # type: ignore[no-any-return] diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index cb0fe284e..7e080a675 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -252,6 +252,7 @@ class OpenAIResponsesImpl: include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[str | ResponseGuardrailSpec] | None = None, + parallel_tool_calls: bool | None = None, max_tool_calls: int | None = None, ): stream = bool(stream) @@ -296,6 +297,7 @@ class OpenAIResponsesImpl: tools=tools, max_infer_iters=max_infer_iters, guardrail_ids=guardrail_ids, + parallel_tool_calls=parallel_tool_calls, max_tool_calls=max_tool_calls, ) @@ -346,6 +348,7 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, max_infer_iters: int | None = 10, guardrail_ids: list[str] | None = None, + parallel_tool_calls: bool | None = True, max_tool_calls: int | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: # These should never be None when called from create_openai_response (which sets defaults) @@ -385,6 +388,7 @@ class OpenAIResponsesImpl: created_at=created_at, text=text, max_infer_iters=max_infer_iters, + parallel_tool_calls=parallel_tool_calls, tool_executor=self.tool_executor, safety_api=self.safety_api, guardrail_ids=guardrail_ids, diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 95c690147..cdbd87244 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -114,6 +114,7 @@ class StreamingResponseOrchestrator: safety_api, guardrail_ids: list[str] | None = None, prompt: OpenAIResponsePrompt | None = None, + parallel_tool_calls: bool | None = None, max_tool_calls: int | None = None, ): self.inference_api = inference_api @@ -128,6 +129,8 @@ class StreamingResponseOrchestrator: self.prompt = prompt # System message that is inserted into the model's context self.instructions = instructions + # Whether to allow more than one function tool call generated per turn. + self.parallel_tool_calls = parallel_tool_calls # Max number of total calls to built-in tools that can be processed in a response self.max_tool_calls = max_tool_calls self.sequence_number = 0 @@ -190,6 +193,7 @@ class StreamingResponseOrchestrator: usage=self.accumulated_usage, instructions=self.instructions, prompt=self.prompt, + parallel_tool_calls=self.parallel_tool_calls, max_tool_calls=self.max_tool_calls, ) diff --git a/src/llama_stack_api/agents.py b/src/llama_stack_api/agents.py index ca0611746..9b767608a 100644 --- a/src/llama_stack_api/agents.py +++ b/src/llama_stack_api/agents.py @@ -72,6 +72,7 @@ class Agents(Protocol): model: str, prompt: OpenAIResponsePrompt | None = None, instructions: str | None = None, + parallel_tool_calls: bool | None = True, previous_response_id: str | None = None, conversation: str | None = None, store: bool | None = True, diff --git a/src/llama_stack_api/openai_responses.py b/src/llama_stack_api/openai_responses.py index 952418f1c..e20004487 100644 --- a/src/llama_stack_api/openai_responses.py +++ b/src/llama_stack_api/openai_responses.py @@ -585,7 +585,7 @@ class OpenAIResponseObject(BaseModel): :param model: Model identifier used for generation :param object: Object type identifier, always "response" :param output: List of generated output items (messages, tool calls, etc.) - :param parallel_tool_calls: Whether tool calls can be executed in parallel + :param parallel_tool_calls: (Optional) Whether to allow more than one function tool call generated per turn. :param previous_response_id: (Optional) ID of the previous response in a conversation :param prompt: (Optional) Reference to a prompt template and its variables. :param status: Current status of the response generation @@ -605,7 +605,7 @@ class OpenAIResponseObject(BaseModel): model: str object: Literal["response"] = "response" output: Sequence[OpenAIResponseOutput] - parallel_tool_calls: bool = False + parallel_tool_calls: bool | None = True previous_response_id: str | None = None prompt: OpenAIResponsePrompt | None = None status: str From bd5ad2963e496e78f6e115dfc9910d55ce2121b5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 18 Nov 2025 13:15:16 -0800 Subject: [PATCH 2/3] refactor(storage): make { kvstore, sqlstore } as llama stack "internal" APIs (#4181) These primitives (used both by the Stack as well as provider implementations) can be thought of fruitfully as internal-only APIs which can themselves have multiple implementations. We use the new `llama_stack_api.internal` namespace for this. In addition: the change moves kv/sql store impls, configs, and dependency helpers under `core/storage` ## Testing `pytest tests/unit/utils/test_authorized_sqlstore.py`, other existing CI --- pyproject.toml | 4 + .../core/conversations/conversations.py | 6 +- src/llama_stack/core/prompts/prompts.py | 2 +- src/llama_stack/core/server/quota.py | 4 +- src/llama_stack/core/stack.py | 4 +- .../storage}/kvstore/__init__.py | 2 + .../utils => core/storage}/kvstore/config.py | 0 .../utils => core/storage}/kvstore/kvstore.py | 30 ++-- .../storage}/kvstore/mongodb/__init__.py | 0 .../storage}/kvstore/mongodb/mongodb.py | 2 +- .../storage}/kvstore/postgres/__init__.py | 0 .../storage}/kvstore/postgres/postgres.py | 45 +++--- .../storage}/kvstore/redis/__init__.py | 0 .../storage}/kvstore/redis/redis.py | 57 +++++-- .../storage}/kvstore/sqlite/__init__.py | 0 .../storage}/kvstore/sqlite/sqlite.py | 2 +- .../core/storage/sqlstore/__init__.py | 17 +++ .../storage}/sqlstore/authorized_sqlstore.py | 4 +- .../storage}/sqlstore/sqlalchemy_sqlstore.py | 3 +- .../storage}/sqlstore/sqlstore.py | 3 +- src/llama_stack/core/store/registry.py | 2 +- .../distributions/starter/starter.py | 4 +- src/llama_stack/distributions/template.py | 8 +- .../inline/agents/meta_reference/agents.py | 2 +- .../inline/batches/reference/__init__.py | 2 +- .../inline/batches/reference/batches.py | 2 +- .../inline/datasetio/localfs/datasetio.py | 2 +- .../inline/eval/meta_reference/eval.py | 2 +- .../providers/inline/files/localfs/files.py | 6 +- .../providers/inline/vector_io/faiss/faiss.py | 4 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 4 +- src/llama_stack/providers/registry/agents.py | 2 +- src/llama_stack/providers/registry/files.py | 2 +- .../datasetio/huggingface/huggingface.py | 2 +- .../providers/remote/files/openai/files.py | 6 +- .../providers/remote/files/s3/files.py | 6 +- .../remote/vector_io/chroma/chroma.py | 4 +- .../remote/vector_io/milvus/milvus.py | 4 +- .../remote/vector_io/pgvector/pgvector.py | 4 +- .../remote/vector_io/qdrant/qdrant.py | 2 +- .../remote/vector_io/weaviate/weaviate.py | 4 +- .../utils/inference/inference_store.py | 7 +- .../providers/utils/kvstore/sqlite/config.py | 20 --- .../utils/memory/openai_vector_store_mixin.py | 2 +- .../utils/responses/responses_store.py | 7 +- .../providers/utils/sqlstore/api.py | 140 ------------------ .../internal}/__init__.py | 4 + .../internal/kvstore.py} | 5 + src/llama_stack_api/internal/sqlstore.py | 79 ++++++++++ tests/integration/files/test_files.py | 6 +- .../sqlstore/test_authorized_sqlstore.py | 10 +- .../unit/conversations/test_conversations.py | 2 +- tests/unit/files/test_files.py | 2 +- tests/unit/fixtures.py | 4 +- tests/unit/prompts/prompts/conftest.py | 2 +- .../meta_reference/test_openai_responses.py | 2 +- tests/unit/providers/batches/conftest.py | 2 +- tests/unit/providers/files/conftest.py | 2 +- .../providers/files/test_s3_files_auth.py | 16 +- tests/unit/providers/vector_io/conftest.py | 4 +- tests/unit/registry/test_registry.py | 2 +- tests/unit/server/test_quota.py | 2 +- tests/unit/server/test_resolver.py | 4 +- .../utils/inference/test_inference_store.py | 2 +- .../unit/utils/kvstore/test_sqlite_memory.py | 4 +- .../utils/responses/test_responses_store.py | 2 +- tests/unit/utils/sqlstore/test_sqlstore.py | 6 +- tests/unit/utils/test_authorized_sqlstore.py | 14 +- 68 files changed, 302 insertions(+), 309 deletions(-) rename src/llama_stack/{providers/utils => core/storage}/kvstore/__init__.py (78%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/config.py (100%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/kvstore.py (82%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/mongodb/__init__.py (100%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/mongodb/mongodb.py (98%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/postgres/__init__.py (100%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/postgres/postgres.py (73%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/redis/__init__.py (100%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/redis/redis.py (54%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/sqlite/__init__.py (100%) rename src/llama_stack/{providers/utils => core/storage}/kvstore/sqlite/sqlite.py (99%) create mode 100644 src/llama_stack/core/storage/sqlstore/__init__.py rename src/llama_stack/{providers/utils => core/storage}/sqlstore/authorized_sqlstore.py (99%) rename src/llama_stack/{providers/utils => core/storage}/sqlstore/sqlalchemy_sqlstore.py (99%) rename src/llama_stack/{providers/utils => core/storage}/sqlstore/sqlstore.py (98%) delete mode 100644 src/llama_stack/providers/utils/kvstore/sqlite/config.py delete mode 100644 src/llama_stack/providers/utils/sqlstore/api.py rename src/{llama_stack/providers/utils/sqlstore => llama_stack_api/internal}/__init__.py (65%) rename src/{llama_stack/providers/utils/kvstore/api.py => llama_stack_api/internal/kvstore.py} (89%) create mode 100644 src/llama_stack_api/internal/sqlstore.py diff --git a/pyproject.toml b/pyproject.toml index eea515b09..3e16dc08f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -356,6 +356,10 @@ exclude = [ module = [ "yaml", "fire", + "redis.asyncio", + "psycopg2", + "psycopg2.extras", + "psycopg2.extensions", "torchtune.*", "fairscale.*", "torchvision.*", diff --git a/src/llama_stack/core/conversations/conversations.py b/src/llama_stack/core/conversations/conversations.py index 4cf5a82ee..90402439b 100644 --- a/src/llama_stack/core/conversations/conversations.py +++ b/src/llama_stack/core/conversations/conversations.py @@ -11,10 +11,9 @@ from typing import Any, Literal from pydantic import BaseModel, TypeAdapter from llama_stack.core.datatypes import AccessRule, StackRunConfig +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( Conversation, ConversationDeletedResource, @@ -25,6 +24,7 @@ from llama_stack_api import ( Conversations, Metadata, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType logger = get_logger(name=__name__, category="openai_conversations") diff --git a/src/llama_stack/core/prompts/prompts.py b/src/llama_stack/core/prompts/prompts.py index 9f532c1cd..ff67ad138 100644 --- a/src/llama_stack/core/prompts/prompts.py +++ b/src/llama_stack/core/prompts/prompts.py @@ -10,7 +10,7 @@ from typing import Any from pydantic import BaseModel from llama_stack.core.datatypes import StackRunConfig -from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.core.storage.kvstore import KVStore, kvstore_impl from llama_stack_api import ListPromptsResponse, Prompt, Prompts diff --git a/src/llama_stack/core/server/quota.py b/src/llama_stack/core/server/quota.py index 689f0e4c3..d74d3e89d 100644 --- a/src/llama_stack/core/server/quota.py +++ b/src/llama_stack/core/server/quota.py @@ -11,9 +11,9 @@ from datetime import UTC, datetime, timedelta from starlette.types import ASGIApp, Receive, Scope, Send from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType +from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl +from llama_stack_api.internal.kvstore import KVStore logger = get_logger(name=__name__, category="core::server") diff --git a/src/llama_stack/core/stack.py b/src/llama_stack/core/stack.py index 00d990cb1..8ba1f2afd 100644 --- a/src/llama_stack/core/stack.py +++ b/src/llama_stack/core/stack.py @@ -385,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig): else: raise ValueError(f"Unknown storage backend type: {type}") - from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends - from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends + from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends + from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends register_kvstore_backends(kv_backends) register_sqlstore_backends(sql_backends) diff --git a/src/llama_stack/providers/utils/kvstore/__init__.py b/src/llama_stack/core/storage/kvstore/__init__.py similarity index 78% rename from src/llama_stack/providers/utils/kvstore/__init__.py rename to src/llama_stack/core/storage/kvstore/__init__.py index 470a75d2d..2d60f1508 100644 --- a/src/llama_stack/providers/utils/kvstore/__init__.py +++ b/src/llama_stack/core/storage/kvstore/__init__.py @@ -4,4 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack_api.internal.kvstore import KVStore as KVStore + from .kvstore import * # noqa: F401, F403 diff --git a/src/llama_stack/providers/utils/kvstore/config.py b/src/llama_stack/core/storage/kvstore/config.py similarity index 100% rename from src/llama_stack/providers/utils/kvstore/config.py rename to src/llama_stack/core/storage/kvstore/config.py diff --git a/src/llama_stack/providers/utils/kvstore/kvstore.py b/src/llama_stack/core/storage/kvstore/kvstore.py similarity index 82% rename from src/llama_stack/providers/utils/kvstore/kvstore.py rename to src/llama_stack/core/storage/kvstore/kvstore.py index 5b8d77102..8ea9282fa 100644 --- a/src/llama_stack/providers/utils/kvstore/kvstore.py +++ b/src/llama_stack/core/storage/kvstore/kvstore.py @@ -13,11 +13,19 @@ from __future__ import annotations import asyncio from collections import defaultdict +from datetime import datetime +from typing import cast -from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType +from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig +from llama_stack_api.internal.kvstore import KVStore -from .api import KVStore -from .config import KVStoreConfig +from .config import ( + KVStoreConfig, + MongoDBKVStoreConfig, + PostgresKVStoreConfig, + RedisKVStoreConfig, + SqliteKVStoreConfig, +) def kvstore_dependencies(): @@ -33,7 +41,7 @@ def kvstore_dependencies(): class InmemoryKVStoreImpl(KVStore): def __init__(self): - self._store = {} + self._store: dict[str, str] = {} async def initialize(self) -> None: pass @@ -41,7 +49,7 @@ class InmemoryKVStoreImpl(KVStore): async def get(self, key: str) -> str | None: return self._store.get(key) - async def set(self, key: str, value: str) -> None: + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: self._store[key] = value async def values_in_range(self, start_key: str, end_key: str) -> list[str]: @@ -70,7 +78,8 @@ def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None _KVSTORE_INSTANCES.clear() _KVSTORE_LOCKS.clear() for name, cfg in backends.items(): - _KVSTORE_BACKENDS[name] = cfg + typed_cfg = cast(KVStoreConfig, cfg) + _KVSTORE_BACKENDS[name] = typed_cfg async def kvstore_impl(reference: KVStoreReference) -> KVStore: @@ -94,19 +103,20 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore: config = backend_config.model_copy() config.namespace = reference.namespace - if config.type == StorageBackendType.KV_REDIS.value: + impl: KVStore + if isinstance(config, RedisKVStoreConfig): from .redis import RedisKVStoreImpl impl = RedisKVStoreImpl(config) - elif config.type == StorageBackendType.KV_SQLITE.value: + elif isinstance(config, SqliteKVStoreConfig): from .sqlite import SqliteKVStoreImpl impl = SqliteKVStoreImpl(config) - elif config.type == StorageBackendType.KV_POSTGRES.value: + elif isinstance(config, PostgresKVStoreConfig): from .postgres import PostgresKVStoreImpl impl = PostgresKVStoreImpl(config) - elif config.type == StorageBackendType.KV_MONGODB.value: + elif isinstance(config, MongoDBKVStoreConfig): from .mongodb import MongoDBKVStoreImpl impl = MongoDBKVStoreImpl(config) diff --git a/src/llama_stack/providers/utils/kvstore/mongodb/__init__.py b/src/llama_stack/core/storage/kvstore/mongodb/__init__.py similarity index 100% rename from src/llama_stack/providers/utils/kvstore/mongodb/__init__.py rename to src/llama_stack/core/storage/kvstore/mongodb/__init__.py diff --git a/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/src/llama_stack/core/storage/kvstore/mongodb/mongodb.py similarity index 98% rename from src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py rename to src/llama_stack/core/storage/kvstore/mongodb/mongodb.py index 964c45090..673d6038f 100644 --- a/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/src/llama_stack/core/storage/kvstore/mongodb/mongodb.py @@ -9,8 +9,8 @@ from datetime import datetime from pymongo import AsyncMongoClient from pymongo.asynchronous.collection import AsyncCollection +from llama_stack.core.storage.kvstore import KVStore from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig diff --git a/src/llama_stack/providers/utils/kvstore/postgres/__init__.py b/src/llama_stack/core/storage/kvstore/postgres/__init__.py similarity index 100% rename from src/llama_stack/providers/utils/kvstore/postgres/__init__.py rename to src/llama_stack/core/storage/kvstore/postgres/__init__.py diff --git a/src/llama_stack/providers/utils/kvstore/postgres/postgres.py b/src/llama_stack/core/storage/kvstore/postgres/postgres.py similarity index 73% rename from src/llama_stack/providers/utils/kvstore/postgres/postgres.py rename to src/llama_stack/core/storage/kvstore/postgres/postgres.py index 56d6dbb48..39c3fd2e2 100644 --- a/src/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/src/llama_stack/core/storage/kvstore/postgres/postgres.py @@ -6,12 +6,13 @@ from datetime import datetime -import psycopg2 -from psycopg2.extras import DictCursor +import psycopg2 # type: ignore[import-not-found] +from psycopg2.extensions import connection as PGConnection # type: ignore[import-not-found] +from psycopg2.extras import DictCursor # type: ignore[import-not-found] from llama_stack.log import get_logger +from llama_stack_api.internal.kvstore import KVStore -from ..api import KVStore from ..config import PostgresKVStoreConfig log = get_logger(name=__name__, category="providers::utils") @@ -20,12 +21,12 @@ log = get_logger(name=__name__, category="providers::utils") class PostgresKVStoreImpl(KVStore): def __init__(self, config: PostgresKVStoreConfig): self.config = config - self.conn = None - self.cursor = None + self._conn: PGConnection | None = None + self._cursor: DictCursor | None = None async def initialize(self) -> None: try: - self.conn = psycopg2.connect( + self._conn = psycopg2.connect( host=self.config.host, port=self.config.port, database=self.config.db, @@ -34,11 +35,11 @@ class PostgresKVStoreImpl(KVStore): sslmode=self.config.ssl_mode, sslrootcert=self.config.ca_cert_path, ) - self.conn.autocommit = True - self.cursor = self.conn.cursor(cursor_factory=DictCursor) + self._conn.autocommit = True + self._cursor = self._conn.cursor(cursor_factory=DictCursor) # Create table if it doesn't exist - self.cursor.execute( + self._cursor.execute( f""" CREATE TABLE IF NOT EXISTS {self.config.table_name} ( key TEXT PRIMARY KEY, @@ -51,6 +52,11 @@ class PostgresKVStoreImpl(KVStore): log.exception("Could not connect to PostgreSQL database server") raise RuntimeError("Could not connect to PostgreSQL database server") from e + def _cursor_or_raise(self) -> DictCursor: + if self._cursor is None: + raise RuntimeError("Postgres client not initialized") + return self._cursor + def _namespaced_key(self, key: str) -> str: if not self.config.namespace: return key @@ -58,7 +64,8 @@ class PostgresKVStoreImpl(KVStore): async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: key = self._namespaced_key(key) - self.cursor.execute( + cursor = self._cursor_or_raise() + cursor.execute( f""" INSERT INTO {self.config.table_name} (key, value, expiration) VALUES (%s, %s, %s) @@ -70,7 +77,8 @@ class PostgresKVStoreImpl(KVStore): async def get(self, key: str) -> str | None: key = self._namespaced_key(key) - self.cursor.execute( + cursor = self._cursor_or_raise() + cursor.execute( f""" SELECT value FROM {self.config.table_name} WHERE key = %s @@ -78,12 +86,13 @@ class PostgresKVStoreImpl(KVStore): """, (key,), ) - result = self.cursor.fetchone() + result = cursor.fetchone() return result[0] if result else None async def delete(self, key: str) -> None: key = self._namespaced_key(key) - self.cursor.execute( + cursor = self._cursor_or_raise() + cursor.execute( f"DELETE FROM {self.config.table_name} WHERE key = %s", (key,), ) @@ -92,7 +101,8 @@ class PostgresKVStoreImpl(KVStore): start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) - self.cursor.execute( + cursor = self._cursor_or_raise() + cursor.execute( f""" SELECT value FROM {self.config.table_name} WHERE key >= %s AND key < %s @@ -101,14 +111,15 @@ class PostgresKVStoreImpl(KVStore): """, (start_key, end_key), ) - return [row[0] for row in self.cursor.fetchall()] + return [row[0] for row in cursor.fetchall()] async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) - self.cursor.execute( + cursor = self._cursor_or_raise() + cursor.execute( f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s", (start_key, end_key), ) - return [row[0] for row in self.cursor.fetchall()] + return [row[0] for row in cursor.fetchall()] diff --git a/src/llama_stack/providers/utils/kvstore/redis/__init__.py b/src/llama_stack/core/storage/kvstore/redis/__init__.py similarity index 100% rename from src/llama_stack/providers/utils/kvstore/redis/__init__.py rename to src/llama_stack/core/storage/kvstore/redis/__init__.py diff --git a/src/llama_stack/providers/utils/kvstore/redis/redis.py b/src/llama_stack/core/storage/kvstore/redis/redis.py similarity index 54% rename from src/llama_stack/providers/utils/kvstore/redis/redis.py rename to src/llama_stack/core/storage/kvstore/redis/redis.py index 3d2d956c3..2b35a22e1 100644 --- a/src/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/src/llama_stack/core/storage/kvstore/redis/redis.py @@ -6,18 +6,25 @@ from datetime import datetime -from redis.asyncio import Redis +from redis.asyncio import Redis # type: ignore[import-not-found] + +from llama_stack_api.internal.kvstore import KVStore -from ..api import KVStore from ..config import RedisKVStoreConfig class RedisKVStoreImpl(KVStore): def __init__(self, config: RedisKVStoreConfig): self.config = config + self._redis: Redis | None = None async def initialize(self) -> None: - self.redis = Redis.from_url(self.config.url) + self._redis = Redis.from_url(self.config.url) + + def _client(self) -> Redis: + if self._redis is None: + raise RuntimeError("Redis client not initialized") + return self._redis def _namespaced_key(self, key: str) -> str: if not self.config.namespace: @@ -26,30 +33,37 @@ class RedisKVStoreImpl(KVStore): async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: key = self._namespaced_key(key) - await self.redis.set(key, value) + client = self._client() + await client.set(key, value) if expiration: - await self.redis.expireat(key, expiration) + await client.expireat(key, expiration) async def get(self, key: str) -> str | None: key = self._namespaced_key(key) - value = await self.redis.get(key) + client = self._client() + value = await client.get(key) if value is None: return None - await self.redis.ttl(key) - return value + await client.ttl(key) + if isinstance(value, bytes): + return value.decode("utf-8") + if isinstance(value, str): + return value + return str(value) async def delete(self, key: str) -> None: key = self._namespaced_key(key) - await self.redis.delete(key) + await self._client().delete(key) async def values_in_range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) + client = self._client() cursor = 0 pattern = start_key + "*" # Match all keys starting with start_key prefix - matching_keys = [] + matching_keys: list[str | bytes] = [] while True: - cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000) + cursor, keys = await client.scan(cursor, match=pattern, count=1000) for key in keys: key_str = key.decode("utf-8") if isinstance(key, bytes) else key @@ -61,7 +75,7 @@ class RedisKVStoreImpl(KVStore): # Then fetch all values in a single MGET call if matching_keys: - values = await self.redis.mget(matching_keys) + values = await client.mget(matching_keys) return [ value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None ] @@ -70,7 +84,18 @@ class RedisKVStoreImpl(KVStore): async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" - matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}") - if not matching_keys: - return [] - return [k.decode("utf-8") for k in matching_keys] + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + client = self._client() + cursor = 0 + pattern = start_key + "*" + result: list[str] = [] + while True: + cursor, keys = await client.scan(cursor, match=pattern, count=1000) + for key in keys: + key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key) + if start_key <= key_str <= end_key: + result.append(key_str) + if cursor == 0: + break + return result diff --git a/src/llama_stack/providers/utils/kvstore/sqlite/__init__.py b/src/llama_stack/core/storage/kvstore/sqlite/__init__.py similarity index 100% rename from src/llama_stack/providers/utils/kvstore/sqlite/__init__.py rename to src/llama_stack/core/storage/kvstore/sqlite/__init__.py diff --git a/src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/src/llama_stack/core/storage/kvstore/sqlite/sqlite.py similarity index 99% rename from src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py rename to src/llama_stack/core/storage/kvstore/sqlite/sqlite.py index a9a7a1304..22cf8ac49 100644 --- a/src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/src/llama_stack/core/storage/kvstore/sqlite/sqlite.py @@ -10,8 +10,8 @@ from datetime import datetime import aiosqlite from llama_stack.log import get_logger +from llama_stack_api.internal.kvstore import KVStore -from ..api import KVStore from ..config import SqliteKVStoreConfig logger = get_logger(name=__name__, category="providers::utils") diff --git a/src/llama_stack/core/storage/sqlstore/__init__.py b/src/llama_stack/core/storage/sqlstore/__init__.py new file mode 100644 index 000000000..eb843e4ba --- /dev/null +++ b/src/llama_stack/core/storage/sqlstore/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_api.internal.sqlstore import ( + ColumnDefinition as ColumnDefinition, +) +from llama_stack_api.internal.sqlstore import ( + ColumnType as ColumnType, +) +from llama_stack_api.internal.sqlstore import ( + SqlStore as SqlStore, +) + +from .sqlstore import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py similarity index 99% rename from src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py rename to src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py index ba95dd120..e6cdcc543 100644 --- a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py @@ -14,8 +14,8 @@ from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.storage.datatypes import StorageBackendType from llama_stack.log import get_logger - -from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore +from llama_stack_api import PaginatedResponse +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore logger = get_logger(name=__name__, category="providers::utils") diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py similarity index 99% rename from src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py rename to src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py index 10009d396..01c561443 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py @@ -29,8 +29,7 @@ from sqlalchemy.sql.elements import ColumnElement from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig from llama_stack.log import get_logger from llama_stack_api import PaginatedResponse - -from .api import ColumnDefinition, ColumnType, SqlStore +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore logger = get_logger(name=__name__, category="providers::utils") diff --git a/src/llama_stack/providers/utils/sqlstore/sqlstore.py b/src/llama_stack/core/storage/sqlstore/sqlstore.py similarity index 98% rename from src/llama_stack/providers/utils/sqlstore/sqlstore.py rename to src/llama_stack/core/storage/sqlstore/sqlstore.py index 9409b7d00..fb2c9d279 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/src/llama_stack/core/storage/sqlstore/sqlstore.py @@ -16,8 +16,7 @@ from llama_stack.core.storage.datatypes import ( StorageBackendConfig, StorageBackendType, ) - -from .api import SqlStore +from llama_stack_api.internal.sqlstore import SqlStore sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] diff --git a/src/llama_stack/core/store/registry.py b/src/llama_stack/core/store/registry.py index 6ff9e575b..7144a94f7 100644 --- a/src/llama_stack/core/store/registry.py +++ b/src/llama_stack/core/store/registry.py @@ -12,8 +12,8 @@ import pydantic from llama_stack.core.datatypes import RoutableObjectWithProvider from llama_stack.core.storage.datatypes import KVStoreReference +from llama_stack.core.storage.kvstore import KVStore, kvstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl logger = get_logger(__name__, category="core::registry") diff --git a/src/llama_stack/distributions/starter/starter.py b/src/llama_stack/distributions/starter/starter.py index 4c21a8c99..32264eebb 100644 --- a/src/llama_stack/distributions/starter/starter.py +++ b/src/llama_stack/distributions/starter/starter.py @@ -17,6 +17,8 @@ from llama_stack.core.datatypes import ( ToolGroupInput, VectorStoresConfig, ) +from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig @@ -35,8 +37,6 @@ from llama_stack.providers.remote.vector_io.pgvector.config import ( ) from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig -from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig from llama_stack_api import RemoteProviderSpec diff --git a/src/llama_stack/distributions/template.py b/src/llama_stack/distributions/template.py index 5755a26de..90b458805 100644 --- a/src/llama_stack/distributions/template.py +++ b/src/llama_stack/distributions/template.py @@ -35,13 +35,13 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageBackendType, ) +from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig +from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages +from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages from llama_stack_api import DatasetPurpose, ModelType diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index e47e757be..ba83a9576 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -6,8 +6,8 @@ from llama_stack.core.datatypes import AccessRule +from llama_stack.core.storage.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack_api import ( Agents, diff --git a/src/llama_stack/providers/inline/batches/reference/__init__.py b/src/llama_stack/providers/inline/batches/reference/__init__.py index 11c4b06a9..b48c82864 100644 --- a/src/llama_stack/providers/inline/batches/reference/__init__.py +++ b/src/llama_stack/providers/inline/batches/reference/__init__.py @@ -7,7 +7,7 @@ from typing import Any from llama_stack.core.datatypes import AccessRule, Api -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack_api import Files, Inference, Models from .batches import ReferenceBatchesImpl diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 73727799d..aaa2c7b22 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -16,8 +16,8 @@ from typing import Any, Literal from openai.types.batch import BatchError, Errors from pydantic import BaseModel +from llama_stack.core.storage.kvstore import KVStore from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore from llama_stack_api import ( Batches, BatchObject, diff --git a/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 6ab1a540f..85c7cff3e 100644 --- a/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -5,8 +5,8 @@ # the root directory of this source tree. from typing import Any +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri -from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse diff --git a/src/llama_stack/providers/inline/eval/meta_reference/eval.py b/src/llama_stack/providers/inline/eval/meta_reference/eval.py index d43e569e2..0f0cb84d6 100644 --- a/src/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/src/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -8,8 +8,8 @@ from typing import Any from tqdm import tqdm +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.common.data_schema_validator import ColumnName -from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack_api import ( Agents, Benchmark, diff --git a/src/llama_stack/providers/inline/files/localfs/files.py b/src/llama_stack/providers/inline/files/localfs/files.py index 5fb35a378..2afe2fe5e 100644 --- a/src/llama_stack/providers/inline/files/localfs/files.py +++ b/src/llama_stack/providers/inline/files/localfs/files.py @@ -13,11 +13,10 @@ from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.core.datatypes import AccessRule from llama_stack.core.id_generation import generate_object_id +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack.log import get_logger from llama_stack.providers.utils.files.form_data import parse_expires_after -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( ExpiresAfter, Files, @@ -28,6 +27,7 @@ from llama_stack_api import ( Order, ResourceNotFoundError, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType from .config import LocalfsFilesImplConfig diff --git a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py index d52a54e6a..91a17058b 100644 --- a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -14,9 +14,8 @@ import faiss # type: ignore[import-untyped] import numpy as np from numpy.typing import NDArray +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack_api import ( @@ -32,6 +31,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import FaissVectorIOConfig diff --git a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 74bc349a5..a384a33dc 100644 --- a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -14,9 +14,8 @@ import numpy as np import sqlite_vec # type: ignore[import-untyped] from numpy.typing import NDArray +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, @@ -35,6 +34,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore logger = get_logger(name=__name__, category="vector_io") diff --git a/src/llama_stack/providers/registry/agents.py b/src/llama_stack/providers/registry/agents.py index 455be1ae7..2c68750a6 100644 --- a/src/llama_stack/providers/registry/agents.py +++ b/src/llama_stack/providers/registry/agents.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.providers.utils.kvstore import kvstore_dependencies +from llama_stack.core.storage.kvstore import kvstore_dependencies from llama_stack_api import ( Api, InlineProviderSpec, diff --git a/src/llama_stack/providers/registry/files.py b/src/llama_stack/providers/registry/files.py index 024254b57..8ce8acd91 100644 --- a/src/llama_stack/providers/registry/files.py +++ b/src/llama_stack/providers/registry/files.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages +from llama_stack.core.storage.sqlstore.sqlstore import sql_store_pip_packages from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec diff --git a/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 72069f716..26390a63b 100644 --- a/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -6,7 +6,7 @@ from typing import Any from urllib.parse import parse_qs, urlparse -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse diff --git a/src/llama_stack/providers/remote/files/openai/files.py b/src/llama_stack/providers/remote/files/openai/files.py index d2f5a08eb..2cfd44168 100644 --- a/src/llama_stack/providers/remote/files/openai/files.py +++ b/src/llama_stack/providers/remote/files/openai/files.py @@ -10,10 +10,9 @@ from typing import Annotated, Any from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.core.datatypes import AccessRule +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.files.form_data import parse_expires_after -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( ExpiresAfter, Files, @@ -24,6 +23,7 @@ from llama_stack_api import ( Order, ResourceNotFoundError, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType from openai import OpenAI from .config import OpenAIFilesImplConfig diff --git a/src/llama_stack/providers/remote/files/s3/files.py b/src/llama_stack/providers/remote/files/s3/files.py index 68822eb77..3c1c82fa0 100644 --- a/src/llama_stack/providers/remote/files/s3/files.py +++ b/src/llama_stack/providers/remote/files/s3/files.py @@ -19,10 +19,9 @@ if TYPE_CHECKING: from llama_stack.core.datatypes import AccessRule from llama_stack.core.id_generation import generate_object_id +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.files.form_data import parse_expires_after -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( ExpiresAfter, Files, @@ -33,6 +32,7 @@ from llama_stack_api import ( Order, ResourceNotFoundError, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType from .config import S3FilesImplConfig diff --git a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py index 645b40661..491db6d4d 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -11,10 +11,9 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack_api import ( @@ -27,6 +26,7 @@ from llama_stack_api import ( VectorStore, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index aefa20317..044d678fa 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -11,10 +11,9 @@ from typing import Any from numpy.typing import NDArray from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_WEIGHTED, @@ -34,6 +33,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig diff --git a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 2901bad97..5c86fb08d 100644 --- a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -13,10 +13,9 @@ from psycopg2 import sql from psycopg2.extras import Json, execute_values from pydantic import BaseModel, TypeAdapter +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name @@ -31,6 +30,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import PGVectorVectorIOConfig diff --git a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 20ab653d0..4dd78d834 100644 --- a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -13,9 +13,9 @@ from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig -from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack_api import ( diff --git a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index ba3e6b7ea..c15d5f468 100644 --- a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -13,9 +13,8 @@ from weaviate.classes.init import Auth from weaviate.classes.query import Filter, HybridFusion from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, @@ -35,6 +34,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import WeaviateVectorIOConfig diff --git a/src/llama_stack/providers/utils/inference/inference_store.py b/src/llama_stack/providers/utils/inference/inference_store.py index 49e3af7a1..a8a0cace4 100644 --- a/src/llama_stack/providers/utils/inference/inference_store.py +++ b/src/llama_stack/providers/utils/inference/inference_store.py @@ -10,6 +10,8 @@ from sqlalchemy.exc import IntegrityError from llama_stack.core.datatypes import AccessRule from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl from llama_stack.log import get_logger from llama_stack_api import ( ListOpenAIChatCompletionResponse, @@ -18,10 +20,7 @@ from llama_stack_api import ( OpenAIMessageParam, Order, ) - -from ..sqlstore.api import ColumnDefinition, ColumnType -from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType logger = get_logger(name=__name__, category="inference") diff --git a/src/llama_stack/providers/utils/kvstore/sqlite/config.py b/src/llama_stack/providers/utils/kvstore/sqlite/config.py deleted file mode 100644 index 0f8fa0a95..000000000 --- a/src/llama_stack/providers/utils/kvstore/sqlite/config.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pydantic import BaseModel, Field - -from llama_stack_api import json_schema_type - - -@json_schema_type -class SqliteControlPlaneConfig(BaseModel): - db_path: str = Field( - description="File path for the sqlite database", - ) - table_name: str = Field( - default="llamastack_control_plane", - description="Table into which all the keys will be placed", - ) diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 540ff5940..bbfd60e25 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -17,7 +17,6 @@ from pydantic import TypeAdapter from llama_stack.core.id_generation import generate_object_id from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import ( ChunkForDeletion, content_from_data_and_mime_type, @@ -53,6 +52,7 @@ from llama_stack_api import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack_api.internal.kvstore import KVStore EMBEDDING_DIMENSION = 768 diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index f6e7c435d..0401db206 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -6,6 +6,8 @@ from llama_stack.core.datatypes import AccessRule from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack.log import get_logger from llama_stack_api import ( ListOpenAIResponseInputItem, @@ -17,10 +19,7 @@ from llama_stack_api import ( OpenAIResponseObjectWithInput, Order, ) - -from ..sqlstore.api import ColumnDefinition, ColumnType -from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import sqlstore_impl +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType logger = get_logger(name=__name__, category="openai_responses") diff --git a/src/llama_stack/providers/utils/sqlstore/api.py b/src/llama_stack/providers/utils/sqlstore/api.py deleted file mode 100644 index 708fc7095..000000000 --- a/src/llama_stack/providers/utils/sqlstore/api.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from collections.abc import Mapping, Sequence -from enum import Enum -from typing import Any, Literal, Protocol - -from pydantic import BaseModel - -from llama_stack_api import PaginatedResponse - - -class ColumnType(Enum): - INTEGER = "INTEGER" - STRING = "STRING" - TEXT = "TEXT" - FLOAT = "FLOAT" - BOOLEAN = "BOOLEAN" - JSON = "JSON" - DATETIME = "DATETIME" - - -class ColumnDefinition(BaseModel): - type: ColumnType - primary_key: bool = False - nullable: bool = True - default: Any = None - - -class SqlStore(Protocol): - """ - A protocol for a SQL store. - """ - - async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: - """ - Create a table. - """ - pass - - async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: - """ - Insert a row or batch of rows into a table. - """ - pass - - async def upsert( - self, - table: str, - data: Mapping[str, Any], - conflict_columns: list[str], - update_columns: list[str] | None = None, - ) -> None: - """ - Insert a row and update specified columns when conflicts occur. - """ - pass - - async def fetch_all( - self, - table: str, - where: Mapping[str, Any] | None = None, - where_sql: str | None = None, - limit: int | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - cursor: tuple[str, str] | None = None, - ) -> PaginatedResponse: - """ - Fetch all rows from a table with optional cursor-based pagination. - - :param table: The table name - :param where: Simple key-value WHERE conditions - :param where_sql: Raw SQL WHERE clause for complex queries - :param limit: Maximum number of records to return - :param order_by: List of (column, order) tuples for sorting - :param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page) - Requires order_by with exactly one column when used - :return: PaginatedResult with data and has_more flag - - Note: Cursor pagination only supports single-column ordering for simplicity. - Multi-column ordering is allowed without cursor but will raise an error with cursor. - """ - pass - - async def fetch_one( - self, - table: str, - where: Mapping[str, Any] | None = None, - where_sql: str | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> dict[str, Any] | None: - """ - Fetch one row from a table. - """ - pass - - async def update( - self, - table: str, - data: Mapping[str, Any], - where: Mapping[str, Any], - ) -> None: - """ - Update a row in a table. - """ - pass - - async def delete( - self, - table: str, - where: Mapping[str, Any], - ) -> None: - """ - Delete a row from a table. - """ - pass - - async def add_column_if_not_exists( - self, - table: str, - column_name: str, - column_type: ColumnType, - nullable: bool = True, - ) -> None: - """ - Add a column to an existing table if the column doesn't already exist. - - This is useful for table migrations when adding new functionality. - If the table doesn't exist, this method should do nothing. - If the column already exists, this method should do nothing. - - :param table: Table name - :param column_name: Name of the column to add - :param column_type: Type of the column to add - :param nullable: Whether the column should be nullable (default: True) - """ - pass diff --git a/src/llama_stack/providers/utils/sqlstore/__init__.py b/src/llama_stack_api/internal/__init__.py similarity index 65% rename from src/llama_stack/providers/utils/sqlstore/__init__.py rename to src/llama_stack_api/internal/__init__.py index 756f351d8..bbf7010c3 100644 --- a/src/llama_stack/providers/utils/sqlstore/__init__.py +++ b/src/llama_stack_api/internal/__init__.py @@ -3,3 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +# Internal subpackage for shared interfaces that are not part of the public API. + +__all__: list[str] = [] diff --git a/src/llama_stack/providers/utils/kvstore/api.py b/src/llama_stack_api/internal/kvstore.py similarity index 89% rename from src/llama_stack/providers/utils/kvstore/api.py rename to src/llama_stack_api/internal/kvstore.py index d17dc66e1..a6d982261 100644 --- a/src/llama_stack/providers/utils/kvstore/api.py +++ b/src/llama_stack_api/internal/kvstore.py @@ -9,6 +9,8 @@ from typing import Protocol class KVStore(Protocol): + """Protocol for simple key/value storage backends.""" + # TODO: make the value type bytes instead of str async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ... @@ -19,3 +21,6 @@ class KVStore(Protocol): async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... + + +__all__ = ["KVStore"] diff --git a/src/llama_stack_api/internal/sqlstore.py b/src/llama_stack_api/internal/sqlstore.py new file mode 100644 index 000000000..ebb2d8ba2 --- /dev/null +++ b/src/llama_stack_api/internal/sqlstore.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Mapping, Sequence +from enum import Enum +from typing import Any, Literal, Protocol + +from pydantic import BaseModel + +from llama_stack_api import PaginatedResponse + + +class ColumnType(Enum): + INTEGER = "INTEGER" + STRING = "STRING" + TEXT = "TEXT" + FLOAT = "FLOAT" + BOOLEAN = "BOOLEAN" + JSON = "JSON" + DATETIME = "DATETIME" + + +class ColumnDefinition(BaseModel): + type: ColumnType + primary_key: bool = False + nullable: bool = True + default: Any = None + + +class SqlStore(Protocol): + """Protocol for common SQL-store functionality.""" + + async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: ... + + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: ... + + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: ... + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + where_sql: str | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: ... + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + where_sql: str | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: ... + + async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: ... + + async def delete(self, table: str, where: Mapping[str, Any]) -> None: ... + + async def add_column_if_not_exists( + self, + table: str, + column_name: str, + column_type: ColumnType, + nullable: bool = True, + ) -> None: ... + + +__all__ = ["ColumnDefinition", "ColumnType", "SqlStore"] diff --git a/tests/integration/files/test_files.py b/tests/integration/files/test_files.py index 1f19c88c5..e8004c95d 100644 --- a/tests/integration/files/test_files.py +++ b/tests/integration/files/test_files.py @@ -175,7 +175,7 @@ def test_expires_after_requests(openai_client): @pytest.mark.xfail(message="User isolation broken for current providers, must be fixed.") -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack_client): """Test that users can only access their own files.""" from llama_stack_client import NotFoundError @@ -275,7 +275,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack raise e -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") def test_files_authentication_shared_attributes( mock_get_authenticated_user, llama_stack_client, provider_type_is_openai ): @@ -335,7 +335,7 @@ def test_files_authentication_shared_attributes( raise e -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") def test_files_authentication_anonymous_access( mock_get_authenticated_user, llama_stack_client, provider_type_is_openai ): diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index ad9115756..4f4f4a8dd 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -13,14 +13,14 @@ import pytest from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.datatypes import User from llama_stack.core.storage.datatypes import SqlStoreReference -from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import ( +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import ( PostgresSqlStoreConfig, SqliteSqlStoreConfig, register_sqlstore_backends, sqlstore_impl, ) +from llama_stack_api.internal.sqlstore import ColumnType def get_postgres_config(): @@ -96,7 +96,7 @@ async def cleanup_records(sql_store, table_name, record_ids): @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request): """Test that JSON column comparisons work correctly for both PostgreSQL and SQLite""" backend_name = request.node.callspec.id @@ -190,7 +190,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request): """Test that 'user is owner' policies work correctly with record ownership""" from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope diff --git a/tests/unit/conversations/test_conversations.py b/tests/unit/conversations/test_conversations.py index 95c54d379..e8286576b 100644 --- a/tests/unit/conversations/test_conversations.py +++ b/tests/unit/conversations/test_conversations.py @@ -23,7 +23,7 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageConfig, ) -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import OpenAIResponseInputMessageContentText, OpenAIResponseMessage diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index 793f4edd3..197038349 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -9,11 +9,11 @@ import pytest from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack.providers.inline.files.localfs import ( LocalfsFilesImpl, LocalfsFilesImplConfig, ) -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import OpenAIFilePurpose, Order, ResourceNotFoundError diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index 443a1d371..9e049f8da 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -6,9 +6,9 @@ import pytest +from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig +from llama_stack.core.storage.kvstore.sqlite import SqliteKVStoreImpl from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl @pytest.fixture(scope="function") diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index c876f2041..8bfc1f03c 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageConfig, ) -from llama_stack.providers.utils.kvstore import register_kvstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends @pytest.fixture diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 78f0d7cfd..256df6baf 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -17,6 +17,7 @@ from openai.types.chat.chat_completion_chunk import ( from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) @@ -24,7 +25,6 @@ from llama_stack.providers.utils.responses.responses_store import ( ResponsesStore, _OpenAIResponseObjectWithInputAndMessages, ) -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api.agents import Order from llama_stack_api.inference import ( OpenAIAssistantMessageParam, diff --git a/tests/unit/providers/batches/conftest.py b/tests/unit/providers/batches/conftest.py index d161bf976..8ecfa99fb 100644 --- a/tests/unit/providers/batches/conftest.py +++ b/tests/unit/providers/batches/conftest.py @@ -13,9 +13,9 @@ from unittest.mock import AsyncMock import pytest from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig +from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig -from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends @pytest.fixture diff --git a/tests/unit/providers/files/conftest.py b/tests/unit/providers/files/conftest.py index c64ecc3a3..f8959b5b7 100644 --- a/tests/unit/providers/files/conftest.py +++ b/tests/unit/providers/files/conftest.py @@ -9,8 +9,8 @@ import pytest from moto import mock_aws from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends class MockUploadFile: diff --git a/tests/unit/providers/files/test_s3_files_auth.py b/tests/unit/providers/files/test_s3_files_auth.py index e113611bd..49b33fd7b 100644 --- a/tests/unit/providers/files/test_s3_files_auth.py +++ b/tests/unit/providers/files/test_s3_files_auth.py @@ -18,11 +18,11 @@ async def test_listing_hides_other_users_file(s3_provider, sample_text_file): user_a = User("user-a", {"roles": ["team-a"]}) user_b = User("user-b", {"roles": ["team-b"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_a uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_b listed = await s3_provider.openai_list_files() assert all(f.id != uploaded.id for f in listed.data) @@ -41,11 +41,11 @@ async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op): user_a = User("user-a", {"roles": ["team-a"]}) user_b = User("user-b", {"roles": ["team-b"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_a uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_b with pytest.raises(ResourceNotFoundError): await op(s3_provider, uploaded.id) @@ -56,11 +56,11 @@ async def test_shared_role_allows_listing(s3_provider, sample_text_file): user_a = User("user-a", {"roles": ["shared-role"]}) user_b = User("user-b", {"roles": ["shared-role"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_a uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_b listed = await s3_provider.openai_list_files() assert any(f.id == uploaded.id for f in listed.data) @@ -79,10 +79,10 @@ async def test_shared_role_allows_access(s3_provider, sample_text_file, op): user_x = User("user-x", {"roles": ["shared-role"]}) user_y = User("user-y", {"roles": ["shared-role"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_x uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_y await op(s3_provider, uploaded.id) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 6408e25ab..b4ea77c0a 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -11,13 +11,13 @@ import numpy as np import pytest from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig +from llama_stack.core.storage.kvstore import register_kvstore_backends from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter -from llama_stack.providers.utils.kvstore import register_kvstore_backends from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore EMBEDDING_DIMENSION = 768 @@ -279,7 +279,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd ) as mock_check_version: mock_check_version.return_value = "0.5.1" - with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl: + with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl: mock_kvstore = AsyncMock() mock_kvstore_impl.return_value = mock_kvstore diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 1b5032782..2b32de833 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -9,12 +9,12 @@ import pytest from llama_stack.core.datatypes import VectorStoreWithOwner from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig +from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends from llama_stack.core.store.registry import ( KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, ) -from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends from llama_stack_api import Model, VectorStore diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py index 0939414dd..cd8c38eed 100644 --- a/tests/unit/server/test_quota.py +++ b/tests/unit/server/test_quota.py @@ -15,7 +15,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod from llama_stack.core.server.quota import QuotaMiddleware from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore import register_kvstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends @pytest.fixture diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 8f8a61ea7..a1b03f630 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -24,8 +24,8 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageConfig, ) -from llama_stack.providers.utils.kvstore import register_kvstore_backends -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import Inference, InlineProviderSpec, ProviderSpec diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index bdcc529ce..22d4ec1e5 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -9,8 +9,8 @@ import time import pytest from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import ( OpenAIAssistantMessageParam, OpenAIChatCompletion, diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py index a31377306..1aaf57b44 100644 --- a/tests/unit/utils/kvstore/test_sqlite_memory.py +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -5,8 +5,8 @@ # the root directory of this source tree. -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl +from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig +from llama_stack.core.storage.kvstore.sqlite.sqlite import SqliteKVStoreImpl async def test_memory_kvstore_persistence_behavior(): diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 8c108d9c1..a71fb39f6 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -11,8 +11,8 @@ from uuid import uuid4 import pytest from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack.providers.utils.responses.responses_store import ResponsesStore -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import OpenAIMessageParam, OpenAIResponseInput, OpenAIResponseObject, OpenAIUserMessageParam, Order diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index d7ba0dc89..421e3b69d 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -9,9 +9,9 @@ from tempfile import TemporaryDirectory import pytest -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType async def test_sqlite_sqlstore(): diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index d85e784a9..e9a6b511b 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -10,13 +10,13 @@ from unittest.mock import patch from llama_stack.core.access_control.access_control import default_policy, is_action_allowed from llama_stack.core.access_control.datatypes import Action from llama_stack.core.datatypes import User -from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord +from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack_api.internal.sqlstore import ColumnType -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user): """Test that fetch_all works correctly with where_sql for access control""" with TemporaryDirectory() as tmp_dir: @@ -78,7 +78,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic assert row["title"] == "User Document" -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_sql_policy_consistency(mock_get_authenticated_user): """Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic""" with TemporaryDirectory() as tmp_dir: @@ -164,7 +164,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): ) -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): """Test that user attributes are properly captured during insert""" with TemporaryDirectory() as tmp_dir: From 91f1b352b4ca2c6d9a4624663bffbd2a8d98fb69 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 18 Nov 2025 18:22:26 -0500 Subject: [PATCH 3/3] chore: add storage sane defaults (#4182) # What does this PR do? since `StackRunConfig` requires certain parts of `StorageConfig`, it'd probably make sense to template in some defaults that will "just work" for most usecases specifically introduce`ServerStoresConfig` defaults for inference, metadata, conversations and prompts. We already actually funnel in defaults for these sections ad-hoc throughout the codebase additionally set some `backends` defaults for the `StorageConfig`. This will alleviate some weirdness for `--providers` for run/list-deps and also some work I have to better align our list-deps/run datatypes --------- Signed-off-by: Charlie Doern --- src/llama_stack/core/storage/datatypes.py | 27 ++++++++++++++++--- .../unit/conversations/test_conversations.py | 6 +++++ tests/unit/core/test_stack_validation.py | 23 +++++++++++++--- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/llama_stack/core/storage/datatypes.py b/src/llama_stack/core/storage/datatypes.py index 4b17b9ea9..527c1b828 100644 --- a/src/llama_stack/core/storage/datatypes.py +++ b/src/llama_stack/core/storage/datatypes.py @@ -12,6 +12,8 @@ from typing import Annotated, Literal from pydantic import BaseModel, Field, field_validator +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR + class StorageBackendType(StrEnum): KV_REDIS = "kv_redis" @@ -256,15 +258,24 @@ class ResponsesStoreReference(InferenceStoreReference): class ServerStoresConfig(BaseModel): metadata: KVStoreReference | None = Field( - default=None, + default=KVStoreReference( + backend="kv_default", + namespace="registry", + ), description="Metadata store configuration (uses KV backend)", ) inference: InferenceStoreReference | None = Field( - default=None, + default=InferenceStoreReference( + backend="sql_default", + table_name="inference_store", + ), description="Inference store configuration (uses SQL backend)", ) conversations: SqlStoreReference | None = Field( - default=None, + default=SqlStoreReference( + backend="sql_default", + table_name="openai_conversations", + ), description="Conversations store configuration (uses SQL backend)", ) responses: ResponsesStoreReference | None = Field( @@ -272,13 +283,21 @@ class ServerStoresConfig(BaseModel): description="Responses store configuration (uses SQL backend)", ) prompts: KVStoreReference | None = Field( - default=None, + default=KVStoreReference(backend="kv_default", namespace="prompts"), description="Prompts store configuration (uses KV backend)", ) class StorageConfig(BaseModel): backends: dict[str, StorageBackendConfig] = Field( + default={ + "kv_default": SqliteKVStoreConfig( + db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/kvstore.db", + ), + "sql_default": SqliteSqlStoreConfig( + db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/sql_store.db", + ), + }, description="Named backend configurations (e.g., 'default', 'cache')", ) stores: ServerStoresConfig = Field( diff --git a/tests/unit/conversations/test_conversations.py b/tests/unit/conversations/test_conversations.py index e8286576b..3f9df5fc0 100644 --- a/tests/unit/conversations/test_conversations.py +++ b/tests/unit/conversations/test_conversations.py @@ -38,6 +38,9 @@ async def service(): }, stores=ServerStoresConfig( conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"), + metadata=None, + inference=None, + prompts=None, ), ) register_sqlstore_backends({"sql_test": storage.backends["sql_test"]}) @@ -142,6 +145,9 @@ async def test_policy_configuration(): }, stores=ServerStoresConfig( conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"), + metadata=None, + inference=None, + prompts=None, ), ) register_sqlstore_backends({"sql_test": storage.backends["sql_test"]}) diff --git a/tests/unit/core/test_stack_validation.py b/tests/unit/core/test_stack_validation.py index 462a25c8b..5f75bc522 100644 --- a/tests/unit/core/test_stack_validation.py +++ b/tests/unit/core/test_stack_validation.py @@ -10,8 +10,9 @@ from unittest.mock import AsyncMock import pytest -from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig +from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config +from llama_stack.core.storage.datatypes import ServerStoresConfig, StorageConfig from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model, ModelType, Shield @@ -21,7 +22,15 @@ class TestVectorStoresValidation: run_config = StackRunConfig( image_name="test", providers={}, - storage=StorageConfig(backends={}, stores={}), + storage=StorageConfig( + backends={}, + stores=ServerStoresConfig( + metadata=None, + inference=None, + conversations=None, + prompts=None, + ), + ), vector_stores=VectorStoresConfig( default_provider_id="faiss", default_embedding_model=QualifiedModel( @@ -41,7 +50,15 @@ class TestVectorStoresValidation: run_config = StackRunConfig( image_name="test", providers={}, - storage=StorageConfig(backends={}, stores={}), + storage=StorageConfig( + backends={}, + stores=ServerStoresConfig( + metadata=None, + inference=None, + conversations=None, + prompts=None, + ), + ), vector_stores=VectorStoresConfig( default_provider_id="faiss", default_embedding_model=QualifiedModel(