This commit is contained in:
Sumanth Kamenani 2025-09-24 09:30:04 +02:00 committed by GitHub
commit 689f1db815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 284 additions and 8 deletions

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from __future__ import annotations
"""Vector store global config stuff.
Basically just holds default embedding model settings so we don't have to
pass them around everywhere. Router picks these up when client doesn't specify.
"""
import os
from pydantic import BaseModel, ConfigDict, Field
__all__ = ["VectorStoreConfig"]
class VectorStoreConfig(BaseModel):
"""Default embedding model config that gets picked up from env vars."""
default_embedding_model: str | None = Field(
default_factory=lambda: os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL")
)
# dimension from env - fallback to None if not set or invalid
default_embedding_dimension: int | None = Field(
default_factory=lambda: int(os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", 0)) or None, ge=1
)
model_config = ConfigDict(frozen=True)

View file

@ -12,6 +12,7 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
@ -474,6 +475,12 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""",
)
# Global vector-store defaults (embedding model etc.)
vector_store_config: VectorStoreConfig = Field(
default_factory=VectorStoreConfig,
description="Global defaults for vector-store creation (embedding model, dimension, …)",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -11,6 +11,7 @@ from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
@ -76,6 +77,41 @@ class VectorIORouter(VectorIO):
logger.error(f"Error getting embedding models: {e}")
return None
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
"""Figure out which embedding model to use and what dimension it has."""
# if they passed a model explicitly, use that
if explicit_model is not None:
# try to look up dimension from our routing table
models = await self.routing_table.get_all_with_type("model")
for model in models:
if getattr(model, "identifier", None) == explicit_model:
dim = model.metadata.get("embedding_dimension")
if dim is None:
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
return explicit_model, dim
# model not found in registry - this is an error
raise ValueError(f"Embedding model '{explicit_model}' not found in model registry")
# check if we have global defaults set via env vars
config = VectorStoreConfig()
if config.default_embedding_model is not None:
if config.default_embedding_dimension is None:
raise ValueError(
f"default_embedding_model '{config.default_embedding_model}' is set but default_embedding_dimension is missing"
)
return config.default_embedding_model, config.default_embedding_dimension
# fallback to first available embedding model for compatibility
fallback = await self._get_first_embedding_model()
if fallback is not None:
return fallback
# if no models available, raise error
raise ValueError(
"No embedding model specified and no default configured. Either provide an embedding_model parameter or set vector_store_config.default_embedding_model"
)
async def register_vector_db(
self,
vector_db_id: str,
@ -102,7 +138,7 @@ class VectorIORouter(VectorIO):
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.chunk_id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -131,13 +167,8 @@ class VectorIORouter(VectorIO):
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
# Determine which embedding model to use based on new precedence
embedding_model, embedding_dimension = await self._resolve_embedding_model(embedding_model)
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(

View file

@ -39,6 +39,9 @@ distribution_spec:
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
vector_store_config:
default_embedding_model: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_MODEL:=all-MiniLM-L6-v2}
default_embedding_dimension: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION:=384}
image_type: venv
additional_pip_packages:
- sqlalchemy[asyncio]