mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
- Remove incorrect 'Llama-Stack v2' version reference from documentation - Move MissingEmbeddingModelError to llama_stack/apis/common/errors.py - Update docstring references to point to correct exception location - Clarify default_embedding_dimension behavior (defaults to 384) - Update test imports and exception handling
81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.common.errors import MissingEmbeddingModelError
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.distribution.routers.vector_io import VectorIORouter
|
|
|
|
|
|
class _DummyModel:
|
|
def __init__(self, identifier: str, dim: int):
|
|
self.identifier = identifier
|
|
self.model_type = ModelType.embedding
|
|
self.metadata = {"embedding_dimension": dim}
|
|
|
|
|
|
class _DummyRoutingTable:
|
|
"""Minimal stub satisfying the methods used by VectorIORouter in tests."""
|
|
|
|
def __init__(self):
|
|
self._models: list[_DummyModel] = [
|
|
_DummyModel("first-model", 123),
|
|
_DummyModel("second-model", 512),
|
|
]
|
|
|
|
async def get_all_with_type(self, _type: str):
|
|
# Only embedding models requested in our tests
|
|
return self._models
|
|
|
|
# The following methods are required by the VectorIORouter signature but
|
|
# are not used in these unit tests; stub them out.
|
|
async def register_vector_db(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
async def get_provider_impl(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
async def test_global_default_used(monkeypatch):
|
|
"""Router should pick up global default when no explicit model is supplied."""
|
|
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", "256")
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
model, dim = await router._resolve_embedding_model(None)
|
|
assert model == "env-default-model"
|
|
assert dim == 256
|
|
|
|
# Cleanup env vars
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_DIMENSION", raising=False)
|
|
|
|
|
|
async def test_explicit_override(monkeypatch):
|
|
"""Explicit model parameter should override global default."""
|
|
|
|
monkeypatch.setenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", "env-default-model")
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
model, dim = await router._resolve_embedding_model("first-model")
|
|
assert model == "first-model"
|
|
assert dim == 123
|
|
|
|
monkeypatch.delenv("LLAMA_STACK_DEFAULT_EMBEDDING_MODEL", raising=False)
|
|
|
|
|
|
async def test_error_when_no_default():
|
|
"""Router should raise when neither explicit nor global default is available."""
|
|
|
|
router = VectorIORouter(routing_table=_DummyRoutingTable())
|
|
|
|
with pytest.raises(MissingEmbeddingModelError):
|
|
await router._resolve_embedding_model(None)
|