From a765f1c029efd640ba8ed66065fb062eab1d5bc3 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 6 Oct 2025 06:37:56 -0400 Subject: [PATCH] get_models -> list_provider_model_ids --- .../remote/inference/databricks/databricks.py | 4 +- .../providers/utils/inference/openai_mixin.py | 42 ++-- .../utils/inference/test_openai_mixin.py | 207 ++++-------------- 3 files changed, 64 insertions(+), 189 deletions(-) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 6f7e1ba97..2ab4856b5 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import Iterable from typing import Any from databricks.sdk import WorkspaceClient from llama_stack.apis.inference import ( Inference, - Model, OpenAICompletion, ) from llama_stack.log import get_logger @@ -71,7 +71,7 @@ class DatabricksInferenceAdapter( ) -> OpenAICompletion: raise NotImplementedError() - async def get_models(self) -> list[Model] | None: + async def list_provider_model_ids(self) -> Iterable[str]: return [ endpoint.name for endpoint in WorkspaceClient( diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index b8f66cda3..3239d814c 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -111,7 +111,7 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): """ return {} - async def get_models(self) -> Iterable[str] | None: + async def list_provider_model_ids(self) -> Iterable[str]: """ List available models from the provider. @@ -121,7 +121,7 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): :return: An iterable of model IDs or None if not implemented """ - return None + return [m.id async for m in self.client.models.list()] @property def client(self) -> AsyncOpenAI: @@ -400,41 +400,35 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): self._model_cache = {} # give subclasses a chance to provide custom model listing - if (iterable := await self.get_models()) is not None: - if not hasattr(iterable, "__iter__"): - raise TypeError( - f"Failed to list models: {self.__class__.__name__}.get_models() must return an iterable of " - f"strings or None, but returned {type(iterable).__name__}" - ) - models_ids = list(iterable) - logger.info( - f"Using {self.__class__.__name__}.get_models() implementation, received {len(models_ids)} models" + iterable = await self.list_provider_model_ids() + if not hasattr(iterable, "__iter__"): + raise TypeError( + f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of " + f"strings or None, but returned {type(iterable).__name__}" ) - else: - models_ids = [m.id async for m in self.client.models.list()] + provider_models_ids = list(iterable) + logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models") - for m_id in models_ids: - if self.allowed_models and m_id not in self.allowed_models: - logger.info(f"Skipping model {m_id} as it is not in the allowed models list") + for provider_model_id in provider_models_ids: + if self.allowed_models and provider_model_id not in self.allowed_models: + logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue - if metadata := self.embedding_model_metadata.get(m_id): - # This is an embedding model - augment with metadata + if metadata := self.embedding_model_metadata.get(provider_model_id): model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m_id, - identifier=m_id, + provider_resource_id=provider_model_id, + identifier=provider_model_id, model_type=ModelType.embedding, metadata=metadata, ) else: - # This is an LLM model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m_id, - identifier=m_id, + provider_resource_id=provider_model_id, + identifier=provider_model_id, model_type=ModelType.llm, ) - self._model_cache[m_id] = model + self._model_cache[provider_model_id] = model return list(self._model_cache.values()) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 9991a49cd..1110e1843 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -499,19 +499,19 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl): return "default-base-url" -class OpenAIMixinWithCustomGetModels(OpenAIMixinImpl): - """Test implementation with custom get_models override""" +class CustomListProviderModelIdsImplementation(OpenAIMixinImpl): + """Test implementation with custom list_provider_model_ids override""" def __init__(self, custom_model_ids): self._custom_model_ids = custom_model_ids - async def get_models(self) -> Iterable[str] | None: + async def list_provider_model_ids(self) -> Iterable[str]: """Return custom model IDs list""" return self._custom_model_ids -class TestOpenAIMixinCustomGetModels: - """Test cases for custom get_models() implementation functionality""" +class TestOpenAIMixinCustomListProviderModelIds: + """Test cases for custom list_provider_model_ids() implementation functionality""" @pytest.fixture def custom_model_ids_list(self): @@ -519,40 +519,32 @@ class TestOpenAIMixinCustomGetModels: return ["custom-model-1", "custom-model-2", "custom-embedding"] @pytest.fixture - def mixin_with_custom_get_models(self, custom_model_ids_list): - """Create mixin instance with custom get_models implementation""" - mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=custom_model_ids_list) - # Add embedding metadata to test that feature still works + def adapter(self, custom_model_ids_list): + """Create mixin instance with custom list_provider_model_ids implementation""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=custom_model_ids_list) mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} return mixin - async def test_custom_get_models_is_used(self, mixin_with_custom_get_models, custom_model_ids_list): - """Test that custom get_models() implementation is used instead of client.models.list()""" - result = await mixin_with_custom_get_models.list_models() + async def test_is_used(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()""" + result = await adapter.list_models() assert result is not None assert len(result) == 3 - # Verify all custom models are present - identifiers = {m.identifier for m in result} - assert "custom-model-1" in identifiers - assert "custom-model-2" in identifiers - assert "custom-embedding" in identifiers + assert set(custom_model_ids_list) == {m.identifier for m in result} - async def test_custom_get_models_populates_cache(self, mixin_with_custom_get_models): - """Test that custom get_models() results are cached""" - assert len(mixin_with_custom_get_models._model_cache) == 0 + async def test_populates_cache(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() results are cached""" + assert len(adapter._model_cache) == 0 - await mixin_with_custom_get_models.list_models() + await adapter.list_models() - assert len(mixin_with_custom_get_models._model_cache) == 3 - assert "custom-model-1" in mixin_with_custom_get_models._model_cache - assert "custom-model-2" in mixin_with_custom_get_models._model_cache - assert "custom-embedding" in mixin_with_custom_get_models._model_cache + assert set(custom_model_ids_list) == set(adapter._model_cache.keys()) - async def test_custom_get_models_respects_allowed_models(self): - """Test that custom get_models() respects allowed_models filtering""" - mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["model-1", "model-2", "model-3"]) + async def test_respects_allowed_models(self): + """Test that custom list_provider_model_ids() respects allowed_models filtering""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=["model-1", "model-2", "model-3"]) mixin.allowed_models = ["model-1"] result = await mixin.list_models() @@ -561,23 +553,9 @@ class TestOpenAIMixinCustomGetModels: assert len(result) == 1 assert result[0].identifier == "model-1" - async def test_custom_get_models_with_embedding_metadata(self, mixin_with_custom_get_models): - """Test that custom get_models() works with embedding_model_metadata""" - result = await mixin_with_custom_get_models.list_models() - - # Find the embedding model - embedding_model = next((m for m in result if m.identifier == "custom-embedding"), None) - assert embedding_model is not None - assert embedding_model.model_type == ModelType.embedding - assert embedding_model.metadata == {"embedding_dimension": 768, "context_length": 512} - - # Verify LLM models - llm_models = [m for m in result if m.model_type == ModelType.llm] - assert len(llm_models) == 2 - - async def test_custom_get_models_with_empty_list(self): - """Test that custom get_models() handles empty list correctly""" - mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=[]) + async def test_with_empty_list(self): + """Test that custom list_provider_model_ids() handles empty list correctly""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[]) result = await mixin.list_models() @@ -585,157 +563,60 @@ class TestOpenAIMixinCustomGetModels: assert len(result) == 0 assert len(mixin._model_cache) == 0 - async def test_default_get_models_returns_none(self, mixin): - """Test that default get_models() implementation returns None""" - custom_models = await mixin.get_models() - assert custom_models is None + async def test_wrong_type_raises_error(self): + """Test that list_provider_model_ids() returning unhashable items results in an error""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[["nested", "list"], {"key": "value"}]) - async def test_fallback_to_client_when_get_models_returns_none( - self, mixin, mock_client_with_models, mock_client_context - ): - """Test that when get_models() returns None, falls back to client.models.list()""" - # Default get_models() returns None, so should use client - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 3 - mock_client_with_models.models.list.assert_called_once() - - async def test_custom_get_models_creates_proper_model_objects(self): - """Test that custom get_models() model IDs are converted to proper Model objects""" - model_ids = ["gpt-4", "gpt-3.5-turbo"] - mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=model_ids) - - result = await mixin.list_models() - - assert result is not None - assert len(result) == 2 - - for model in result: - assert isinstance(model, Model) - assert model.provider_id == "test-provider" - assert model.identifier in model_ids - assert model.provider_resource_id in model_ids - assert model.model_type == ModelType.llm - - async def test_custom_get_models_bypasses_client(self, mock_client_context): - """Test that providing get_models() means client.models.list() is NOT called""" - mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["model-1", "model-2"]) - - # Create a mock client that should NOT be called - mock_client = MagicMock() - mock_client.models.list = MagicMock(side_effect=AssertionError("client.models.list should not be called!")) - - with mock_client_context(mixin, mock_client): - result = await mixin.list_models() - - # Should succeed without calling client.models.list - assert result is not None - assert len(result) == 2 - mock_client.models.list.assert_not_called() - - async def test_get_models_wrong_type_raises_error(self): - """Test that get_models() returning unhashable items results in an error""" - - class BadGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return list with unhashable items - return [["nested", "list"], {"key": "value"}] # type: ignore - - mixin = BadGetModelsAdapter() - - # Should raise TypeError when trying to use unhashable items as dict keys with pytest.raises(TypeError, match="unhashable type"): await mixin.list_models() - async def test_get_models_non_iterable_raises_error(self): - """Test that get_models() returning non-iterable type raises error""" + async def test_non_iterable_raises_error(self): + """Test that list_provider_model_ids() returning non-iterable type raises error""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=42) - class NonIterableGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return non-iterable type - return 42 # type: ignore - - mixin = NonIterableGetModelsAdapter() - - # Should raise TypeError with custom error message with pytest.raises( TypeError, - match=r"Failed to list models: NonIterableGetModelsAdapter\.get_models\(\) must return an iterable.*but returned int", + match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int", ): await mixin.list_models() - async def test_get_models_with_none_items_raises_error(self): - """Test that get_models() returning list with None items causes error""" + async def test_with_none_items_raises_error(self): + """Test that list_provider_model_ids() returning list with None items causes error""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[None, "valid-model", None]) - class NoneItemsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return list with None items - return [None, "valid-model", None] # type: ignore - - mixin = NoneItemsAdapter() - - # Should raise ValidationError when creating Model with None identifier with pytest.raises(Exception, match="Input should be a valid string"): await mixin.list_models() - async def test_embedding_models_from_custom_get_models_have_correct_type(self, mixin_with_custom_get_models): - """Test that embedding models from custom get_models() are properly typed as embedding""" - result = await mixin_with_custom_get_models.list_models() + async def test_accepts_various_iterables(self): + """Test that list_provider_model_ids() accepts tuples, sets, generators, etc.""" - # Verify we have both LLM and embedding models - llm_models = [m for m in result if m.model_type == ModelType.llm] - embedding_models = [m for m in result if m.model_type == ModelType.embedding] - - assert len(llm_models) == 2 - assert len(embedding_models) == 1 - assert embedding_models[0].identifier == "custom-embedding" - - async def test_llm_models_from_custom_get_models_have_correct_type(self): - """Test that LLM models from custom get_models() are properly typed as llm""" - mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["gpt-4", "claude-3"]) - - result = await mixin.list_models() - - assert result is not None - assert len(result) == 2 - for model in result: - assert model.model_type == ModelType.llm - - async def test_get_models_accepts_various_iterables(self): - """Test that get_models() accepts tuples, sets, generators, etc.""" - - # Test with tuple - class TupleGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: + class TupleAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str] | None: return ("model-1", "model-2", "model-3") - mixin = TupleGetModelsAdapter() + mixin = TupleAdapter() result = await mixin.list_models() assert result is not None assert len(result) == 3 - # Test with generator - class GeneratorGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: + class GeneratorAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str] | None: def gen(): yield "gen-model-1" yield "gen-model-2" return gen() - mixin = GeneratorGetModelsAdapter() + mixin = GeneratorAdapter() result = await mixin.list_models() assert result is not None assert len(result) == 2 - # Test with set (order may vary) - class SetGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: + class SetAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str] | None: return {"set-model-1", "set-model-2"} - mixin = SetGetModelsAdapter() + mixin = SetAdapter() result = await mixin.list_models() assert result is not None assert len(result) == 2