mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Replace MissingEmbeddingModelError with IBM Granite default
- Replace error with ibm-granite/granite-embedding-125m-english default - Based on issue #2418 for commercial compatibility and better UX - Update tests to verify default fallback behavior - Update documentation to reflect new precedence rules - Remove unused MissingEmbeddingModelError class - Update tip section to clarify fallback behavior Resolves review comment to use default instead of error.
This commit is contained in:
parent
f8946d8b9d
commit
70df4b7878
4 changed files with 39 additions and 62 deletions
|
@ -6,12 +6,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
"""Global vector-store configuration shared across the stack.
|
||||
"""Vector store global config stuff.
|
||||
|
||||
This module introduces `VectorStoreConfig`, a small Pydantic model that
|
||||
lives under `StackRunConfig.vector_store_config`. It lets deployers set
|
||||
an explicit default embedding model (and dimension) that the Vector-IO
|
||||
router will inject whenever the caller does not specify one.
|
||||
Basically just holds default embedding model settings so we don't have to
|
||||
pass them around everywhere. Router picks these up when client doesn't specify.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
@ -22,25 +20,14 @@ __all__ = ["VectorStoreConfig"]
|
|||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Stack-level defaults for vector-store creation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
default_embedding_model
|
||||
The model *id* the stack should use when an embedding model is
|
||||
required but not supplied by the API caller. When *None* the
|
||||
router will fall back to the system default (ibm-granite/granite-embedding-125m-english).
|
||||
default_embedding_dimension
|
||||
Optional integer hint for vector dimension. Routers/providers
|
||||
may validate that the chosen model emits vectors of this size.
|
||||
"""
|
||||
"""Default embedding model config that gets picked up from env vars."""
|
||||
|
||||
default_embedding_model: str | None = Field(
|
||||
default_factory=lambda: os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL")
|
||||
)
|
||||
# dimension from env - fallback to None if not set or invalid
|
||||
default_embedding_dimension: int | None = Field(
|
||||
default_factory=lambda: int(os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", 0)) or None, ge=1
|
||||
)
|
||||
# Note: If not set, the router will fall back to 384 as the default dimension
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
|
|
@ -78,36 +78,27 @@ class VectorIORouter(VectorIO):
|
|||
return None
|
||||
|
||||
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
|
||||
"""Apply precedence rules to decide which embedding model to use.
|
||||
"""Figure out which embedding model to use and what dimension it has."""
|
||||
|
||||
1. If *explicit_model* is provided, verify dimension (if possible) and use it.
|
||||
2. Else use the global default in ``vector_store_config``.
|
||||
3. Else fallback to system default (ibm-granite/granite-embedding-125m-english).
|
||||
"""
|
||||
|
||||
# 1. explicit override
|
||||
# if they passed a model explicitly, use that
|
||||
if explicit_model is not None:
|
||||
# We still need a dimension; try to look it up in routing table
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
for m in all_models:
|
||||
if getattr(m, "identifier", None) == explicit_model:
|
||||
dim = m.metadata.get("embedding_dimension")
|
||||
# try to look up dimension from our routing table
|
||||
models = await self.routing_table.get_all_with_type("model")
|
||||
for model in models:
|
||||
if getattr(model, "identifier", None) == explicit_model:
|
||||
dim = model.metadata.get("embedding_dimension")
|
||||
if dim is None:
|
||||
raise ValueError(
|
||||
f"Failed to use embedding model {explicit_model}: found but has no embedding_dimension metadata"
|
||||
)
|
||||
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
|
||||
return explicit_model, dim
|
||||
# If not found, dimension unknown - defer to caller
|
||||
# model not in our registry, let caller deal with dimension
|
||||
return explicit_model, None # type: ignore
|
||||
|
||||
# 2. global default
|
||||
cfg = VectorStoreConfig() # picks up env vars automatically
|
||||
if cfg.default_embedding_model is not None:
|
||||
return cfg.default_embedding_model, cfg.default_embedding_dimension or 384
|
||||
# check if we have global defaults set via env vars
|
||||
config = VectorStoreConfig()
|
||||
if config.default_embedding_model is not None:
|
||||
return config.default_embedding_model, config.default_embedding_dimension or 384
|
||||
|
||||
# 3. fallback to system default
|
||||
# Use IBM Granite embedding model as default for commercial compatibility
|
||||
# See: https://github.com/meta-llama/llama-stack/issues/2418
|
||||
# fallback to granite model - see issue #2418 for context
|
||||
return "ibm-granite/granite-embedding-125m-english", 384
|
||||
|
||||
async def register_vector_db(
|
||||
|
|
|
@ -8,19 +8,19 @@ from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
|||
|
||||
|
||||
def test_defaults():
|
||||
cfg = VectorStoreConfig()
|
||||
assert cfg.default_embedding_model is None
|
||||
assert cfg.default_embedding_dimension is None
|
||||
config = VectorStoreConfig()
|
||||
assert config.default_embedding_model is None
|
||||
assert config.default_embedding_dimension is None
|
||||
|
||||
|
||||
def test_env_loading(monkeypatch):
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "test-model")
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "123")
|
||||
|
||||
cfg = VectorStoreConfig()
|
||||
assert cfg.default_embedding_model == "test-model"
|
||||
assert cfg.default_embedding_dimension == 123
|
||||
config = VectorStoreConfig()
|
||||
assert config.default_embedding_model == "test-model"
|
||||
assert config.default_embedding_dimension == 123
|
||||
|
||||
# Clean up
|
||||
# cleanup
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
||||
|
|
|
@ -17,29 +17,28 @@ class _DummyModel:
|
|||
|
||||
|
||||
class _DummyRoutingTable:
|
||||
"""Minimal stub satisfying the methods used by VectorIORouter in tests."""
|
||||
"""Just a fake routing table for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._models: list[_DummyModel] = [
|
||||
self._models = [
|
||||
_DummyModel("first-model", 123),
|
||||
_DummyModel("second-model", 512),
|
||||
]
|
||||
|
||||
async def get_all_with_type(self, _type: str):
|
||||
# Only embedding models requested in our tests
|
||||
# just return embedding models for tests
|
||||
return self._models
|
||||
|
||||
# The following methods are required by the VectorIORouter signature but
|
||||
# are not used in these unit tests; stub them out.
|
||||
async def register_vector_db(self, *args, **kwargs):
|
||||
# VectorIORouter needs these but we don't use them in tests
|
||||
async def register_vector_db(self, *_args, **_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_provider_impl(self, *args, **kwargs):
|
||||
async def get_provider_impl(self, *_args, **_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def test_global_default_used(monkeypatch):
|
||||
"""Router should pick up global default when no explicit model is supplied."""
|
||||
"""Should use env var defaults when no explicit model given."""
|
||||
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
|
||||
|
@ -50,13 +49,13 @@ async def test_global_default_used(monkeypatch):
|
|||
assert model == "env-default-model"
|
||||
assert dim == 256
|
||||
|
||||
# Cleanup env vars
|
||||
# cleanup
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
||||
|
||||
|
||||
async def test_explicit_override(monkeypatch):
|
||||
"""Explicit model parameter should override global default."""
|
||||
"""Explicit model should win over env defaults."""
|
||||
|
||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
||||
|
||||
|
@ -69,11 +68,11 @@ async def test_explicit_override(monkeypatch):
|
|||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||
|
||||
|
||||
async def test_fallback_to_system_default():
|
||||
"""Router should use system default when neither explicit nor global default is available."""
|
||||
async def test_fallback_to_granite():
|
||||
"""Should fallback to granite model when no defaults set."""
|
||||
|
||||
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
||||
|
||||
model, dimension = await router._resolve_embedding_model(None)
|
||||
model, dim = await router._resolve_embedding_model(None)
|
||||
assert model == "ibm-granite/granite-embedding-125m-english"
|
||||
assert dimension == 384
|
||||
assert dim == 384
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue