mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 05:24:39 +00:00
Merge 32930868de
into 2f58d87c22
This commit is contained in:
commit
689f1db815
8 changed files with 284 additions and 8 deletions
|
@ -12,6 +12,7 @@ from urllib.parse import urlparse
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
@ -474,6 +475,12 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
|||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# Global vector-store defaults (embedding model etc.)
|
||||
vector_store_config: VectorStoreConfig = Field(
|
||||
default_factory=VectorStoreConfig,
|
||||
description="Global defaults for vector-store creation (embedding model, dimension, …)",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -76,6 +77,41 @@ class VectorIORouter(VectorIO):
|
|||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
|
||||
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
|
||||
"""Figure out which embedding model to use and what dimension it has."""
|
||||
|
||||
# if they passed a model explicitly, use that
|
||||
if explicit_model is not None:
|
||||
# try to look up dimension from our routing table
|
||||
models = await self.routing_table.get_all_with_type("model")
|
||||
for model in models:
|
||||
if getattr(model, "identifier", None) == explicit_model:
|
||||
dim = model.metadata.get("embedding_dimension")
|
||||
if dim is None:
|
||||
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
|
||||
return explicit_model, dim
|
||||
# model not found in registry - this is an error
|
||||
raise ValueError(f"Embedding model '{explicit_model}' not found in model registry")
|
||||
|
||||
# check if we have global defaults set via env vars
|
||||
config = VectorStoreConfig()
|
||||
if config.default_embedding_model is not None:
|
||||
if config.default_embedding_dimension is None:
|
||||
raise ValueError(
|
||||
f"default_embedding_model '{config.default_embedding_model}' is set but default_embedding_dimension is missing"
|
||||
)
|
||||
return config.default_embedding_model, config.default_embedding_dimension
|
||||
|
||||
# fallback to first available embedding model for compatibility
|
||||
fallback = await self._get_first_embedding_model()
|
||||
if fallback is not None:
|
||||
return fallback
|
||||
|
||||
# if no models available, raise error
|
||||
raise ValueError(
|
||||
"No embedding model specified and no default configured. Either provide an embedding_model parameter or set vector_store_config.default_embedding_model"
|
||||
)
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -102,7 +138,7 @@ class VectorIORouter(VectorIO):
|
|||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.chunk_id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
@ -131,13 +167,8 @@ class VectorIORouter(VectorIO):
|
|||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
# Determine which embedding model to use based on new precedence
|
||||
embedding_model, embedding_dimension = await self._resolve_embedding_model(embedding_model)
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue