Use model type from ProviderModelEntry when listing models

This commit is contained in:
melonkernel 2025-09-04 12:45:58 +03:00
parent eed25fc6e4
commit 2a478fb1d5
2 changed files with 58 additions and 1 deletions

View file

@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
Model(
identifier=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
model_type=entry.model_type or ModelType.llm,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)

View file

@ -0,0 +1,57 @@
import json
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry, ModelType
class TestConfig(BaseModel):
api_key: str | None = Field(default=None)
class TestProviderDataValidator(BaseModel):
test_api_key: str | None = Field(default=None)
MODEL_ENTRIES_WITHOUT_ALIASES = [
ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]),
ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}),
]
class TestLiteLLMAdapterWithModelEntries(LiteLLMOpenAIMixin):
def __init__(self, config: TestConfig):
super().__init__(
model_entries=MODEL_ENTRIES_WITHOUT_ALIASES,
litellm_provider_name="test",
api_key_from_config=config.api_key,
provider_data_api_key_field="test_api_key",
openai_compat_api_base=None,
)
@pytest.fixture
def adapter_with_model_entries():
"""Fixture to create adapter with API key in config"""
config = TestConfig()
adapter = TestLiteLLMAdapterWithModelEntries(config)
adapter.__provider_id__ = "test-provider"
return adapter
async def test_model_types_are_correct(adapter_with_model_entries):
"""Test that model types are correct"""
model_entries = adapter_with_model_entries.model_entries
llm_model_entries = [m for m in model_entries if m.model_type == ModelType.llm]
assert len(llm_model_entries) == 1
embedding_model_entries = [m for m in model_entries if m.model_type == ModelType.embedding]
assert len(embedding_model_entries) == 1
models = await adapter_with_model_entries.list_models()
llm_models = [m for m in models if m.model_type == ModelType.llm]
assert len(llm_models) == len(llm_model_entries)
embedding_models = [m for m in models if m.model_type == ModelType.embedding]
assert len(embedding_models) == len(embedding_model_entries)