mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Replace MissingEmbeddingModelError with IBM Granite default
- Replace error with ibm-granite/granite-embedding-125m-english default - Based on issue #2418 for commercial compatibility and better UX - Update tests to verify default fallback behavior - Update documentation to reflect new precedence rules - Remove unused MissingEmbeddingModelError class - Update tip section to clarify fallback behavior Resolves review comment to use default instead of error.
This commit is contained in:
parent
f8946d8b9d
commit
70df4b7878
4 changed files with 39 additions and 62 deletions
|
@ -6,12 +6,10 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
"""Global vector-store configuration shared across the stack.
|
"""Vector store global config stuff.
|
||||||
|
|
||||||
This module introduces `VectorStoreConfig`, a small Pydantic model that
|
Basically just holds default embedding model settings so we don't have to
|
||||||
lives under `StackRunConfig.vector_store_config`. It lets deployers set
|
pass them around everywhere. Router picks these up when client doesn't specify.
|
||||||
an explicit default embedding model (and dimension) that the Vector-IO
|
|
||||||
router will inject whenever the caller does not specify one.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
@ -22,25 +20,14 @@ __all__ = ["VectorStoreConfig"]
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConfig(BaseModel):
|
class VectorStoreConfig(BaseModel):
|
||||||
"""Stack-level defaults for vector-store creation.
|
"""Default embedding model config that gets picked up from env vars."""
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
default_embedding_model
|
|
||||||
The model *id* the stack should use when an embedding model is
|
|
||||||
required but not supplied by the API caller. When *None* the
|
|
||||||
router will fall back to the system default (ibm-granite/granite-embedding-125m-english).
|
|
||||||
default_embedding_dimension
|
|
||||||
Optional integer hint for vector dimension. Routers/providers
|
|
||||||
may validate that the chosen model emits vectors of this size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
default_embedding_model: str | None = Field(
|
default_embedding_model: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL")
|
default_factory=lambda: os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL")
|
||||||
)
|
)
|
||||||
|
# dimension from env - fallback to None if not set or invalid
|
||||||
default_embedding_dimension: int | None = Field(
|
default_embedding_dimension: int | None = Field(
|
||||||
default_factory=lambda: int(os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", 0)) or None, ge=1
|
default_factory=lambda: int(os.getenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", 0)) or None, ge=1
|
||||||
)
|
)
|
||||||
# Note: If not set, the router will fall back to 384 as the default dimension
|
|
||||||
|
|
||||||
model_config = ConfigDict(frozen=True)
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
|
@ -78,36 +78,27 @@ class VectorIORouter(VectorIO):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
|
async def _resolve_embedding_model(self, explicit_model: str | None = None) -> tuple[str, int]:
|
||||||
"""Apply precedence rules to decide which embedding model to use.
|
"""Figure out which embedding model to use and what dimension it has."""
|
||||||
|
|
||||||
1. If *explicit_model* is provided, verify dimension (if possible) and use it.
|
# if they passed a model explicitly, use that
|
||||||
2. Else use the global default in ``vector_store_config``.
|
|
||||||
3. Else fallback to system default (ibm-granite/granite-embedding-125m-english).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. explicit override
|
|
||||||
if explicit_model is not None:
|
if explicit_model is not None:
|
||||||
# We still need a dimension; try to look it up in routing table
|
# try to look up dimension from our routing table
|
||||||
all_models = await self.routing_table.get_all_with_type("model")
|
models = await self.routing_table.get_all_with_type("model")
|
||||||
for m in all_models:
|
for model in models:
|
||||||
if getattr(m, "identifier", None) == explicit_model:
|
if getattr(model, "identifier", None) == explicit_model:
|
||||||
dim = m.metadata.get("embedding_dimension")
|
dim = model.metadata.get("embedding_dimension")
|
||||||
if dim is None:
|
if dim is None:
|
||||||
raise ValueError(
|
raise ValueError(f"Model {explicit_model} found but no embedding dimension in metadata")
|
||||||
f"Failed to use embedding model {explicit_model}: found but has no embedding_dimension metadata"
|
|
||||||
)
|
|
||||||
return explicit_model, dim
|
return explicit_model, dim
|
||||||
# If not found, dimension unknown - defer to caller
|
# model not in our registry, let caller deal with dimension
|
||||||
return explicit_model, None # type: ignore
|
return explicit_model, None # type: ignore
|
||||||
|
|
||||||
# 2. global default
|
# check if we have global defaults set via env vars
|
||||||
cfg = VectorStoreConfig() # picks up env vars automatically
|
config = VectorStoreConfig()
|
||||||
if cfg.default_embedding_model is not None:
|
if config.default_embedding_model is not None:
|
||||||
return cfg.default_embedding_model, cfg.default_embedding_dimension or 384
|
return config.default_embedding_model, config.default_embedding_dimension or 384
|
||||||
|
|
||||||
# 3. fallback to system default
|
# fallback to granite model - see issue #2418 for context
|
||||||
# Use IBM Granite embedding model as default for commercial compatibility
|
|
||||||
# See: https://github.com/meta-llama/llama-stack/issues/2418
|
|
||||||
return "ibm-granite/granite-embedding-125m-english", 384
|
return "ibm-granite/granite-embedding-125m-english", 384
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
|
|
|
@ -8,19 +8,19 @@ from llama_stack.apis.common.vector_store_config import VectorStoreConfig
|
||||||
|
|
||||||
|
|
||||||
def test_defaults():
|
def test_defaults():
|
||||||
cfg = VectorStoreConfig()
|
config = VectorStoreConfig()
|
||||||
assert cfg.default_embedding_model is None
|
assert config.default_embedding_model is None
|
||||||
assert cfg.default_embedding_dimension is None
|
assert config.default_embedding_dimension is None
|
||||||
|
|
||||||
|
|
||||||
def test_env_loading(monkeypatch):
|
def test_env_loading(monkeypatch):
|
||||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "test-model")
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "test-model")
|
||||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "123")
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "123")
|
||||||
|
|
||||||
cfg = VectorStoreConfig()
|
config = VectorStoreConfig()
|
||||||
assert cfg.default_embedding_model == "test-model"
|
assert config.default_embedding_model == "test-model"
|
||||||
assert cfg.default_embedding_dimension == 123
|
assert config.default_embedding_dimension == 123
|
||||||
|
|
||||||
# Clean up
|
# cleanup
|
||||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
||||||
|
|
|
@ -17,29 +17,28 @@ class _DummyModel:
|
||||||
|
|
||||||
|
|
||||||
class _DummyRoutingTable:
|
class _DummyRoutingTable:
|
||||||
"""Minimal stub satisfying the methods used by VectorIORouter in tests."""
|
"""Just a fake routing table for testing."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._models: list[_DummyModel] = [
|
self._models = [
|
||||||
_DummyModel("first-model", 123),
|
_DummyModel("first-model", 123),
|
||||||
_DummyModel("second-model", 512),
|
_DummyModel("second-model", 512),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_all_with_type(self, _type: str):
|
async def get_all_with_type(self, _type: str):
|
||||||
# Only embedding models requested in our tests
|
# just return embedding models for tests
|
||||||
return self._models
|
return self._models
|
||||||
|
|
||||||
# The following methods are required by the VectorIORouter signature but
|
# VectorIORouter needs these but we don't use them in tests
|
||||||
# are not used in these unit tests; stub them out.
|
async def register_vector_db(self, *_args, **_kwargs):
|
||||||
async def register_vector_db(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_provider_impl(self, *args, **kwargs):
|
async def get_provider_impl(self, *_args, **_kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
async def test_global_default_used(monkeypatch):
|
async def test_global_default_used(monkeypatch):
|
||||||
"""Router should pick up global default when no explicit model is supplied."""
|
"""Should use env var defaults when no explicit model given."""
|
||||||
|
|
||||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
||||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
|
||||||
|
@ -50,13 +49,13 @@ async def test_global_default_used(monkeypatch):
|
||||||
assert model == "env-default-model"
|
assert model == "env-default-model"
|
||||||
assert dim == 256
|
assert dim == 256
|
||||||
|
|
||||||
# Cleanup env vars
|
# cleanup
|
||||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
||||||
|
|
||||||
|
|
||||||
async def test_explicit_override(monkeypatch):
|
async def test_explicit_override(monkeypatch):
|
||||||
"""Explicit model parameter should override global default."""
|
"""Explicit model should win over env defaults."""
|
||||||
|
|
||||||
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
||||||
|
|
||||||
|
@ -69,11 +68,11 @@ async def test_explicit_override(monkeypatch):
|
||||||
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
||||||
|
|
||||||
|
|
||||||
async def test_fallback_to_system_default():
|
async def test_fallback_to_granite():
|
||||||
"""Router should use system default when neither explicit nor global default is available."""
|
"""Should fallback to granite model when no defaults set."""
|
||||||
|
|
||||||
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
||||||
|
|
||||||
model, dimension = await router._resolve_embedding_model(None)
|
model, dim = await router._resolve_embedding_model(None)
|
||||||
assert model == "ibm-granite/granite-embedding-125m-english"
|
assert model == "ibm-granite/granite-embedding-125m-english"
|
||||||
assert dimension == 384
|
assert dim == 384
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue