From 2a478fb1d5502d9c8d268cfc627723bc47de8276 Mon Sep 17 00:00:00 2001 From: melonkernel Date: Thu, 4 Sep 2025 12:45:58 +0300 Subject: [PATCH 1/2] Use model type from ProviderModelEntry when listing models --- .../utils/inference/model_registry.py | 2 +- .../inference/test_model_registry.py | 57 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/unit/providers/inference/test_model_registry.py diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 44add8f9e..13f234075 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): Model( identifier=id, provider_resource_id=entry.provider_model_id, - model_type=ModelType.llm, + model_type=entry.model_type or ModelType.llm, metadata=entry.metadata, provider_id=self.__provider_id__, ) diff --git a/tests/unit/providers/inference/test_model_registry.py b/tests/unit/providers/inference/test_model_registry.py new file mode 100644 index 000000000..b53cfa688 --- /dev/null +++ b/tests/unit/providers/inference/test_model_registry.py @@ -0,0 +1,57 @@ +import json +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from llama_stack.core.request_headers import request_provider_data_context +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry, ModelType + +class TestConfig(BaseModel): + api_key: str | None = Field(default=None) + +class TestProviderDataValidator(BaseModel): + test_api_key: str | None = Field(default=None) + +MODEL_ENTRIES_WITHOUT_ALIASES = [ + ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]), + ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}), + ] + +class TestLiteLLMAdapterWithModelEntries(LiteLLMOpenAIMixin): + def __init__(self, config: TestConfig): + super().__init__( + model_entries=MODEL_ENTRIES_WITHOUT_ALIASES, + litellm_provider_name="test", + api_key_from_config=config.api_key, + provider_data_api_key_field="test_api_key", + openai_compat_api_base=None, + ) + +@pytest.fixture +def adapter_with_model_entries(): + """Fixture to create adapter with API key in config""" + config = TestConfig() + adapter = TestLiteLLMAdapterWithModelEntries(config) + adapter.__provider_id__ = "test-provider" + + return adapter + + +async def test_model_types_are_correct(adapter_with_model_entries): + """Test that model types are correct""" + model_entries = adapter_with_model_entries.model_entries + llm_model_entries = [m for m in model_entries if m.model_type == ModelType.llm] + assert len(llm_model_entries) == 1 + + embedding_model_entries = [m for m in model_entries if m.model_type == ModelType.embedding] + assert len(embedding_model_entries) == 1 + + models = await adapter_with_model_entries.list_models() + llm_models = [m for m in models if m.model_type == ModelType.llm] + assert len(llm_models) == len(llm_model_entries) + + embedding_models = [m for m in models if m.model_type == ModelType.embedding] + assert len(embedding_models) == len(embedding_model_entries) + From 27a69188347f26d5480bf9f1d69cd324806da0a6 Mon Sep 17 00:00:00 2001 From: melonkernel Date: Thu, 4 Sep 2025 14:01:51 +0300 Subject: [PATCH 2/2] ensure embedding models have dimensions --- .../utils/inference/model_registry.py | 5 ++++ .../inference/test_model_registry.py | 23 +++++++++++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 13f234075..35cc4cf1f 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -35,6 +35,11 @@ class ProviderModelEntry(BaseModel): llama_model: str | None = None model_type: ModelType = ModelType.llm metadata: dict[str, Any] = Field(default_factory=dict) + + def __init__(self, **data): + super().__init__(**data) + if self.model_type == ModelType.embedding and "embedding_dimension" not in self.metadata: + raise ValueError("Embedding models must specify 'embedding_dimension' in metadata") def get_huggingface_repo(model_descriptor: str) -> str | None: diff --git a/tests/unit/providers/inference/test_model_registry.py b/tests/unit/providers/inference/test_model_registry.py index b53cfa688..c10b27ef8 100644 --- a/tests/unit/providers/inference/test_model_registry.py +++ b/tests/unit/providers/inference/test_model_registry.py @@ -15,9 +15,9 @@ class TestProviderDataValidator(BaseModel): test_api_key: str | None = Field(default=None) MODEL_ENTRIES_WITHOUT_ALIASES = [ - ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]), - ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}), - ] + ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]), + ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}), +] class TestLiteLLMAdapterWithModelEntries(LiteLLMOpenAIMixin): def __init__(self, config: TestConfig): @@ -38,7 +38,6 @@ def adapter_with_model_entries(): return adapter - async def test_model_types_are_correct(adapter_with_model_entries): """Test that model types are correct""" model_entries = adapter_with_model_entries.model_entries @@ -54,4 +53,20 @@ async def test_model_types_are_correct(adapter_with_model_entries): embedding_models = [m for m in models if m.model_type == ModelType.embedding] assert len(embedding_models) == len(embedding_model_entries) + +def test_embedding_metadata_is_required(): + with pytest.raises(ValueError): + entry1 = ProviderModelEntry( + model_type=ModelType.embedding, + provider_model_id="test-text-embedding-model", + aliases=[], + metadata={} + ) + entry2 = ProviderModelEntry( + model_type=ModelType.embedding, + provider_model_id="test-text-embedding-model", + aliases=[], + metadata={"embedding_dimension": 1536} + ) + assert entry2.metadata["embedding_dimension"] == 1536