Add configurable embedding models for vector IO providers

This change lets users configure default embedding models at the provider level instead of always relying on system defaults. Each vector store provider can now specify an embedding_model and optional embedding_dimension in their config.

Key features:
- Auto-dimension lookup for standard models from the registry
- Support for Matryoshka embeddings with custom dimensions
- Three-tier priority: explicit params > provider config > system fallback
- Full backward compatibility - existing setups work unchanged
- Comprehensive test coverage with 20 test cases

Updated all vector IO providers (FAISS, Chroma, Milvus, Qdrant, etc.) with the new config fields and added detailed documentation with examples.

Fixes #2729
This commit is contained in:
skamenan7 2025-07-15 16:46:40 -04:00
parent 2298d2473c
commit 474b50b422
28 changed files with 1160 additions and 24 deletions

View file

@ -7,9 +7,7 @@
import asyncio
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
@ -28,6 +26,7 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.vector_io.embedding_utils import get_provider_embedding_model_info
logger = get_logger(name=__name__, category="core")
@ -51,10 +50,10 @@ class VectorIORouter(VectorIO):
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
"""Get the first available embedding model identifier (DEPRECATED - use embedding_utils instead)."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
all_models = await self.routing_table.get_all_with_type("model") # type: ignore
# Filter for embedding models
embedding_models = [
@ -75,6 +74,31 @@ class VectorIORouter(VectorIO):
logger.error(f"Error getting embedding models: {e}")
return None
async def _get_provider_config(self, provider_id: str | None = None) -> Any:
"""Get the provider configuration object for embedding model defaults."""
try:
# If no provider_id specified, get the first available provider
if provider_id is None and hasattr(self.routing_table, "impls_by_provider_id"):
available_providers = list(self.routing_table.impls_by_provider_id.keys()) # type: ignore
if available_providers:
provider_id = available_providers[0]
else:
logger.warning("No vector IO providers available")
return None
if provider_id and hasattr(self.routing_table, "impls_by_provider_id"):
provider_impl = self.routing_table.impls_by_provider_id.get(provider_id) # type: ignore
if provider_impl and hasattr(provider_impl, "__provider_config__"):
return provider_impl.__provider_config__
else:
logger.debug(f"Provider {provider_id} has no config object attached")
return None
return None
except Exception as e:
logger.error(f"Error getting provider config: {e}")
return None
async def register_vector_db(
self,
vector_db_id: str,
@ -84,7 +108,7 @@ class VectorIORouter(VectorIO):
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
await self.routing_table.register_vector_db( # type: ignore
vector_db_id,
embedding_model,
embedding_dimension,
@ -127,13 +151,64 @@ class VectorIORouter(VectorIO):
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
# Use the new 3-tier priority system for embedding model selection
provider_config = await self._get_provider_config(provider_id)
# Log the resolution context for debugging
logger.debug(f"Resolving embedding model for vector store '{name}' with provider_id={provider_id}")
logger.debug(f"Explicit model: {embedding_model}, explicit dimension: {embedding_dimension}")
logger.debug(
f"Provider config embedding_model: {getattr(provider_config, 'embedding_model', None) if provider_config else None}"
)
logger.debug(
f"Provider config embedding_dimension: {getattr(provider_config, 'embedding_dimension', None) if provider_config else None}"
)
try:
embedding_model_info = await get_provider_embedding_model_info(
routing_table=self.routing_table,
provider_config=provider_config,
explicit_model_id=embedding_model,
explicit_dimension=embedding_dimension,
)
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
resolved_model, resolved_dimension = embedding_model_info
# Enhanced logging to show resolution path
if embedding_model is not None:
logger.info(
f"✅ Vector store '{name}': Using EXPLICIT embedding model '{resolved_model}' (dimension: {resolved_dimension})"
)
elif provider_config and getattr(provider_config, "embedding_model", None):
logger.info(
f"✅ Vector store '{name}': Using PROVIDER DEFAULT embedding model '{resolved_model}' (dimension: {resolved_dimension}) from provider '{provider_id}'"
)
if getattr(provider_config, "embedding_dimension", None):
logger.info(f" └── Provider config dimension override: {resolved_dimension}")
else:
logger.info(f" └── Auto-lookup dimension from model registry: {resolved_dimension}")
else:
logger.info(
f"✅ Vector store '{name}': Using SYSTEM DEFAULT embedding model '{resolved_model}' (dimension: {resolved_dimension})"
)
logger.warning(
f"⚠️ Consider configuring a default embedding model for provider '{provider_id}' to avoid fallback behavior"
)
embedding_model, embedding_dimension = resolved_model, resolved_dimension
except Exception as e:
logger.error(
f"❌ Failed to resolve embedding model for vector store '{name}' with provider '{provider_id}': {e}"
)
logger.error(f" Debug info - Explicit: model={embedding_model}, dim={embedding_dimension}")
logger.error(
f" Debug info - Provider: model={getattr(provider_config, 'embedding_model', None) if provider_config else None}, dim={getattr(provider_config, 'embedding_dimension', None) if provider_config else None}"
)
raise ValueError(f"Unable to determine embedding model for vector store '{name}': {e}") from e
vector_db_id = name
registered_vector_db = await self.routing_table.register_vector_db(