Add configurable embedding models for vector IO providers

This change lets users configure default embedding models at the provider level instead of always relying on system defaults. Each vector store provider can now specify an embedding_model and optional embedding_dimension in their config.

Key features:
- Auto-dimension lookup for standard models from the registry
- Support for Matryoshka embeddings with custom dimensions
- Three-tier priority: explicit params > provider config > system fallback
- Full backward compatibility - existing setups work unchanged
- Comprehensive test coverage with 20 test cases

Updated all vector IO providers (FAISS, Chroma, Milvus, Qdrant, etc.) with the new config fields and added detailed documentation with examples.

Fixes #2729
This commit is contained in:
skamenan7 2025-07-15 16:46:40 -04:00
parent 2298d2473c
commit 474b50b422
28 changed files with 1160 additions and 24 deletions

View file

@ -0,0 +1,127 @@
# Sample Vector IO Configuration with Embedding Model Defaults
# This example demonstrates the new provider-level embedding configuration features
# Image and version info
version: 3
image_name: my-embedding-app
# APIs to serve
apis:
- inference
- vector_io
# Provider configurations
providers:
# Inference provider for embedding models
inference:
- provider_id: local_inference
provider_type: inline::ollama
config:
url: http://localhost:11434
# Vector IO providers with embedding model defaults
vector_io:
# FAISS for fast local search with lightweight embeddings
- provider_id: fast_local_search
provider_type: inline::faiss
config:
kvstore:
provider_type: sqlite
config:
db_path: ~/.llama/distributions/my-app/faiss_store.db
# NEW: Default embedding model for this provider
embedding_model: "all-MiniLM-L6-v2"
# Dimension auto-lookup: 384 (from model registry)
# SQLite Vec for lightweight vector storage with Matryoshka embeddings
- provider_id: compact_storage
provider_type: inline::sqlite_vec
config:
db_path: ~/.llama/distributions/my-app/sqlite_vec.db
kvstore:
provider_type: sqlite
config:
db_name: sqlite_vec_registry.db
# Matryoshka embedding with custom dimension
embedding_model: "nomic-embed-text"
embedding_dimension: 256 # Reduced from default 768 for efficiency
# Chroma for persistent local storage
- provider_id: persistent_search
provider_type: inline::chroma
config:
db_path: ~/.llama/distributions/my-app/chroma.db
# High-quality embeddings for semantic search
embedding_model: "sentence-transformers/all-mpnet-base-v2"
# Auto-lookup dimension from model registry
# Qdrant Cloud for production-scale search (when available)
- provider_id: cloud_search
provider_type: remote::qdrant
config:
api_key: "${env.QDRANT_API_KEY}"
url: "${env.QDRANT_URL}"
# Production-grade embedding model
embedding_model: "text-embedding-3-small"
embedding_dimension: 512 # Custom dimension for performance
# Model registry - ensure embedding models are properly configured
models:
# Lightweight embedding model (384 dimensions)
- model_id: all-MiniLM-L6-v2
provider_id: local_inference
provider_model_id: sentence-transformers/all-MiniLM-L6-v2
model_type: embedding
metadata:
embedding_dimension: 384
description: "Fast, lightweight embeddings for general use"
# Matryoshka embedding model (variable dimensions)
- model_id: nomic-embed-text
provider_id: local_inference
provider_model_id: nomic-embed-text
model_type: embedding
metadata:
embedding_dimension: 768 # Default, can be overridden
description: "Flexible Matryoshka embeddings supporting variable dimensions"
# High-quality embedding model (768 dimensions)
- model_id: sentence-transformers/all-mpnet-base-v2
provider_id: local_inference
provider_model_id: sentence-transformers/all-mpnet-base-v2
model_type: embedding
metadata:
embedding_dimension: 768
description: "High-quality embeddings for semantic search"
# OpenAI embedding model (for cloud usage)
- model_id: text-embedding-3-small
provider_id: openai_inference # Would need OpenAI provider configured
provider_model_id: text-embedding-3-small
model_type: embedding
metadata:
embedding_dimension: 1536 # Default OpenAI dimension
description: "OpenAI's efficient embedding model"
# Optional: Configure specific vector databases (will use provider defaults)
vector_dbs:
# Uses fast_local_search provider defaults (all-MiniLM-L6-v2, 384 dims)
- vector_db_id: general_docs
provider_id: fast_local_search
# Uses compact_storage provider defaults (nomic-embed-text, 256 dims)
- vector_db_id: compressed_knowledge
provider_id: compact_storage
# Uses persistent_search provider defaults (all-mpnet-base-v2, 768 dims)
- vector_db_id: semantic_library
provider_id: persistent_search
# Server configuration
server:
host: 0.0.0.0
port: 5000
# Logging configuration
logging:
level: INFO

View file

@ -0,0 +1,302 @@
# Vector IO Embedding Model Configuration
## Overview
Vector IO providers now support configuring default embedding models at the provider level. This allows you to:
- Set a default embedding model for each vector store provider
- Support Matryoshka embeddings with custom dimensions
- Automatic dimension lookup from the model registry
- Maintain backward compatibility with existing configurations
## Configuration Options
### Provider-Level Embedding Configuration
Add `embedding_model` and `embedding_dimension` fields to your vector IO provider configuration:
```yaml
providers:
vector_io:
- provider_id: my_faiss_store
provider_type: inline::faiss
config:
kvstore:
provider_type: sqlite
config:
db_path: ~/.llama/distributions/my-app/faiss_store.db
# NEW: Configure default embedding model
embedding_model: "all-MiniLM-L6-v2"
# Optional: Only needed for variable-dimension models
# embedding_dimension: 384
```
### Embedding Model Selection Priority
The system uses a 3-tier priority system for selecting embedding models:
1. **Explicit API Parameters** (highest priority)
```python
# API call explicitly specifies model - this takes precedence
await vector_io.openai_create_vector_store(
name="my-store",
embedding_model="nomic-embed-text", # Explicit override
embedding_dimension=256,
)
```
2. **Provider Config Defaults** (middle priority)
```yaml
# Provider config provides default when no explicit model specified
config:
embedding_model: "all-MiniLM-L6-v2"
embedding_dimension: 384
```
3. **System Default** (fallback)
```
# Uses first available embedding model from model registry
# Maintains backward compatibility
```
## Provider Examples
### FAISS with Default Embedding Model
```yaml
providers:
vector_io:
- provider_id: faiss_store
provider_type: inline::faiss
config:
kvstore:
provider_type: sqlite
config:
db_path: ~/.llama/distributions/my-app/faiss_store.db
embedding_model: "all-MiniLM-L6-v2"
# Dimension auto-lookup: 384 (from model registry)
```
### SQLite Vec with Matryoshka Embedding
```yaml
providers:
vector_io:
- provider_id: sqlite_vec_store
provider_type: inline::sqlite_vec
config:
db_path: ~/.llama/distributions/my-app/sqlite_vec.db
kvstore:
provider_type: sqlite
config:
db_name: sqlite_vec_registry.db
embedding_model: "nomic-embed-text"
embedding_dimension: 256 # Override default 768 to 256
```
### Chroma with Provider Default
```yaml
providers:
vector_io:
- provider_id: chroma_store
provider_type: inline::chroma
config:
db_path: ~/.llama/distributions/my-app/chroma.db
embedding_model: "sentence-transformers/all-mpnet-base-v2"
# Auto-lookup dimension from model registry
```
### Remote Qdrant Configuration
```yaml
providers:
vector_io:
- provider_id: qdrant_cloud
provider_type: remote::qdrant
config:
api_key: "${env.QDRANT_API_KEY}"
url: "https://my-cluster.qdrant.tech"
embedding_model: "text-embedding-3-small"
embedding_dimension: 512 # Custom dimension for Matryoshka model
```
### Multiple Providers with Different Models
```yaml
providers:
vector_io:
# Fast, lightweight embeddings for simple search
- provider_id: fast_search
provider_type: inline::faiss
config:
kvstore:
provider_type: sqlite
config:
db_path: ~/.llama/fast_search.db
embedding_model: "all-MiniLM-L6-v2" # 384 dimensions
# High-quality embeddings for semantic search
- provider_id: semantic_search
provider_type: remote::qdrant
config:
api_key: "${env.QDRANT_API_KEY}"
embedding_model: "text-embedding-3-large" # 3072 dimensions
# Flexible Matryoshka embeddings
- provider_id: flexible_search
provider_type: inline::chroma
config:
db_path: ~/.llama/flexible_search.db
embedding_model: "nomic-embed-text"
embedding_dimension: 256 # Reduced from default 768
```
## Model Registry Configuration
Ensure your embedding models are registered in the model registry:
```yaml
models:
- model_id: all-MiniLM-L6-v2
provider_id: huggingface
provider_model_id: sentence-transformers/all-MiniLM-L6-v2
model_type: embedding
metadata:
embedding_dimension: 384
- model_id: nomic-embed-text
provider_id: ollama
provider_model_id: nomic-embed-text
model_type: embedding
metadata:
embedding_dimension: 768 # Default, can be overridden
- model_id: text-embedding-3-small
provider_id: openai
provider_model_id: text-embedding-3-small
model_type: embedding
metadata:
embedding_dimension: 1536 # Default for OpenAI model
```
## API Usage Examples
### Using Provider Defaults
```python
# Uses the embedding model configured in the provider config
vector_store = await vector_io.openai_create_vector_store(
name="documents", provider_id="faiss_store" # Will use configured embedding_model
)
```
### Explicit Override
```python
# Overrides provider defaults with explicit parameters
vector_store = await vector_io.openai_create_vector_store(
name="documents",
embedding_model="text-embedding-3-large", # Override provider default
embedding_dimension=1024, # Custom dimension
provider_id="faiss_store",
)
```
### Matryoshka Embedding Usage
```python
# Provider configured with nomic-embed-text and dimension 256
vector_store = await vector_io.openai_create_vector_store(
name="compact_embeddings", provider_id="flexible_search" # Uses Matryoshka config
)
# Or override with different dimension
vector_store = await vector_io.openai_create_vector_store(
name="full_embeddings",
embedding_dimension=768, # Use full dimension
provider_id="flexible_search",
)
```
## Migration Guide
### Updating Existing Configurations
Your existing configurations will continue to work without changes. To add provider-level defaults:
1. **Add embedding model fields** to your provider configs
2. **Test the configuration** to ensure expected behavior
3. **Remove explicit embedding_model parameters** from API calls if desired
### Before (explicit parameters required):
```python
# Had to specify embedding model every time
await vector_io.openai_create_vector_store(
name="store1", embedding_model="all-MiniLM-L6-v2"
)
```
### After (provider defaults):
```yaml
# Configure once in provider config
config:
embedding_model: "all-MiniLM-L6-v2"
```
```python
# No need to specify repeatedly
await vector_io.openai_create_vector_store(name="store1")
await vector_io.openai_create_vector_store(name="store2")
await vector_io.openai_create_vector_store(name="store3")
```
## Best Practices
### 1. Model Selection
- Use **lightweight models** (e.g., `all-MiniLM-L6-v2`) for simple semantic search
- Use **high-quality models** (e.g., `text-embedding-3-large`) for complex retrieval
- Consider **Matryoshka models** (e.g., `nomic-embed-text`) for flexible dimension requirements
### 2. Provider Configuration
- Configure embedding models at the **provider level** for consistency
- Use **environment variables** for API keys and sensitive configuration
- Set up **multiple providers** with different models for different use cases
### 3. Dimension Management
- Let the system **auto-lookup dimensions** when possible
- Only specify `embedding_dimension` for **Matryoshka embeddings** or custom requirements
- Ensure **model registry** has correct dimension metadata
### 4. Performance Optimization
- Use **smaller dimensions** for faster search (e.g., 256 instead of 768)
- Consider **multiple vector stores** with different embedding models for different content types
- Test **different embedding models** to find the best balance for your use case
## Troubleshooting
### Common Issues
**Model not found error:**
```
ValueError: Embedding model 'my-model' not found in model registry
```
**Solution:** Ensure the model is registered in your model configuration.
**Missing dimension metadata:**
```
ValueError: Embedding model 'my-model' has no embedding_dimension in metadata
```
**Solution:** Add `embedding_dimension` to the model's metadata in your model registry.
**Invalid dimension override:**
```
ValueError: Override dimension must be positive, got -1
```
**Solution:** Use positive integers for `embedding_dimension` values.
### Debugging Tips
1. **Check model registry:** Verify embedding models are properly registered
2. **Review provider config:** Ensure `embedding_model` matches registry IDs
3. **Test explicit parameters:** Override provider defaults to isolate issues
4. **Check logs:** Look for embedding model selection messages in router logs

View file

@ -42,6 +42,8 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -38,6 +38,8 @@ more details about Faiss in general.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -9,6 +9,8 @@ Meta's reference implementation of a vector database.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -13,6 +13,8 @@ Please refer to the remote provider documentation.
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -51,6 +51,8 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `path` | `<class 'str'>` | No | PydanticUndefined | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -207,6 +207,8 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -12,6 +12,8 @@ Please refer to the sqlite-vec provider documentation.
|-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -41,6 +41,8 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `url` | `str \| None` | No | PydanticUndefined | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -115,6 +115,8 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.

View file

@ -40,6 +40,8 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
| `db` | `str \| None` | No | postgres | |
| `user` | `str \| None` | No | postgres | |
| `password` | `str \| None` | No | mysecretpassword | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -20,6 +20,8 @@ Please refer to the inline provider documentation.
| `prefix` | `str \| None` | No | | |
| `timeout` | `int \| None` | No | | |
| `host` | `str \| None` | No | | |
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration

View file

@ -33,6 +33,13 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `embedding_model` | `str \| None` | No | | Optional default embedding model for this provider. If not specified, will use system default. |
| `embedding_dimension` | `int \| None` | No | | Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry. |
## Sample Configuration
```yaml

View file

@ -7,9 +7,7 @@
import asyncio
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
@ -28,6 +26,7 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.vector_io.embedding_utils import get_provider_embedding_model_info
logger = get_logger(name=__name__, category="core")
@ -51,10 +50,10 @@ class VectorIORouter(VectorIO):
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
"""Get the first available embedding model identifier (DEPRECATED - use embedding_utils instead)."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
all_models = await self.routing_table.get_all_with_type("model") # type: ignore
# Filter for embedding models
embedding_models = [
@ -75,6 +74,31 @@ class VectorIORouter(VectorIO):
logger.error(f"Error getting embedding models: {e}")
return None
async def _get_provider_config(self, provider_id: str | None = None) -> Any:
"""Get the provider configuration object for embedding model defaults."""
try:
# If no provider_id specified, get the first available provider
if provider_id is None and hasattr(self.routing_table, "impls_by_provider_id"):
available_providers = list(self.routing_table.impls_by_provider_id.keys()) # type: ignore
if available_providers:
provider_id = available_providers[0]
else:
logger.warning("No vector IO providers available")
return None
if provider_id and hasattr(self.routing_table, "impls_by_provider_id"):
provider_impl = self.routing_table.impls_by_provider_id.get(provider_id) # type: ignore
if provider_impl and hasattr(provider_impl, "__provider_config__"):
return provider_impl.__provider_config__
else:
logger.debug(f"Provider {provider_id} has no config object attached")
return None
return None
except Exception as e:
logger.error(f"Error getting provider config: {e}")
return None
async def register_vector_db(
self,
vector_db_id: str,
@ -84,7 +108,7 @@ class VectorIORouter(VectorIO):
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
await self.routing_table.register_vector_db( # type: ignore
vector_db_id,
embedding_model,
embedding_dimension,
@ -127,13 +151,64 @@ class VectorIORouter(VectorIO):
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
# Use the new 3-tier priority system for embedding model selection
provider_config = await self._get_provider_config(provider_id)
# Log the resolution context for debugging
logger.debug(f"Resolving embedding model for vector store '{name}' with provider_id={provider_id}")
logger.debug(f"Explicit model: {embedding_model}, explicit dimension: {embedding_dimension}")
logger.debug(
f"Provider config embedding_model: {getattr(provider_config, 'embedding_model', None) if provider_config else None}"
)
logger.debug(
f"Provider config embedding_dimension: {getattr(provider_config, 'embedding_dimension', None) if provider_config else None}"
)
try:
embedding_model_info = await get_provider_embedding_model_info(
routing_table=self.routing_table,
provider_config=provider_config,
explicit_model_id=embedding_model,
explicit_dimension=embedding_dimension,
)
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
resolved_model, resolved_dimension = embedding_model_info
# Enhanced logging to show resolution path
if embedding_model is not None:
logger.info(
f"✅ Vector store '{name}': Using EXPLICIT embedding model '{resolved_model}' (dimension: {resolved_dimension})"
)
elif provider_config and getattr(provider_config, "embedding_model", None):
logger.info(
f"✅ Vector store '{name}': Using PROVIDER DEFAULT embedding model '{resolved_model}' (dimension: {resolved_dimension}) from provider '{provider_id}'"
)
if getattr(provider_config, "embedding_dimension", None):
logger.info(f" └── Provider config dimension override: {resolved_dimension}")
else:
logger.info(f" └── Auto-lookup dimension from model registry: {resolved_dimension}")
else:
logger.info(
f"✅ Vector store '{name}': Using SYSTEM DEFAULT embedding model '{resolved_model}' (dimension: {resolved_dimension})"
)
logger.warning(
f"⚠️ Consider configuring a default embedding model for provider '{provider_id}' to avoid fallback behavior"
)
embedding_model, embedding_dimension = resolved_model, resolved_dimension
except Exception as e:
logger.error(
f"❌ Failed to resolve embedding model for vector store '{name}' with provider '{provider_id}': {e}"
)
logger.error(f" Debug info - Explicit: model={embedding_model}, dim={embedding_dimension}")
logger.error(
f" Debug info - Provider: model={getattr(provider_config, 'embedding_model', None) if provider_config else None}, dim={getattr(provider_config, 'embedding_dimension', None) if provider_config else None}"
)
raise ValueError(f"Unable to determine embedding model for vector store '{name}': {e}") from e
vector_db_id = name
registered_vector_db = await self.routing_table.register_vector_db(

View file

@ -6,12 +6,25 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
class ChromaVectorIOConfig(BaseModel):
db_path: str
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
return {"db_path": db_path}
return {
"db_path": db_path,
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
@ -18,6 +18,14 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
@ -25,5 +33,8 @@ class FaissVectorIOConfig(BaseModel):
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="faiss_store.db",
)
),
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -20,6 +20,14 @@ class MilvusVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
@ -29,4 +37,7 @@ class MilvusVectorIOConfig(BaseModel):
__distro_dir__=__distro_dir__,
db_name="milvus_registry.db",
),
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -7,7 +7,7 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@ -15,9 +15,20 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class QdrantVectorIOConfig(BaseModel):
path: str
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -17,6 +17,14 @@ from llama_stack.providers.utils.kvstore.config import (
class SQLiteVectorIOConfig(BaseModel):
db_path: str = Field(description="Path to the SQLite database file")
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
@ -26,4 +34,7 @@ class SQLiteVectorIOConfig(BaseModel):
__distro_dir__=__distro_dir__,
db_name="sqlite_vec_registry.db",
),
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -6,12 +6,25 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
class ChromaVectorIOConfig(BaseModel):
url: str | None
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
return {"url": url}
return {
"url": url,
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -18,6 +18,14 @@ class MilvusVectorIOConfig(BaseModel):
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +33,10 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -18,15 +18,32 @@ class PGVectorVectorIOConfig(BaseModel):
db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword")
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(
cls,
host: str = "${env.PGVECTOR_HOST:=localhost}",
port: int = "${env.PGVECTOR_PORT:=5432}",
port: int | str = "${env.PGVECTOR_PORT:=5432}",
db: str = "${env.PGVECTOR_DB}",
user: str = "${env.PGVECTOR_USER}",
password: str = "${env.PGVECTOR_PASSWORD}",
**kwargs: Any,
) -> dict[str, Any]:
return {"host": host, "port": port, "db": db, "user": user, "password": password}
return {
"host": host,
"port": port,
"db": db,
"user": user,
"password": password,
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@ -23,9 +23,20 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None
timeout: int | None = None
host: str | None = None
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"api_key": "${env.QDRANT_API_KEY}",
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
class WeaviateRequestProviderData(BaseModel):
@ -15,6 +15,19 @@ class WeaviateRequestProviderData(BaseModel):
class WeaviateVectorIOConfig(BaseModel):
embedding_model: str | None = Field(
default=None,
description="Optional default embedding model for this provider. If not specified, will use system default.",
)
embedding_dimension: int | None = Field(
default=None,
description="Optional embedding dimension override. Only needed for models with variable dimensions (e.g., Matryoshka embeddings). If not specified, will auto-lookup from model registry.",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {}
return {
# Optional: Configure default embedding model for this provider
# "embedding_model": "all-MiniLM-L6-v2",
# "embedding_dimension": 384, # Only needed for variable-dimension models
}

View file

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
async def get_embedding_model_info(
model_id: str, routing_table: RoutingTable, override_dimension: int | None = None
) -> tuple[str, int]:
"""
Get embedding model info with auto-dimension lookup.
This function validates that the specified model is an embedding model
and returns its embedding dimensions, with support for Matryoshka embeddings
through dimension overrides.
Args:
model_id: The embedding model identifier to look up
routing_table: Access to the model registry for validation and dimension lookup
override_dimension: Optional dimension override for Matryoshka models that
support variable dimensions (e.g., nomic-embed-text)
Returns:
tuple: (model_id, embedding_dimension)
Raises:
ValueError: If model not found, not an embedding model, or missing dimension info
"""
try:
# Look up the model in the routing table
model = await routing_table.get_object_by_identifier("model", model_id) # type: ignore
if model is None:
raise ValueError(f"Embedding model '{model_id}' not found in model registry")
# Validate that this is an embedding model
if not hasattr(model, "model_type") or model.model_type != ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is not an embedding model (type: {getattr(model, 'model_type', 'unknown')})"
)
# If override dimension is provided, use it (for Matryoshka embeddings)
if override_dimension is not None:
if override_dimension <= 0:
raise ValueError(f"Override dimension must be positive, got {override_dimension}")
logger.info(f"Using override dimension {override_dimension} for embedding model '{model_id}'")
return model_id, override_dimension
# Extract embedding dimension from model metadata
if not hasattr(model, "metadata") or not model.metadata:
raise ValueError(f"Embedding model '{model_id}' has no metadata")
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
if not isinstance(embedding_dimension, int) or embedding_dimension <= 0:
raise ValueError(f"Invalid embedding_dimension for model '{model_id}': {embedding_dimension}")
logger.debug(f"Auto-lookup successful for embedding model '{model_id}': dimension {embedding_dimension}")
return model_id, embedding_dimension
except Exception as e:
logger.error(f"Error looking up embedding model info for '{model_id}': {e}")
raise
async def get_provider_embedding_model_info(
routing_table: RoutingTable,
provider_config,
explicit_model_id: str | None = None,
explicit_dimension: int | None = None,
) -> tuple[str, int] | None:
"""
Get embedding model info with provider-level defaults and explicit overrides.
This function implements the priority order for embedding model selection:
1. Explicit parameters (from API calls)
2. Provider config defaults (NEW - from VectorIOConfig)
3. System default (current fallback behavior)
Args:
routing_table: Access to the model registry
provider_config: The VectorIOConfig object with potential embedding_model defaults
explicit_model_id: Explicit model ID from API call (highest priority)
explicit_dimension: Explicit dimension from API call (highest priority)
Returns:
tuple: (model_id, embedding_dimension) or None if no model available
Raises:
ValueError: If model validation fails
"""
try:
# Priority 1: Explicit parameters (existing behavior)
if explicit_model_id is not None:
logger.debug(f"Using explicit embedding model: {explicit_model_id}")
return await get_embedding_model_info(explicit_model_id, routing_table, explicit_dimension)
# Priority 2: Provider config default (NEW)
if hasattr(provider_config, "embedding_model") and provider_config.embedding_model:
logger.info(f"Using provider config default embedding model: {provider_config.embedding_model}")
override_dim = None
if hasattr(provider_config, "embedding_dimension") and provider_config.embedding_dimension:
override_dim = provider_config.embedding_dimension
logger.info(f"Using provider config dimension override: {override_dim}")
return await get_embedding_model_info(provider_config.embedding_model, routing_table, override_dim)
# Priority 3: System default (existing fallback behavior)
logger.debug("No explicit model or provider default, falling back to system default")
return await _get_first_embedding_model_fallback(routing_table)
except Exception as e:
logger.error(f"Error getting provider embedding model info: {e}")
raise
async def _get_first_embedding_model_fallback(routing_table: RoutingTable) -> tuple[str, int] | None:
"""
Fallback to get the first available embedding model (existing behavior).
This maintains backward compatibility by preserving the original logic
from VectorIORouter._get_first_embedding_model().
"""
try:
# Get all models from the routing table
all_models = await routing_table.get_all_with_type("model") # type: ignore
# Filter for embedding models
embedding_models = [
model for model in all_models if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
logger.info(f"System fallback: using first available embedding model {embedding_models[0].identifier}")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting fallback embedding model: {e}")
return None

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,320 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.vector_io.embedding_utils import (
_get_first_embedding_model_fallback,
get_embedding_model_info,
get_provider_embedding_model_info,
)
class MockModel:
"""Mock model object for testing."""
def __init__(self, identifier: str, model_type: ModelType, metadata: dict | None = None):
self.identifier = identifier
self.model_type = model_type
self.metadata = metadata
class MockConfig:
"""Mock provider config for testing."""
def __init__(self, embedding_model: str | None = None, embedding_dimension: int | None = None):
self.embedding_model = embedding_model
self.embedding_dimension = embedding_dimension
@pytest.fixture
def mock_routing_table():
"""Create a mock routing table for testing."""
routing_table = AsyncMock()
# Mock embedding models
embedding_models = [
MockModel(identifier="all-MiniLM-L6-v2", model_type=ModelType.embedding, metadata={"embedding_dimension": 384}),
MockModel(identifier="nomic-embed-text", model_type=ModelType.embedding, metadata={"embedding_dimension": 768}),
]
# Mock LLM model (should be filtered out)
llm_model = MockModel(identifier="llama-3.1-8b", model_type=ModelType.llm, metadata={})
all_models = embedding_models + [llm_model]
async def mock_get_object_by_identifier(type_name: str, identifier: str):
if type_name == "model":
for model in all_models:
if model.identifier == identifier:
return model
return None
async def mock_get_all_with_type(type_name: str):
if type_name == "model":
return all_models
return []
routing_table.get_object_by_identifier.side_effect = mock_get_object_by_identifier
routing_table.get_all_with_type.side_effect = mock_get_all_with_type
return routing_table
class TestGetEmbeddingModelInfo:
"""Test the core get_embedding_model_info function."""
@pytest.mark.asyncio
async def test_valid_embedding_model(self, mock_routing_table):
"""Test successful lookup of a valid embedding model."""
model_id, dimension = await get_embedding_model_info("all-MiniLM-L6-v2", mock_routing_table)
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384
@pytest.mark.asyncio
async def test_embedding_model_with_override_dimension(self, mock_routing_table):
"""Test Matryoshka embedding with dimension override."""
model_id, dimension = await get_embedding_model_info(
"nomic-embed-text", mock_routing_table, override_dimension=256
)
assert model_id == "nomic-embed-text"
assert dimension == 256 # Should use override, not default 768
@pytest.mark.asyncio
async def test_model_not_found(self, mock_routing_table):
"""Test error when model doesn't exist."""
with pytest.raises(ValueError, match="not found in model registry"):
await get_embedding_model_info("non-existent-model", mock_routing_table)
@pytest.mark.asyncio
async def test_non_embedding_model(self, mock_routing_table):
"""Test error when model is not an embedding model."""
with pytest.raises(ValueError, match="is not an embedding model"):
await get_embedding_model_info("llama-3.1-8b", mock_routing_table)
@pytest.mark.asyncio
async def test_model_missing_dimension_metadata(self, mock_routing_table):
"""Test error when embedding model has no dimension metadata."""
# Create a model with non-empty metadata dict missing embedding_dimension
bad_model = MockModel(
identifier="bad-embedding-model",
model_type=ModelType.embedding,
metadata={"some_other_field": "value"}, # Non-empty but missing embedding_dimension
)
async def mock_get_bad_model(type_name: str, identifier: str):
if type_name == "model" and identifier == "bad-embedding-model":
return bad_model
return await mock_routing_table.get_object_by_identifier(type_name, identifier)
mock_routing_table.get_object_by_identifier.side_effect = mock_get_bad_model
with pytest.raises(ValueError, match="has no embedding_dimension in metadata"):
await get_embedding_model_info("bad-embedding-model", mock_routing_table)
@pytest.mark.asyncio
async def test_invalid_override_dimension(self, mock_routing_table):
"""Test error with invalid override dimension."""
with pytest.raises(ValueError, match="Override dimension must be positive"):
await get_embedding_model_info("all-MiniLM-L6-v2", mock_routing_table, override_dimension=0)
with pytest.raises(ValueError, match="Override dimension must be positive"):
await get_embedding_model_info("all-MiniLM-L6-v2", mock_routing_table, override_dimension=-10)
class TestGetProviderEmbeddingModelInfo:
"""Test the provider-level embedding model selection with priority system."""
@pytest.mark.asyncio
async def test_priority_1_explicit_parameters(self, mock_routing_table):
"""Test highest priority: explicit parameters."""
config = MockConfig(embedding_model="nomic-embed-text", embedding_dimension=512)
# Explicit parameters should override config
result = await get_provider_embedding_model_info(
routing_table=mock_routing_table,
provider_config=config,
explicit_model_id="all-MiniLM-L6-v2", # Should use this
explicit_dimension=256, # Should use this
)
assert result is not None
model_id, dimension = result
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 256
@pytest.mark.asyncio
async def test_priority_2_provider_config_defaults(self, mock_routing_table):
"""Test middle priority: provider config defaults."""
config = MockConfig(embedding_model="nomic-embed-text", embedding_dimension=512)
# No explicit parameters, should use config
model_id, dimension = await get_provider_embedding_model_info(
routing_table=mock_routing_table, provider_config=config, explicit_model_id=None, explicit_dimension=None
)
assert model_id == "nomic-embed-text"
assert dimension == 512 # Config override
@pytest.mark.asyncio
async def test_priority_2_provider_config_model_only(self, mock_routing_table):
"""Test provider config with model but no dimension override."""
config = MockConfig(embedding_model="all-MiniLM-L6-v2") # No dimension override
model_id, dimension = await get_provider_embedding_model_info(
routing_table=mock_routing_table, provider_config=config, explicit_model_id=None, explicit_dimension=None
)
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384 # Auto-lookup from model metadata
@pytest.mark.asyncio
async def test_priority_3_system_default(self, mock_routing_table):
"""Test lowest priority: system default fallback."""
config = MockConfig() # No defaults set
model_id, dimension = await get_provider_embedding_model_info(
routing_table=mock_routing_table, provider_config=config, explicit_model_id=None, explicit_dimension=None
)
# Should get first available embedding model
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384
@pytest.mark.asyncio
async def test_no_provider_config(self, mock_routing_table):
"""Test with None provider config."""
model_id, dimension = await get_provider_embedding_model_info(
routing_table=mock_routing_table, provider_config=None, explicit_model_id=None, explicit_dimension=None
)
# Should fall back to system default
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384
@pytest.mark.asyncio
async def test_no_embedding_models_available(self, mock_routing_table):
"""Test when no embedding models are available."""
# Mock routing table with no embedding models
async def mock_get_all_empty(type_name: str):
return [] # No models
mock_routing_table.get_all_with_type.side_effect = mock_get_all_empty
config = MockConfig()
result = await get_provider_embedding_model_info(
routing_table=mock_routing_table, provider_config=config, explicit_model_id=None, explicit_dimension=None
)
assert result is None
class TestGetFirstEmbeddingModelFallback:
"""Test the fallback function for system defaults."""
@pytest.mark.asyncio
async def test_successful_fallback(self, mock_routing_table):
"""Test successful fallback to first embedding model."""
model_id, dimension = await _get_first_embedding_model_fallback(mock_routing_table)
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384
@pytest.mark.asyncio
async def test_no_embedding_models_fallback(self, mock_routing_table):
"""Test fallback when no embedding models exist."""
# Mock empty model list
async def mock_get_all_empty(type_name: str):
return []
mock_routing_table.get_all_with_type.side_effect = mock_get_all_empty
result = await _get_first_embedding_model_fallback(mock_routing_table)
assert result is None
@pytest.mark.asyncio
async def test_embedding_model_missing_dimension_fallback(self, mock_routing_table):
"""Test fallback when embedding model has no dimension - should return None."""
bad_model = MockModel(
identifier="bad-embedding",
model_type=ModelType.embedding,
metadata={}, # Missing dimension
)
async def mock_get_all_bad(type_name: str):
return [bad_model] if type_name == "model" else []
mock_routing_table.get_all_with_type.side_effect = mock_get_all_bad
# The function should return None (not raise) when model has no dimension
result = await _get_first_embedding_model_fallback(mock_routing_table)
assert result is None
class TestBackwardCompatibility:
"""Test that the new system maintains backward compatibility."""
@pytest.mark.asyncio
async def test_explicit_model_still_works(self, mock_routing_table):
"""Test that explicitly specifying embedding model still works as before."""
model_id, dimension = await get_embedding_model_info("all-MiniLM-L6-v2", mock_routing_table)
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384
@pytest.mark.asyncio
async def test_system_fallback_unchanged(self, mock_routing_table):
"""Test that system fallback behavior is unchanged."""
# This should behave exactly like the old _get_first_embedding_model
model_id, dimension = await _get_first_embedding_model_fallback(mock_routing_table)
assert model_id == "all-MiniLM-L6-v2"
assert dimension == 384
class TestMatryoshkaEmbeddings:
"""Test specific Matryoshka embedding scenarios."""
@pytest.mark.asyncio
async def test_nomic_embed_text_default(self, mock_routing_table):
"""Test nomic-embed-text with default dimension."""
model_id, dimension = await get_embedding_model_info("nomic-embed-text", mock_routing_table)
assert model_id == "nomic-embed-text"
assert dimension == 768 # Default dimension
@pytest.mark.asyncio
async def test_nomic_embed_text_override(self, mock_routing_table):
"""Test nomic-embed-text with dimension override."""
model_id, dimension = await get_embedding_model_info(
"nomic-embed-text", mock_routing_table, override_dimension=256
)
assert model_id == "nomic-embed-text"
assert dimension == 256 # Overridden dimension
@pytest.mark.asyncio
async def test_provider_config_matryoshka_override(self, mock_routing_table):
"""Test provider config with Matryoshka dimension override."""
config = MockConfig(
embedding_model="nomic-embed-text",
embedding_dimension=128, # Custom dimension
)
model_id, dimension = await get_provider_embedding_model_info(
routing_table=mock_routing_table, provider_config=config, explicit_model_id=None, explicit_dimension=None
)
assert model_id == "nomic-embed-text"
assert dimension == 128 # Should use provider config override