feat: configure vector-io provider with an embedding model

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-31 13:07:03 +02:00
parent 1f0766308d
commit d8f013b35a
29 changed files with 228 additions and 24 deletions

View file

@ -31,6 +31,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
from llama_stack.providers.utils.vector_io.embedding_config import EmbeddingConfig
log = logging.getLogger(__name__)
@ -39,6 +41,41 @@ RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
def apply_provider_embedding_defaults(
vector_db: VectorDB, provider_embedding_config: EmbeddingConfig | None
) -> VectorDB:
"""Apply provider-level embedding defaults to a VectorDB if not already specified.
This allows providers to specify default embedding models for use-case specific
vector stores, reducing the need for app developers to know embedding details.
Args:
vector_db: The VectorDB to potentially modify
provider_embedding_config: The provider's default embedding configuration
Returns:
The VectorDB with embedding defaults applied if needed
"""
if provider_embedding_config is None:
return vector_db
# Create a copy to avoid modifying the original
db_dict = vector_db.model_dump()
# Apply embedding model default if not specified
if not db_dict.get("embedding_model") and provider_embedding_config.model:
db_dict["embedding_model"] = provider_embedding_config.model
# Apply embedding dimension default if not specified
if not db_dict.get("embedding_dimension") and provider_embedding_config.dimensions:
db_dict["embedding_dimension"] = provider_embedding_config.dimensions
elif not db_dict.get("embedding_dimension"):
# Fallback to default dimension if still not specified
db_dict["embedding_dimension"] = provider_embedding_config.get_dimensions_or_default()
return VectorDB.model_validate(db_dict)
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)