mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
includes: - require models to exist in registry before use - make default_embedding_dimension mandatory when setting default model - use first available model fallback instead of hardcoded all-MiniLM-L6-v2 - add tests for error cases and update docs
112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.core.routers.vector_io import VectorIORouter
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
class _DummyModel:
|
|
def __init__(self, identifier: str, dim: int):
|
|
self.identifier = identifier
|
|
self.model_type = ModelType.embedding
|
|
self.metadata = {"embedding_dimension": dim}
|
|
|
|
|
|
class _DummyRoutingTable:
|
|
"""Just a fake routing table for testing."""
|
|
|
|
def __init__(self):
|
|
self._models = [
|
|
_DummyModel("first-model", 123),
|
|
_DummyModel("second-model", 512),
|
|
]
|
|
|
|
async def get_all_with_type(self, _type: str):
|
|
# just return embedding models for tests
|
|
return self._models
|
|
|
|
# VectorIORouter needs these but we don't use them in tests
|
|
async def register_vector_db(self, *_args, **_kwargs):
|
|
raise NotImplementedError
|
|
|
|
async def get_provider_impl(self, *_args, **_kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
async def test_global_default_used(monkeypatch):
|
|
"""Should use env var defaults when no explicit model given."""
|
|
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
model, dim = await router._resolve_embedding_model(None)
|
|
assert model == "env-default-model"
|
|
assert dim == 256
|
|
|
|
# cleanup
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
|
|
|
|
|
async def test_explicit_override(monkeypatch):
|
|
"""Explicit model should win over env defaults."""
|
|
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
model, dim = await router._resolve_embedding_model("first-model")
|
|
assert model == "first-model"
|
|
assert dim == 123
|
|
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
|
|
|
|
|
async def test_fallback_to_default():
|
|
"""Should fallback to first available embedding model when no defaults set."""
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
model, dim = await router._resolve_embedding_model(None)
|
|
assert model == "first-model"
|
|
assert dim == 123
|
|
|
|
|
|
async def test_missing_dimension_requirement(monkeypatch):
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "some-model")
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
with pytest.raises(ValueError, match="default_embedding_model.*is set but default_embedding_dimension is missing"):
|
|
await router._resolve_embedding_model(None)
|
|
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
|
|
|
|
|
async def test_unregistered_model_error():
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
with pytest.raises(ValueError, match="Embedding model 'unknown-model' not found in model registry"):
|
|
await router._resolve_embedding_model("unknown-model")
|
|
|
|
|
|
class _EmptyRoutingTable:
|
|
async def get_all_with_type(self, _type: str):
|
|
return []
|
|
|
|
|
|
async def test_no_models_available_error():
|
|
router = VectorIORouter(routing_table=_EmptyRoutingTable())
|
|
|
|
with pytest.raises(ValueError, match="No embedding model specified and no default configured"):
|
|
await router._resolve_embedding_model(None)
|