This commit is contained in:
Sumanth Kamenani 2025-09-24 09:30:04 +02:00 committed by GitHub
commit 689f1db815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 284 additions and 8 deletions

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.vector_store_config import VectorStoreConfig
def test_defaults(monkeypatch):
# ensure env is clean to avoid flaky defaults
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
config = VectorStoreConfig()
assert config.default_embedding_model is None
assert config.default_embedding_dimension is None
def test_env_loading(monkeypatch):
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "test-model")
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "123")
config = VectorStoreConfig()
assert config.default_embedding_model == "test-model"
assert config.default_embedding_dimension == 123
# cleanup
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)

View file

@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.core.routers.vector_io import VectorIORouter
pytestmark = pytest.mark.asyncio
class _DummyModel:
def __init__(self, identifier: str, dim: int):
self.identifier = identifier
self.model_type = ModelType.embedding
self.metadata = {"embedding_dimension": dim}
class _DummyRoutingTable:
"""Just a fake routing table for testing."""
def __init__(self):
self._models = [
_DummyModel("first-model", 123),
_DummyModel("second-model", 512),
]
async def get_all_with_type(self, _type: str):
# just return embedding models for tests
return self._models
# VectorIORouter needs these but we don't use them in tests
async def register_vector_db(self, *_args, **_kwargs):
raise NotImplementedError
async def get_provider_impl(self, *_args, **_kwargs):
raise NotImplementedError
async def test_global_default_used(monkeypatch):
"""Should use env var defaults when no explicit model given."""
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
router = VectorIORouter(routing_table=_DummyRoutingTable())
model, dim = await router._resolve_embedding_model(None)
assert model == "env-default-model"
assert dim == 256
# cleanup
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
async def test_explicit_override(monkeypatch):
"""Explicit model should win over env defaults."""
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
router = VectorIORouter(routing_table=_DummyRoutingTable())
model, dim = await router._resolve_embedding_model("first-model")
assert model == "first-model"
assert dim == 123
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
async def test_fallback_to_default():
"""Should fallback to first available embedding model when no defaults set."""
router = VectorIORouter(routing_table=_DummyRoutingTable())
model, dim = await router._resolve_embedding_model(None)
assert model == "first-model"
assert dim == 123
async def test_missing_dimension_requirement(monkeypatch):
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "some-model")
router = VectorIORouter(routing_table=_DummyRoutingTable())
with pytest.raises(ValueError, match="default_embedding_model.*is set but default_embedding_dimension is missing"):
await router._resolve_embedding_model(None)
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
async def test_unregistered_model_error():
router = VectorIORouter(routing_table=_DummyRoutingTable())
with pytest.raises(ValueError, match="Embedding model 'unknown-model' not found in model registry"):
await router._resolve_embedding_model("unknown-model")
class _EmptyRoutingTable:
async def get_all_with_type(self, _type: str):
return []
async def test_no_models_available_error():
router = VectorIORouter(routing_table=_EmptyRoutingTable())
with pytest.raises(ValueError, match="No embedding model specified and no default configured"):
await router._resolve_embedding_model(None)