mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 12:30:01 +00:00
Add configurable embedding models for vector IO providers
This change lets users configure default embedding models at the provider level instead of always relying on system defaults. Each vector store provider can now specify an embedding_model and optional embedding_dimension in their config. Key features: - Auto-dimension lookup for standard models from the registry - Support for Matryoshka embeddings with custom dimensions - Three-tier priority: explicit params > provider config > system fallback - Full backward compatibility - existing setups work unchanged - Comprehensive test coverage with 20 test cases Updated all vector IO providers (FAISS, Chroma, Milvus, Qdrant, etc.) with the new config fields and added detailed documentation with examples. Fixes #2729
This commit is contained in:
parent
2298d2473c
commit
474b50b422
28 changed files with 1160 additions and 24 deletions
|
|
@ -7,9 +7,7 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
|
@ -28,6 +26,7 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.vector_io.embedding_utils import get_provider_embedding_model_info
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -51,10 +50,10 @@ class VectorIORouter(VectorIO):
|
|||
pass
|
||||
|
||||
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
||||
"""Get the first available embedding model identifier."""
|
||||
"""Get the first available embedding model identifier (DEPRECATED - use embedding_utils instead)."""
|
||||
try:
|
||||
# Get all models from the routing table
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
all_models = await self.routing_table.get_all_with_type("model") # type: ignore
|
||||
|
||||
# Filter for embedding models
|
||||
embedding_models = [
|
||||
|
|
@ -75,6 +74,31 @@ class VectorIORouter(VectorIO):
|
|||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
|
||||
async def _get_provider_config(self, provider_id: str | None = None) -> Any:
|
||||
"""Get the provider configuration object for embedding model defaults."""
|
||||
try:
|
||||
# If no provider_id specified, get the first available provider
|
||||
if provider_id is None and hasattr(self.routing_table, "impls_by_provider_id"):
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys()) # type: ignore
|
||||
if available_providers:
|
||||
provider_id = available_providers[0]
|
||||
else:
|
||||
logger.warning("No vector IO providers available")
|
||||
return None
|
||||
|
||||
if provider_id and hasattr(self.routing_table, "impls_by_provider_id"):
|
||||
provider_impl = self.routing_table.impls_by_provider_id.get(provider_id) # type: ignore
|
||||
if provider_impl and hasattr(provider_impl, "__provider_config__"):
|
||||
return provider_impl.__provider_config__
|
||||
else:
|
||||
logger.debug(f"Provider {provider_id} has no config object attached")
|
||||
return None
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting provider config: {e}")
|
||||
return None
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -84,7 +108,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_vector_db_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
await self.routing_table.register_vector_db( # type: ignore
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
|
|
@ -127,13 +151,64 @@ class VectorIORouter(VectorIO):
|
|||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
# Use the new 3-tier priority system for embedding model selection
|
||||
provider_config = await self._get_provider_config(provider_id)
|
||||
|
||||
# Log the resolution context for debugging
|
||||
logger.debug(f"Resolving embedding model for vector store '{name}' with provider_id={provider_id}")
|
||||
logger.debug(f"Explicit model: {embedding_model}, explicit dimension: {embedding_dimension}")
|
||||
logger.debug(
|
||||
f"Provider config embedding_model: {getattr(provider_config, 'embedding_model', None) if provider_config else None}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Provider config embedding_dimension: {getattr(provider_config, 'embedding_dimension', None) if provider_config else None}"
|
||||
)
|
||||
|
||||
try:
|
||||
embedding_model_info = await get_provider_embedding_model_info(
|
||||
routing_table=self.routing_table,
|
||||
provider_config=provider_config,
|
||||
explicit_model_id=embedding_model,
|
||||
explicit_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
|
||||
resolved_model, resolved_dimension = embedding_model_info
|
||||
|
||||
# Enhanced logging to show resolution path
|
||||
if embedding_model is not None:
|
||||
logger.info(
|
||||
f"✅ Vector store '{name}': Using EXPLICIT embedding model '{resolved_model}' (dimension: {resolved_dimension})"
|
||||
)
|
||||
elif provider_config and getattr(provider_config, "embedding_model", None):
|
||||
logger.info(
|
||||
f"✅ Vector store '{name}': Using PROVIDER DEFAULT embedding model '{resolved_model}' (dimension: {resolved_dimension}) from provider '{provider_id}'"
|
||||
)
|
||||
if getattr(provider_config, "embedding_dimension", None):
|
||||
logger.info(f" └── Provider config dimension override: {resolved_dimension}")
|
||||
else:
|
||||
logger.info(f" └── Auto-lookup dimension from model registry: {resolved_dimension}")
|
||||
else:
|
||||
logger.info(
|
||||
f"✅ Vector store '{name}': Using SYSTEM DEFAULT embedding model '{resolved_model}' (dimension: {resolved_dimension})"
|
||||
)
|
||||
logger.warning(
|
||||
f"⚠️ Consider configuring a default embedding model for provider '{provider_id}' to avoid fallback behavior"
|
||||
)
|
||||
|
||||
embedding_model, embedding_dimension = resolved_model, resolved_dimension
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"❌ Failed to resolve embedding model for vector store '{name}' with provider '{provider_id}': {e}"
|
||||
)
|
||||
logger.error(f" Debug info - Explicit: model={embedding_model}, dim={embedding_dimension}")
|
||||
logger.error(
|
||||
f" Debug info - Provider: model={getattr(provider_config, 'embedding_model', None) if provider_config else None}, dim={getattr(provider_config, 'embedding_dimension', None) if provider_config else None}"
|
||||
)
|
||||
raise ValueError(f"Unable to determine embedding model for vector store '{name}': {e}") from e
|
||||
|
||||
vector_db_id = name
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
|
|||
|
|
@ -6,12 +6,25 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
return {
|
||||
"db_path": db_path,
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
|
|
@ -18,6 +18,14 @@ from llama_stack.schema_utils import json_schema_type
|
|||
@json_schema_type
|
||||
class FaissVectorIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
@ -25,5 +33,8 @@ class FaissVectorIOConfig(BaseModel):
|
|||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="faiss_store.db",
|
||||
)
|
||||
),
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,14 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
@ -29,4 +37,7 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_registry.db",
|
||||
),
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
|
@ -15,9 +15,20 @@ from llama_stack.schema_utils import json_schema_type
|
|||
@json_schema_type
|
||||
class QdrantVectorIOConfig(BaseModel):
|
||||
path: str
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,14 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
class SQLiteVectorIOConfig(BaseModel):
|
||||
db_path: str = Field(description="Path to the SQLite database file")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
|
|
@ -26,4 +34,7 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
__distro_dir__=__distro_dir__,
|
||||
db_name="sqlite_vec_registry.db",
|
||||
),
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,12 +6,25 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
url: str | None
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"url": url}
|
||||
return {
|
||||
"url": url,
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,14 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
token: str | None = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||
|
|
@ -25,4 +33,10 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
||||
return {
|
||||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,15 +18,32 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
db: str | None = Field(default="postgres")
|
||||
user: str | None = Field(default="postgres")
|
||||
password: str | None = Field(default="mysecretpassword")
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
||||
port: int = "${env.PGVECTOR_PORT:=5432}",
|
||||
port: int | str = "${env.PGVECTOR_PORT:=5432}",
|
||||
db: str = "${env.PGVECTOR_DB}",
|
||||
user: str = "${env.PGVECTOR_USER}",
|
||||
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
||||
return {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"db": db,
|
||||
"user": user,
|
||||
"password": password,
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
|
@ -23,9 +23,20 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
prefix: str | None = None
|
||||
timeout: int | None = None
|
||||
host: str | None = None
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.QDRANT_API_KEY}",
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
|
|
@ -15,6 +15,19 @@ class WeaviateRequestProviderData(BaseModel):
|
|||
|
||||
|
||||
class WeaviateVectorIOConfig(BaseModel):
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Optional default embedding model for this provider. If not specified, will use system default.",
|
||||
)
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
return {
|
||||
# Optional: Configure default embedding model for this provider
|
||||
# "embedding_model": "all-MiniLM-L6-v2",
|
||||
# "embedding_dimension": 384, # Only needed for variable-dimension models
|
||||
}
|
||||
|
|
|
|||
153
llama_stack/providers/utils/vector_io/embedding_utils.py
Normal file
153
llama_stack/providers/utils/vector_io/embedding_utils.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
async def get_embedding_model_info(
|
||||
model_id: str, routing_table: RoutingTable, override_dimension: int | None = None
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
Get embedding model info with auto-dimension lookup.
|
||||
|
||||
This function validates that the specified model is an embedding model
|
||||
and returns its embedding dimensions, with support for Matryoshka embeddings
|
||||
through dimension overrides.
|
||||
|
||||
Args:
|
||||
model_id: The embedding model identifier to look up
|
||||
routing_table: Access to the model registry for validation and dimension lookup
|
||||
override_dimension: Optional dimension override for Matryoshka models that
|
||||
support variable dimensions (e.g., nomic-embed-text)
|
||||
|
||||
Returns:
|
||||
tuple: (model_id, embedding_dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found, not an embedding model, or missing dimension info
|
||||
"""
|
||||
try:
|
||||
# Look up the model in the routing table
|
||||
model = await routing_table.get_object_by_identifier("model", model_id) # type: ignore
|
||||
if model is None:
|
||||
raise ValueError(f"Embedding model '{model_id}' not found in model registry")
|
||||
|
||||
# Validate that this is an embedding model
|
||||
if not hasattr(model, "model_type") or model.model_type != ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is not an embedding model (type: {getattr(model, 'model_type', 'unknown')})"
|
||||
)
|
||||
|
||||
# If override dimension is provided, use it (for Matryoshka embeddings)
|
||||
if override_dimension is not None:
|
||||
if override_dimension <= 0:
|
||||
raise ValueError(f"Override dimension must be positive, got {override_dimension}")
|
||||
logger.info(f"Using override dimension {override_dimension} for embedding model '{model_id}'")
|
||||
return model_id, override_dimension
|
||||
|
||||
# Extract embedding dimension from model metadata
|
||||
if not hasattr(model, "metadata") or not model.metadata:
|
||||
raise ValueError(f"Embedding model '{model_id}' has no metadata")
|
||||
|
||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
|
||||
|
||||
if not isinstance(embedding_dimension, int) or embedding_dimension <= 0:
|
||||
raise ValueError(f"Invalid embedding_dimension for model '{model_id}': {embedding_dimension}")
|
||||
|
||||
logger.debug(f"Auto-lookup successful for embedding model '{model_id}': dimension {embedding_dimension}")
|
||||
return model_id, embedding_dimension
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up embedding model info for '{model_id}': {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_provider_embedding_model_info(
|
||||
routing_table: RoutingTable,
|
||||
provider_config,
|
||||
explicit_model_id: str | None = None,
|
||||
explicit_dimension: int | None = None,
|
||||
) -> tuple[str, int] | None:
|
||||
"""
|
||||
Get embedding model info with provider-level defaults and explicit overrides.
|
||||
|
||||
This function implements the priority order for embedding model selection:
|
||||
1. Explicit parameters (from API calls)
|
||||
2. Provider config defaults (NEW - from VectorIOConfig)
|
||||
3. System default (current fallback behavior)
|
||||
|
||||
Args:
|
||||
routing_table: Access to the model registry
|
||||
provider_config: The VectorIOConfig object with potential embedding_model defaults
|
||||
explicit_model_id: Explicit model ID from API call (highest priority)
|
||||
explicit_dimension: Explicit dimension from API call (highest priority)
|
||||
|
||||
Returns:
|
||||
tuple: (model_id, embedding_dimension) or None if no model available
|
||||
|
||||
Raises:
|
||||
ValueError: If model validation fails
|
||||
"""
|
||||
try:
|
||||
# Priority 1: Explicit parameters (existing behavior)
|
||||
if explicit_model_id is not None:
|
||||
logger.debug(f"Using explicit embedding model: {explicit_model_id}")
|
||||
return await get_embedding_model_info(explicit_model_id, routing_table, explicit_dimension)
|
||||
|
||||
# Priority 2: Provider config default (NEW)
|
||||
if hasattr(provider_config, "embedding_model") and provider_config.embedding_model:
|
||||
logger.info(f"Using provider config default embedding model: {provider_config.embedding_model}")
|
||||
override_dim = None
|
||||
if hasattr(provider_config, "embedding_dimension") and provider_config.embedding_dimension:
|
||||
override_dim = provider_config.embedding_dimension
|
||||
logger.info(f"Using provider config dimension override: {override_dim}")
|
||||
|
||||
return await get_embedding_model_info(provider_config.embedding_model, routing_table, override_dim)
|
||||
|
||||
# Priority 3: System default (existing fallback behavior)
|
||||
logger.debug("No explicit model or provider default, falling back to system default")
|
||||
return await _get_first_embedding_model_fallback(routing_table)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting provider embedding model info: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _get_first_embedding_model_fallback(routing_table: RoutingTable) -> tuple[str, int] | None:
|
||||
"""
|
||||
Fallback to get the first available embedding model (existing behavior).
|
||||
|
||||
This maintains backward compatibility by preserving the original logic
|
||||
from VectorIORouter._get_first_embedding_model().
|
||||
"""
|
||||
try:
|
||||
# Get all models from the routing table
|
||||
all_models = await routing_table.get_all_with_type("model") # type: ignore
|
||||
|
||||
# Filter for embedding models
|
||||
embedding_models = [
|
||||
model for model in all_models if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
||||
]
|
||||
|
||||
if embedding_models:
|
||||
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
||||
if dimension is None:
|
||||
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
||||
|
||||
logger.info(f"System fallback: using first available embedding model {embedding_models[0].identifier}")
|
||||
return embedding_models[0].identifier, dimension
|
||||
else:
|
||||
logger.warning("No embedding models found in the routing table")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting fallback embedding model: {e}")
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue