mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 16:23:08 +00:00
Fix discriminator ambiguity with context-aware backend parsing
- Both SqliteKVStoreConfig and SqliteSqlStoreConfig use type='sqlite' - Pydantic cannot distinguish them in a union - Solution: Custom validator parses backends based on which stores reference them - Metadata store requires KVStore, inference/conversations require SqlStore - Separate kvstore/sqlstore backends in configs for clarity
This commit is contained in:
parent
088a6ac652
commit
5672e70832
4 changed files with 157 additions and 37 deletions
|
@ -28,8 +28,8 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
|
@ -47,7 +47,7 @@ class ConversationServiceConfig(BaseModel):
|
|||
policy: list[AccessRule] = []
|
||||
|
||||
@property
|
||||
def conversations_store(self) -> SqlStoreConfig:
|
||||
def conversations_store(self) -> SqliteSqlStoreConfig | PostgresSqlStoreConfig:
|
||||
"""Resolve conversations store from persistence config."""
|
||||
return resolve_conversations_store_config(self.run_config.persistence)
|
||||
|
||||
|
|
|
@ -27,8 +27,18 @@ from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
)
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||
|
@ -467,7 +477,7 @@ class StoresConfig(BaseModel):
|
|||
class PersistenceConfig(BaseModel):
|
||||
"""Unified persistence configuration."""
|
||||
|
||||
backends: dict[str, KVStoreConfig | SqlStoreConfig] = Field(
|
||||
backends: dict[str, Any] = Field(
|
||||
description="Named backend configurations (e.g., 'default', 'cache')",
|
||||
)
|
||||
stores: StoresConfig | None = Field(
|
||||
|
@ -475,6 +485,45 @@ class PersistenceConfig(BaseModel):
|
|||
description="Store references to backends",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def parse_backends(cls, data: Any) -> Any:
|
||||
"""Parse backends intelligently based on which stores reference them."""
|
||||
if not isinstance(data, dict) or "backends" not in data:
|
||||
return data
|
||||
|
||||
backends_raw = data["backends"]
|
||||
stores = data.get("stores", {})
|
||||
|
||||
# Determine which backends are used by which store types
|
||||
metadata_backend = (stores.get("metadata") or {}).get("backend")
|
||||
inference_backend = (stores.get("inference") or {}).get("backend")
|
||||
conversations_backend = (stores.get("conversations") or {}).get("backend")
|
||||
|
||||
# Parse each backend based on usage context
|
||||
parsed_backends = {}
|
||||
for name, config in backends_raw.items():
|
||||
if not isinstance(config, dict):
|
||||
parsed_backends[name] = config
|
||||
continue
|
||||
|
||||
# Determine backend type based on which store uses it
|
||||
if name == metadata_backend:
|
||||
# Metadata requires KVStore
|
||||
parsed_backends[name] = _parse_kvstore_config(config)
|
||||
elif name in (inference_backend, conversations_backend):
|
||||
# Inference/conversations require SqlStore
|
||||
parsed_backends[name] = _parse_sqlstore_config(config)
|
||||
else:
|
||||
# Unknown usage - try SqlStore first (for backward compat), then KVStore
|
||||
try:
|
||||
parsed_backends[name] = _parse_sqlstore_config(config)
|
||||
except:
|
||||
parsed_backends[name] = _parse_kvstore_config(config)
|
||||
|
||||
data["backends"] = parsed_backends
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_backend_references(self) -> Self:
|
||||
"""Check all store refs point to defined backends."""
|
||||
|
@ -495,14 +544,45 @@ class PersistenceConfig(BaseModel):
|
|||
return self
|
||||
|
||||
|
||||
def _parse_kvstore_config(config: dict) -> (
|
||||
RedisKVStoreConfig
|
||||
| SqliteKVStoreConfig
|
||||
| PostgresKVStoreConfig
|
||||
| MongoDBKVStoreConfig
|
||||
):
|
||||
"""Parse a KVStore config from dict."""
|
||||
type_val = config.get("type")
|
||||
if type_val == "redis":
|
||||
return RedisKVStoreConfig(**config)
|
||||
elif type_val == "sqlite":
|
||||
return SqliteKVStoreConfig(**config)
|
||||
elif type_val == "postgres":
|
||||
return PostgresKVStoreConfig(**config)
|
||||
elif type_val == "mongodb":
|
||||
return MongoDBKVStoreConfig(**config)
|
||||
else:
|
||||
raise ValueError(f"Unknown KVStore type: {type_val}")
|
||||
|
||||
|
||||
def _parse_sqlstore_config(config: dict) -> SqliteSqlStoreConfig | PostgresSqlStoreConfig:
|
||||
"""Parse a SqlStore config from dict."""
|
||||
type_val = config.get("type")
|
||||
if type_val == "sqlite":
|
||||
return SqliteSqlStoreConfig(**config)
|
||||
elif type_val == "postgres":
|
||||
return PostgresSqlStoreConfig(**config)
|
||||
else:
|
||||
raise ValueError(f"Unknown SqlStore type: {type_val}")
|
||||
|
||||
|
||||
class InferenceStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class ResponsesStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
|
|
@ -12,13 +12,27 @@ from llama_stack.core.datatypes import (
|
|||
StoreReference,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqlStoreConfig,
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
)
|
||||
|
||||
T = TypeVar("T", KVStoreConfig, SqlStoreConfig)
|
||||
# Type aliases for cleaner code
|
||||
KVStoreConfigTypes = (
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
MongoDBKVStoreConfig,
|
||||
)
|
||||
SqlStoreConfigTypes = (SqliteSqlStoreConfig, PostgresSqlStoreConfig)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def resolve_backend(
|
||||
|
@ -50,7 +64,7 @@ def resolve_backend(
|
|||
)
|
||||
|
||||
# Clone backend and apply namespace if KVStore
|
||||
if isinstance(backend_config, (KVStoreConfig.__args__)): # type: ignore
|
||||
if isinstance(backend_config, KVStoreConfigTypes):
|
||||
config_dict = backend_config.model_dump()
|
||||
if store_ref.namespace:
|
||||
config_dict["namespace"] = store_ref.namespace
|
||||
|
@ -61,7 +75,7 @@ def resolve_backend(
|
|||
|
||||
def resolve_inference_store_config(
|
||||
persistence: PersistenceConfig | None,
|
||||
) -> tuple[SqlStoreConfig, int, int]:
|
||||
) -> tuple[SqliteSqlStoreConfig | PostgresSqlStoreConfig, int, int]:
|
||||
"""
|
||||
Resolve inference store configuration.
|
||||
|
||||
|
@ -86,7 +100,7 @@ def resolve_inference_store_config(
|
|||
f"not found in persistence.backends"
|
||||
)
|
||||
|
||||
if not isinstance(backend_config, (SqlStoreConfig.__args__)): # type: ignore
|
||||
if not isinstance(backend_config, SqlStoreConfigTypes):
|
||||
raise ValueError(
|
||||
f"Inference store requires SqlStore backend, got {type(backend_config).__name__}"
|
||||
)
|
||||
|
@ -101,7 +115,12 @@ def resolve_inference_store_config(
|
|||
def resolve_metadata_store_config(
|
||||
persistence: PersistenceConfig | None,
|
||||
image_name: str,
|
||||
) -> KVStoreConfig:
|
||||
) -> (
|
||||
RedisKVStoreConfig
|
||||
| SqliteKVStoreConfig
|
||||
| PostgresKVStoreConfig
|
||||
| MongoDBKVStoreConfig
|
||||
):
|
||||
"""
|
||||
Resolve metadata store configuration.
|
||||
|
||||
|
@ -116,21 +135,32 @@ def resolve_metadata_store_config(
|
|||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||
)
|
||||
|
||||
store_ref = None
|
||||
if persistence and persistence.stores:
|
||||
store_ref = persistence.stores.metadata
|
||||
if not persistence or not persistence.stores or not persistence.stores.metadata:
|
||||
return default_config
|
||||
|
||||
return resolve_backend(
|
||||
persistence=persistence,
|
||||
store_ref=store_ref,
|
||||
default_factory=lambda: default_config,
|
||||
store_name="metadata",
|
||||
)
|
||||
metadata_ref = persistence.stores.metadata
|
||||
backend_config = persistence.backends.get(metadata_ref.backend)
|
||||
if not backend_config:
|
||||
raise ValueError(
|
||||
f"Backend '{metadata_ref.backend}' referenced by metadata store "
|
||||
f"not found in persistence.backends"
|
||||
)
|
||||
|
||||
if not isinstance(backend_config, KVStoreConfigTypes):
|
||||
raise ValueError(
|
||||
f"Metadata store requires KVStore backend, got {type(backend_config).__name__}"
|
||||
)
|
||||
|
||||
# Apply namespace if specified
|
||||
config_dict = backend_config.model_dump()
|
||||
if metadata_ref.namespace:
|
||||
config_dict["namespace"] = metadata_ref.namespace
|
||||
return type(backend_config)(**config_dict) # type: ignore
|
||||
|
||||
|
||||
def resolve_conversations_store_config(
|
||||
persistence: PersistenceConfig | None,
|
||||
) -> SqlStoreConfig:
|
||||
) -> SqliteSqlStoreConfig | PostgresSqlStoreConfig:
|
||||
"""
|
||||
Resolve conversations store configuration.
|
||||
|
||||
|
@ -141,13 +171,20 @@ def resolve_conversations_store_config(
|
|||
db_path=(RUNTIME_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
|
||||
store_ref = None
|
||||
if persistence and persistence.stores:
|
||||
store_ref = persistence.stores.conversations
|
||||
if not persistence or not persistence.stores or not persistence.stores.conversations:
|
||||
return default_config
|
||||
|
||||
return resolve_backend(
|
||||
persistence=persistence,
|
||||
store_ref=store_ref,
|
||||
default_factory=lambda: default_config,
|
||||
store_name="conversations",
|
||||
)
|
||||
conversations_ref = persistence.stores.conversations
|
||||
backend_config = persistence.backends.get(conversations_ref.backend)
|
||||
if not backend_config:
|
||||
raise ValueError(
|
||||
f"Backend '{conversations_ref.backend}' referenced by conversations store "
|
||||
f"not found in persistence.backends"
|
||||
)
|
||||
|
||||
if not isinstance(backend_config, SqlStoreConfigTypes):
|
||||
raise ValueError(
|
||||
f"Conversations store requires SqlStore backend, got {type(backend_config).__name__}"
|
||||
)
|
||||
|
||||
return backend_config # type: ignore
|
||||
|
|
|
@ -220,14 +220,17 @@ providers:
|
|||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db
|
||||
persistence:
|
||||
backends:
|
||||
default:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/kvstore.db
|
||||
sqlstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlstore.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: default
|
||||
backend: kvstore
|
||||
inference:
|
||||
backend: default
|
||||
backend: sqlstore
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue