From c465472e42c9f16b75136ec83131b1429a583c62 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 4 Oct 2025 08:34:23 -0400 Subject: [PATCH] chore: give OpenAIMixin subcalsses a change to list models without leaking _model_cache details --- .../remote/inference/databricks/databricks.py | 33 +-- .../providers/utils/inference/openai_mixin.py | 46 +++- .../utils/inference/test_openai_mixin.py | 243 ++++++++++++++++++ 3 files changed, 286 insertions(+), 36 deletions(-) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index a2621b81e..6f7e1ba97 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -13,7 +13,6 @@ from llama_stack.apis.inference import ( Model, OpenAICompletion, ) -from llama_stack.apis.models import ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -72,31 +71,13 @@ class DatabricksInferenceAdapter( ) -> OpenAICompletion: raise NotImplementedError() - async def list_models(self) -> list[Model] | None: - self._model_cache = {} # from OpenAIMixin - ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async - endpoints = ws_client.serving_endpoints.list() - for endpoint in endpoints: - model = Model( - provider_id=self.__provider_id__, - provider_resource_id=endpoint.name, - identifier=endpoint.name, - ) - if endpoint.task == "llm/v1/chat": - model.model_type = ModelType.llm # this is redundant, but informative - elif endpoint.task == "llm/v1/embeddings": - if endpoint.name not in self.embedding_model_metadata: - logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.") - continue - model.model_type = ModelType.embedding - model.metadata = self.embedding_model_metadata[endpoint.name] - else: - logger.warning(f"Unknown model type, skipping: {endpoint}") - continue - - self._model_cache[endpoint.name] = model - - return list(self._model_cache.values()) + async def get_models(self) -> list[Model] | None: + return [ + endpoint.name + for endpoint in WorkspaceClient( + host=self.config.url, token=self.get_api_key() + ).serving_endpoints.list() # TODO: this is not async + ] async def should_refresh_models(self) -> bool: return False diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e..b8f66cda3 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -7,7 +7,7 @@ import base64 import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from typing import Any from openai import NOT_GIVEN, AsyncOpenAI @@ -111,6 +111,18 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): """ return {} + async def get_models(self) -> Iterable[str] | None: + """ + List available models from the provider. + + Child classes can override this method to provide a custom implementation + for listing models. The default implementation uses the AsyncOpenAI client + to list models from the OpenAI-compatible endpoint. + + :return: An iterable of model IDs or None if not implemented + """ + return None + @property def client(self) -> AsyncOpenAI: """ @@ -387,16 +399,30 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): """ self._model_cache = {} - async for m in self.client.models.list(): - if self.allowed_models and m.id not in self.allowed_models: - logger.info(f"Skipping model {m.id} as it is not in the allowed models list") + # give subclasses a chance to provide custom model listing + if (iterable := await self.get_models()) is not None: + if not hasattr(iterable, "__iter__"): + raise TypeError( + f"Failed to list models: {self.__class__.__name__}.get_models() must return an iterable of " + f"strings or None, but returned {type(iterable).__name__}" + ) + models_ids = list(iterable) + logger.info( + f"Using {self.__class__.__name__}.get_models() implementation, received {len(models_ids)} models" + ) + else: + models_ids = [m.id async for m in self.client.models.list()] + + for m_id in models_ids: + if self.allowed_models and m_id not in self.allowed_models: + logger.info(f"Skipping model {m_id} as it is not in the allowed models list") continue - if metadata := self.embedding_model_metadata.get(m.id): + if metadata := self.embedding_model_metadata.get(m_id): # This is an embedding model - augment with metadata model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m.id, - identifier=m.id, + provider_resource_id=m_id, + identifier=m_id, model_type=ModelType.embedding, metadata=metadata, ) @@ -404,11 +430,11 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): # This is an LLM model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m.id, - identifier=m.id, + provider_resource_id=m_id, + identifier=m_id, model_type=ModelType.llm, ) - self._model_cache[m.id] = model + self._model_cache[m_id] = model return list(self._model_cache.values()) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 4856f510b..9991a49cd 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +from collections.abc import Iterable from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch import pytest @@ -498,6 +499,248 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl): return "default-base-url" +class OpenAIMixinWithCustomGetModels(OpenAIMixinImpl): + """Test implementation with custom get_models override""" + + def __init__(self, custom_model_ids): + self._custom_model_ids = custom_model_ids + + async def get_models(self) -> Iterable[str] | None: + """Return custom model IDs list""" + return self._custom_model_ids + + +class TestOpenAIMixinCustomGetModels: + """Test cases for custom get_models() implementation functionality""" + + @pytest.fixture + def custom_model_ids_list(self): + """Create a list of custom model ID strings""" + return ["custom-model-1", "custom-model-2", "custom-embedding"] + + @pytest.fixture + def mixin_with_custom_get_models(self, custom_model_ids_list): + """Create mixin instance with custom get_models implementation""" + mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=custom_model_ids_list) + # Add embedding metadata to test that feature still works + mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} + return mixin + + async def test_custom_get_models_is_used(self, mixin_with_custom_get_models, custom_model_ids_list): + """Test that custom get_models() implementation is used instead of client.models.list()""" + result = await mixin_with_custom_get_models.list_models() + + assert result is not None + assert len(result) == 3 + + # Verify all custom models are present + identifiers = {m.identifier for m in result} + assert "custom-model-1" in identifiers + assert "custom-model-2" in identifiers + assert "custom-embedding" in identifiers + + async def test_custom_get_models_populates_cache(self, mixin_with_custom_get_models): + """Test that custom get_models() results are cached""" + assert len(mixin_with_custom_get_models._model_cache) == 0 + + await mixin_with_custom_get_models.list_models() + + assert len(mixin_with_custom_get_models._model_cache) == 3 + assert "custom-model-1" in mixin_with_custom_get_models._model_cache + assert "custom-model-2" in mixin_with_custom_get_models._model_cache + assert "custom-embedding" in mixin_with_custom_get_models._model_cache + + async def test_custom_get_models_respects_allowed_models(self): + """Test that custom get_models() respects allowed_models filtering""" + mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["model-1", "model-2", "model-3"]) + mixin.allowed_models = ["model-1"] + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 1 + assert result[0].identifier == "model-1" + + async def test_custom_get_models_with_embedding_metadata(self, mixin_with_custom_get_models): + """Test that custom get_models() works with embedding_model_metadata""" + result = await mixin_with_custom_get_models.list_models() + + # Find the embedding model + embedding_model = next((m for m in result if m.identifier == "custom-embedding"), None) + assert embedding_model is not None + assert embedding_model.model_type == ModelType.embedding + assert embedding_model.metadata == {"embedding_dimension": 768, "context_length": 512} + + # Verify LLM models + llm_models = [m for m in result if m.model_type == ModelType.llm] + assert len(llm_models) == 2 + + async def test_custom_get_models_with_empty_list(self): + """Test that custom get_models() handles empty list correctly""" + mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=[]) + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 + assert len(mixin._model_cache) == 0 + + async def test_default_get_models_returns_none(self, mixin): + """Test that default get_models() implementation returns None""" + custom_models = await mixin.get_models() + assert custom_models is None + + async def test_fallback_to_client_when_get_models_returns_none( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that when get_models() returns None, falls back to client.models.list()""" + # Default get_models() returns None, so should use client + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 3 + mock_client_with_models.models.list.assert_called_once() + + async def test_custom_get_models_creates_proper_model_objects(self): + """Test that custom get_models() model IDs are converted to proper Model objects""" + model_ids = ["gpt-4", "gpt-3.5-turbo"] + mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=model_ids) + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 2 + + for model in result: + assert isinstance(model, Model) + assert model.provider_id == "test-provider" + assert model.identifier in model_ids + assert model.provider_resource_id in model_ids + assert model.model_type == ModelType.llm + + async def test_custom_get_models_bypasses_client(self, mock_client_context): + """Test that providing get_models() means client.models.list() is NOT called""" + mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["model-1", "model-2"]) + + # Create a mock client that should NOT be called + mock_client = MagicMock() + mock_client.models.list = MagicMock(side_effect=AssertionError("client.models.list should not be called!")) + + with mock_client_context(mixin, mock_client): + result = await mixin.list_models() + + # Should succeed without calling client.models.list + assert result is not None + assert len(result) == 2 + mock_client.models.list.assert_not_called() + + async def test_get_models_wrong_type_raises_error(self): + """Test that get_models() returning unhashable items results in an error""" + + class BadGetModelsAdapter(OpenAIMixinImpl): + async def get_models(self) -> Iterable[str] | None: + # Return list with unhashable items + return [["nested", "list"], {"key": "value"}] # type: ignore + + mixin = BadGetModelsAdapter() + + # Should raise TypeError when trying to use unhashable items as dict keys + with pytest.raises(TypeError, match="unhashable type"): + await mixin.list_models() + + async def test_get_models_non_iterable_raises_error(self): + """Test that get_models() returning non-iterable type raises error""" + + class NonIterableGetModelsAdapter(OpenAIMixinImpl): + async def get_models(self) -> Iterable[str] | None: + # Return non-iterable type + return 42 # type: ignore + + mixin = NonIterableGetModelsAdapter() + + # Should raise TypeError with custom error message + with pytest.raises( + TypeError, + match=r"Failed to list models: NonIterableGetModelsAdapter\.get_models\(\) must return an iterable.*but returned int", + ): + await mixin.list_models() + + async def test_get_models_with_none_items_raises_error(self): + """Test that get_models() returning list with None items causes error""" + + class NoneItemsAdapter(OpenAIMixinImpl): + async def get_models(self) -> Iterable[str] | None: + # Return list with None items + return [None, "valid-model", None] # type: ignore + + mixin = NoneItemsAdapter() + + # Should raise ValidationError when creating Model with None identifier + with pytest.raises(Exception, match="Input should be a valid string"): + await mixin.list_models() + + async def test_embedding_models_from_custom_get_models_have_correct_type(self, mixin_with_custom_get_models): + """Test that embedding models from custom get_models() are properly typed as embedding""" + result = await mixin_with_custom_get_models.list_models() + + # Verify we have both LLM and embedding models + llm_models = [m for m in result if m.model_type == ModelType.llm] + embedding_models = [m for m in result if m.model_type == ModelType.embedding] + + assert len(llm_models) == 2 + assert len(embedding_models) == 1 + assert embedding_models[0].identifier == "custom-embedding" + + async def test_llm_models_from_custom_get_models_have_correct_type(self): + """Test that LLM models from custom get_models() are properly typed as llm""" + mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["gpt-4", "claude-3"]) + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 2 + for model in result: + assert model.model_type == ModelType.llm + + async def test_get_models_accepts_various_iterables(self): + """Test that get_models() accepts tuples, sets, generators, etc.""" + + # Test with tuple + class TupleGetModelsAdapter(OpenAIMixinImpl): + async def get_models(self) -> Iterable[str] | None: + return ("model-1", "model-2", "model-3") + + mixin = TupleGetModelsAdapter() + result = await mixin.list_models() + assert result is not None + assert len(result) == 3 + + # Test with generator + class GeneratorGetModelsAdapter(OpenAIMixinImpl): + async def get_models(self) -> Iterable[str] | None: + def gen(): + yield "gen-model-1" + yield "gen-model-2" + + return gen() + + mixin = GeneratorGetModelsAdapter() + result = await mixin.list_models() + assert result is not None + assert len(result) == 2 + + # Test with set (order may vary) + class SetGetModelsAdapter(OpenAIMixinImpl): + async def get_models(self) -> Iterable[str] | None: + return {"set-model-1", "set-model-2"} + + mixin = SetGetModelsAdapter() + result = await mixin.list_models() + assert result is not None + assert len(result) == 2 + + class TestOpenAIMixinProviderDataApiKey: """Test cases for provider_data_api_key_field functionality"""