mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Merge 418a25aea9
into 81ecaf6221
This commit is contained in:
commit
e6cb8a262c
8 changed files with 243 additions and 8 deletions
|
@ -687,3 +687,51 @@ shields:
|
|||
provider_shield_id: null
|
||||
...
|
||||
```
|
||||
|
||||
## Global Vector Store Defaults
|
||||
|
||||
You can provide a stack-level default embedding model that will be used whenever a new vector store is created and the caller does not specify an `embedding_model` parameter.
|
||||
|
||||
Add a top-level `vector_store_config` block at the root of your build/run YAML, alongside other root-level keys such as `models`, `shields`, `server`, and `metadata_store`:
|
||||
|
||||
```yaml
|
||||
# ... other configuration sections ...
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: ollama
|
||||
provider_model_id: null
|
||||
shields: []
|
||||
server:
|
||||
port: 8321
|
||||
vector_store_config:
|
||||
default_embedding_model: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_MODEL:=all-MiniLM-L6-v2}
|
||||
# optional - if omitted, defaults to 384
|
||||
default_embedding_dimension: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION:=384}
|
||||
```
|
||||
|
||||
Precedence rules at runtime:
|
||||
|
||||
1. If `embedding_model` is explicitly passed in an API call, that value is used.
|
||||
2. Otherwise the value in `vector_store_config.default_embedding_model` is used.
|
||||
3. If neither is available the server will fall back to the system default (all-MiniLM-L6-v2).
|
||||
|
||||
#### Environment variables
|
||||
|
||||
| Variable | Purpose | Example |
|
||||
|----------|---------|---------|
|
||||
| `LLAMA_STACK_DEFAULT_EMBEDDING_MODEL` | Global default embedding model id | `all-MiniLM-L6-v2` |
|
||||
| `LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION` | Dimension for embeddings (optional, defaults to 384) | `384` |
|
||||
|
||||
If you include the `${env.…}` placeholder in `vector_store_config`, deployments can override the default without editing YAML:
|
||||
|
||||
```bash
|
||||
export LLAMA_STACK_DEFAULT_EMBEDDING_MODEL="sentence-transformers/all-MiniLM-L6-v2"
|
||||
llama stack run --config run.yaml
|
||||
```
|
||||
|
||||
> Tip: If you omit `vector_store_config` entirely and don't set `LLAMA_STACK_DEFAULT_EMBEDDING_MODEL`, the system will fall back to the default `all-MiniLM-L6-v2` model with 384 dimensions for vector store creation.
|
||||
|
|
33
llama_stack/apis/common/vector_store_config.py
Normal file
33
llama_stack/apis/common/vector_store_config.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
"""Vector store global config stuff.
|
||||
|
||||
Basically just holds default embedding model settings so we don't have to
|
||||
pass them around everywhere. Router picks these up when client doesn't specify.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
__all__ = ["VectorStoreConfig"]
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Default embedding model config that gets picked up from env vars."""
|
||||
|
||||
default_embedding_model: str | None = Field(
|
||||
default_factory=lambda: os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL")
|
||||
)
|
||||
# dimension from env - fallback to None if not set or invalid
|
||||
default_embedding_dimension: int | None = Field(
|
||||
default_factory=lambda: int(os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", 0)) or None, ge=1
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
|
@ -11,6 +11,7 @@ from typing import Annotated, Any, Literal, Self
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
@ -391,6 +392,12 @@ Configuration for the persistence store used by the inference API. If not specif
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# Global vector-store defaults (embedding model etc.)
|
||||
vector_store_config: VectorStoreConfig = Field(
|
||||
default_factory=VectorStoreConfig,
|
||||
description="Global defaults for vector-store creation (embedding model, dimension, …)",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -76,6 +77,30 @@ class VectorIORouter(VectorIO):
|
|||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
|
||||
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
|
||||
"""Figure out which embedding model to use and what dimension it has."""
|
||||
|
||||
# if they passed a model explicitly, use that
|
||||
if explicit_model is not None:
|
||||
# try to look up dimension from our routing table
|
||||
models = await self.routing_table.get_all_with_type("model")
|
||||
for model in models:
|
||||
if getattr(model, "identifier", None) == explicit_model:
|
||||
dim = model.metadata.get("embedding_dimension")
|
||||
if dim is None:
|
||||
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
|
||||
return explicit_model, dim
|
||||
# model not in our registry, let caller deal with dimension
|
||||
return explicit_model, None # type: ignore
|
||||
|
||||
# check if we have global defaults set via env vars
|
||||
config = VectorStoreConfig()
|
||||
if config.default_embedding_model is not None:
|
||||
return config.default_embedding_model, config.default_embedding_dimension or 384
|
||||
|
||||
# fallback to existing default model for compatibility
|
||||
return "all-MiniLM-L6-v2", 384
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -102,7 +127,7 @@ class VectorIORouter(VectorIO):
|
|||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.chunk_id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
@ -131,13 +156,8 @@ class VectorIORouter(VectorIO):
|
|||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
# Determine which embedding model to use based on new precedence
|
||||
embedding_model, embedding_dimension = await self._resolve_embedding_model(embedding_model)
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
|
@ -39,6 +39,9 @@ distribution_spec:
|
|||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
vector_store_config:
|
||||
default_embedding_model: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_MODEL:=all-MiniLM-L6-v2}
|
||||
default_embedding_dimension: ${env.LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION:=384}
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -63,6 +63,19 @@ def pytest_configure(config):
|
|||
os.environ["DISABLE_CODE_SANDBOX"] = "1"
|
||||
logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS")
|
||||
|
||||
# After processing CLI --env overrides, ensure global default embedding model is set for vector-store operations
|
||||
embedding_model_opt = config.getoption("--embedding-model") or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
if embedding_model_opt and not os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL"):
|
||||
# Use first value in comma-separated list (if any)
|
||||
default_model = embedding_model_opt.split(",")[0].strip()
|
||||
os.environ["LLAMA_STACK_DEFAULT_EMBEDDING_MODEL"] = default_model
|
||||
logger.info(f"Setting LLAMA_STACK_DEFAULT_EMBEDDING_MODEL={default_model}")
|
||||
|
||||
embedding_dim_opt = config.getoption("--embedding-dimension") or 384
|
||||
if not os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION") and embedding_dim_opt:
|
||||
os.environ["LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION"] = str(embedding_dim_opt)
|
||||
logger.info(f"Setting LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION={embedding_dim_opt}")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
|
|
29
tests/unit/common/test_vector_store_config.py
Normal file
29
tests/unit/common/test_vector_store_config.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||
|
||||
|
||||
def test_defaults(monkeypatch):
|
||||
# ensure env is clean to avoid flaky defaults
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
||||
config = VectorStoreConfig()
|
||||
assert config.default_embedding_model is None
|
||||
assert config.default_embedding_dimension is None
|
||||
|
||||
|
||||
def test_env_loading(monkeypatch):
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "test-model")
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "123")
|
||||
|
||||
config = VectorStoreConfig()
|
||||
assert config.default_embedding_model == "test-model"
|
||||
assert config.default_embedding_dimension == 123
|
||||
|
||||
# cleanup
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
82
tests/unit/router/test_embedding_precedence.py
Normal file
82
tests/unit/router/test_embedding_precedence.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class _DummyModel:
|
||||
def __init__(self, identifier: str, dim: int):
|
||||
self.identifier = identifier
|
||||
self.model_type = ModelType.embedding
|
||||
self.metadata = {"embedding_dimension": dim}
|
||||
|
||||
|
||||
class _DummyRoutingTable:
|
||||
"""Just a fake routing table for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._models = [
|
||||
_DummyModel("first-model", 123),
|
||||
_DummyModel("second-model", 512),
|
||||
]
|
||||
|
||||
async def get_all_with_type(self, _type: str):
|
||||
# just return embedding models for tests
|
||||
return self._models
|
||||
|
||||
# VectorIORouter needs these but we don't use them in tests
|
||||
async def register_vector_db(self, *_args, **_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_provider_impl(self, *_args, **_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def test_global_default_used(monkeypatch):
|
||||
"""Should use env var defaults when no explicit model given."""
|
||||
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
|
||||
|
||||
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
||||
|
||||
model, dim = await router._resolve_embedding_model(None)
|
||||
assert model == "env-default-model"
|
||||
assert dim == 256
|
||||
|
||||
# cleanup
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
||||
|
||||
|
||||
async def test_explicit_override(monkeypatch):
|
||||
"""Explicit model should win over env defaults."""
|
||||
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
||||
|
||||
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
||||
|
||||
model, dim = await router._resolve_embedding_model("first-model")
|
||||
assert model == "first-model"
|
||||
assert dim == 123
|
||||
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
|
||||
|
||||
async def test_fallback_to_default():
|
||||
"""Should fallback to all-MiniLM-L6-v2 when no defaults set."""
|
||||
|
||||
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
||||
|
||||
model, dim = await router._resolve_embedding_model(None)
|
||||
assert model == "all-MiniLM-L6-v2"
|
||||
assert dim == 384
|
Loading…
Add table
Add a link
Reference in a new issue