mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 20:50:52 +00:00
Merge 32930868de
into 2f58d87c22
This commit is contained in:
commit
689f1db815
8 changed files with 284 additions and 8 deletions
33
llama_stack/apis/common/vector_store_config.py
Normal file
33
llama_stack/apis/common/vector_store_config.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
"""Vector store global config stuff.
|
||||
|
||||
Basically just holds default embedding model settings so we don't have to
|
||||
pass them around everywhere. Router picks these up when client doesn't specify.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
__all__ = ["VectorStoreConfig"]
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Default embedding model config that gets picked up from env vars."""
|
||||
|
||||
default_embedding_model: str | None = Field(
|
||||
default_factory=lambda: os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL")
|
||||
)
|
||||
# dimension from env - fallback to None if not set or invalid
|
||||
default_embedding_dimension: int | None = Field(
|
||||
default_factory=lambda: int(os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", 0)) or None, ge=1
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
|
@ -12,6 +12,7 @@ from urllib.parse import urlparse
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
@ -474,6 +475,12 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
|||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# Global vector-store defaults (embedding model etc.)
|
||||
vector_store_config: VectorStoreConfig = Field(
|
||||
default_factory=VectorStoreConfig,
|
||||
description="Global defaults for vector-store creation (embedding model, dimension, …)",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -76,6 +77,41 @@ class VectorIORouter(VectorIO):
|
|||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
|
||||
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
|
||||
"""Figure out which embedding model to use and what dimension it has."""
|
||||
|
||||
# if they passed a model explicitly, use that
|
||||
if explicit_model is not None:
|
||||
# try to look up dimension from our routing table
|
||||
models = await self.routing_table.get_all_with_type("model")
|
||||
for model in models:
|
||||
if getattr(model, "identifier", None) == explicit_model:
|
||||
dim = model.metadata.get("embedding_dimension")
|
||||
if dim is None:
|
||||
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
|
||||
return explicit_model, dim
|
||||
# model not found in registry - this is an error
|
||||
raise ValueError(f"Embedding model '{explicit_model}' not found in model registry")
|
||||
|
||||
# check if we have global defaults set via env vars
|
||||
config = VectorStoreConfig()
|
||||
if config.default_embedding_model is not None:
|
||||
if config.default_embedding_dimension is None:
|
||||
raise ValueError(
|
||||
f"default_embedding_model '{config.default_embedding_model}' is set but default_embedding_dimension is missing"
|
||||
)
|
||||
return config.default_embedding_model, config.default_embedding_dimension
|
||||
|
||||
# fallback to first available embedding model for compatibility
|
||||
fallback = await self._get_first_embedding_model()
|
||||
if fallback is not None:
|
||||
return fallback
|
||||
|
||||
# if no models available, raise error
|
||||
raise ValueError(
|
||||
"No embedding model specified and no default configured. Either provide an embedding_model parameter or set vector_store_config.default_embedding_model"
|
||||
)
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -102,7 +138,7 @@ class VectorIORouter(VectorIO):
|
|||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.chunk_id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
@ -131,13 +167,8 @@ class VectorIORouter(VectorIO):
|
|||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
# Determine which embedding model to use based on new precedence
|
||||
embedding_model, embedding_dimension = await self._resolve_embedding_model(embedding_model)
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
|
@ -39,6 +39,9 @@ distribution_spec:
|
|||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
vector_store_config:
|
||||
default_embedding_model: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_MODEL:=all-MiniLM-L6-v2}
|
||||
default_embedding_dimension: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION:=384}
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue