This commit is contained in:
Sumanth Kamenani 2025-09-24 09:30:04 +02:00 committed by GitHub
commit 689f1db815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 284 additions and 8 deletions

View file

@ -12,6 +12,7 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
@ -474,6 +475,12 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""",
)
# Global vector-store defaults (embedding model etc.)
vector_store_config: VectorStoreConfig = Field(
default_factory=VectorStoreConfig,
description="Global defaults for vector-store creation (embedding model, dimension, …)",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -11,6 +11,7 @@ from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
@ -76,6 +77,41 @@ class VectorIORouter(VectorIO):
logger.error(f"Error getting embedding models: {e}")
return None
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
"""Figure out which embedding model to use and what dimension it has."""
# if they passed a model explicitly, use that
if explicit_model is not None:
# try to look up dimension from our routing table
models = await self.routing_table.get_all_with_type("model")
for model in models:
if getattr(model, "identifier", None) == explicit_model:
dim = model.metadata.get("embedding_dimension")
if dim is None:
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
return explicit_model, dim
# model not found in registry - this is an error
raise ValueError(f"Embedding model '{explicit_model}' not found in model registry")
# check if we have global defaults set via env vars
config = VectorStoreConfig()
if config.default_embedding_model is not None:
if config.default_embedding_dimension is None:
raise ValueError(
f"default_embedding_model '{config.default_embedding_model}' is set but default_embedding_dimension is missing"
)
return config.default_embedding_model, config.default_embedding_dimension
# fallback to first available embedding model for compatibility
fallback = await self._get_first_embedding_model()
if fallback is not None:
return fallback
# if no models available, raise error
raise ValueError(
"No embedding model specified and no default configured. Either provide an embedding_model parameter or set vector_store_config.default_embedding_model"
)
async def register_vector_db(
self,
vector_db_id: str,
@ -102,7 +138,7 @@ class VectorIORouter(VectorIO):
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.chunk_id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -131,13 +167,8 @@ class VectorIORouter(VectorIO):
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
# Determine which embedding model to use based on new precedence
embedding_model, embedding_dimension = await self._resolve_embedding_model(embedding_model)
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(