diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 70d6bb278..3b110f21a 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,13 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import Iterable from typing import Any from databricks.sdk import WorkspaceClient -from llama_stack.apis.inference import ( - OpenAICompletion, -) +from llama_stack.apis.inference import OpenAICompletion from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -69,3 +68,11 @@ class DatabricksInferenceAdapter(OpenAIMixin): suffix: str | None = None, ) -> OpenAICompletion: raise NotImplementedError() + + async def list_provider_model_ids(self) -> Iterable[str]: + return [ + endpoint.name + for endpoint in WorkspaceClient( + host=self.config.url, token=self.get_api_key() + ).serving_endpoints.list() # TODO: this is not async + ] diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 06eba09f4..acca73800 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -48,7 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): - download_images: If True, downloads images and converts to base64 for providers that require it - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata - provider_data_api_key_field: Optional field name in provider data to look for API key - - get_models: Method to list available models from the provider + - list_provider_model_ids: Method to list available models from the provider - get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client Expected Dependencies: @@ -122,7 +122,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ return {} - async def get_models(self) -> Iterable[str] | None: + async def list_provider_model_ids(self) -> Iterable[str]: """ List available models from the provider. @@ -132,7 +132,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): :return: An iterable of model IDs or None if not implemented """ - return None + return [m.id async for m in self.client.models.list()] async def initialize(self) -> None: """ @@ -430,46 +430,42 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ self._model_cache = {} - # give subclasses a chance to provide custom model listing - models_ids = [] try: - if (iterable := await self.get_models()) is not None: # TODO: handle exceptions from get_models - models_ids = list(iterable) - logger.info( - f"Using {self.__class__.__name__}.get_models() implementation, received {len(models_ids)} models" - ) - for id_ in models_ids: - if not isinstance(id_, str): - raise ValueError(f"Model ID {id_} from get_models() is not a string") + iterable = await self.list_provider_model_ids() except Exception as e: - logger.error(f"{self.__class__.__name__}.get_models() failed with: {e}") + logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}") raise + if not hasattr(iterable, "__iter__"): + raise TypeError( + f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of " + f"strings, but returned {type(iterable).__name__}" + ) - if not models_ids: - models_ids = [m.id async for m in self.client.models.list()] + provider_models_ids = list(iterable) + logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models") - for m_id in models_ids: - if self.allowed_models and m_id not in self.allowed_models: - logger.info(f"Skipping model {m_id} as it is not in the allowed models list") + for provider_model_id in provider_models_ids: + if not isinstance(provider_model_id, str): + raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string") + if self.allowed_models and provider_model_id not in self.allowed_models: + logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue - if metadata := self.embedding_model_metadata.get(m_id): - # This is an embedding model - augment with metadata + if metadata := self.embedding_model_metadata.get(provider_model_id): model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m_id, - identifier=m_id, + provider_resource_id=provider_model_id, + identifier=provider_model_id, model_type=ModelType.embedding, metadata=metadata, ) else: - # This is an LLM model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m_id, - identifier=m_id, + provider_resource_id=provider_model_id, + identifier=provider_model_id, model_type=ModelType.llm, ) - self._model_cache[m_id] = model + self._model_cache[provider_model_id] = model return list(self._model_cache.values()) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 266c15f81..ac4c29fea 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -6,6 +6,7 @@ import json from collections.abc import Iterable +from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch import pytest @@ -502,20 +503,18 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl): return "default-base-url" -class OpenAIMixinWithCustomGetModels(OpenAIMixinImpl): - """Test implementation with custom get_models override""" +class CustomListProviderModelIdsImplementation(OpenAIMixinImpl): + """Test implementation with custom list_provider_model_ids override""" - def __init__(self, config, custom_model_ids): - super().__init__(config=config) - self._custom_model_ids = custom_model_ids + custom_model_ids: Any - async def get_models(self) -> Iterable[str] | None: + async def list_provider_model_ids(self) -> Iterable[str]: """Return custom model IDs list""" - return self._custom_model_ids + return self.custom_model_ids -class TestOpenAIMixinCustomGetModels: - """Test cases for custom get_models() implementation functionality""" +class TestOpenAIMixinCustomListProviderModelIds: + """Test cases for custom list_provider_model_ids() implementation functionality""" @pytest.fixture def custom_model_ids_list(self): @@ -523,42 +522,39 @@ class TestOpenAIMixinCustomGetModels: return ["custom-model-1", "custom-model-2", "custom-embedding"] @pytest.fixture - def mixin_with_custom_get_models(self, custom_model_ids_list): - """Create mixin instance with custom get_models implementation""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=custom_model_ids_list) - # Add embedding metadata to test that feature still works + def config(self): + """Create RemoteInferenceProviderConfig instance""" + return RemoteInferenceProviderConfig() + + @pytest.fixture + def adapter(self, custom_model_ids_list, config): + """Create mixin instance with custom list_provider_model_ids implementation""" + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list) mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} return mixin - async def test_custom_get_models_is_used(self, mixin_with_custom_get_models, custom_model_ids_list): - """Test that custom get_models() implementation is used instead of client.models.list()""" - result = await mixin_with_custom_get_models.list_models() + async def test_is_used(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()""" + result = await adapter.list_models() assert result is not None assert len(result) == 3 - # Verify all custom models are present - identifiers = {m.identifier for m in result} - assert "custom-model-1" in identifiers - assert "custom-model-2" in identifiers - assert "custom-embedding" in identifiers + assert set(custom_model_ids_list) == {m.identifier for m in result} - async def test_custom_get_models_populates_cache(self, mixin_with_custom_get_models): - """Test that custom get_models() results are cached""" - assert len(mixin_with_custom_get_models._model_cache) == 0 + async def test_populates_cache(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() results are cached""" + assert len(adapter._model_cache) == 0 - await mixin_with_custom_get_models.list_models() + await adapter.list_models() - assert len(mixin_with_custom_get_models._model_cache) == 3 - assert "custom-model-1" in mixin_with_custom_get_models._model_cache - assert "custom-model-2" in mixin_with_custom_get_models._model_cache - assert "custom-embedding" in mixin_with_custom_get_models._model_cache + assert set(custom_model_ids_list) == set(adapter._model_cache.keys()) - async def test_custom_get_models_respects_allowed_models(self): - """Test that custom get_models() respects allowed_models filtering""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=["model-1", "model-2", "model-3"]) + async def test_respects_allowed_models(self, config): + """Test that custom list_provider_model_ids() respects allowed_models filtering""" + mixin = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=["model-1", "model-2", "model-3"] + ) mixin.allowed_models = ["model-1"] result = await mixin.list_models() @@ -567,222 +563,76 @@ class TestOpenAIMixinCustomGetModels: assert len(result) == 1 assert result[0].identifier == "model-1" - async def test_custom_get_models_with_embedding_metadata(self, mixin_with_custom_get_models): - """Test that custom get_models() works with embedding_model_metadata""" - result = await mixin_with_custom_get_models.list_models() - - # Find the embedding model - embedding_model = next((m for m in result if m.identifier == "custom-embedding"), None) - assert embedding_model is not None - assert embedding_model.model_type == ModelType.embedding - assert embedding_model.metadata == {"embedding_dimension": 768, "context_length": 512} - - # Verify LLM models - llm_models = [m for m in result if m.model_type == ModelType.llm] - assert len(llm_models) == 2 - - async def test_custom_get_models_with_empty_list(self, mock_client_with_empty_models, mock_client_context): - """Test that custom get_models() handles empty list correctly""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=[]) - - # Empty list from get_models() falls back to client.models.list() - with mock_client_context(mixin, mock_client_with_empty_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 0 - assert len(mixin._model_cache) == 0 - - async def test_default_get_models_returns_none(self, mixin): - """Test that default get_models() implementation returns None""" - custom_models = await mixin.get_models() - assert custom_models is None - - async def test_fallback_to_client_when_get_models_returns_none( - self, mixin, mock_client_with_models, mock_client_context - ): - """Test that when get_models() returns None, falls back to client.models.list()""" - # Default get_models() returns None, so should use client - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 3 - mock_client_with_models.models.list.assert_called_once() - - async def test_custom_get_models_creates_proper_model_objects(self): - """Test that custom get_models() model IDs are converted to proper Model objects""" - config = RemoteInferenceProviderConfig() - model_ids = ["gpt-4", "gpt-3.5-turbo"] - mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=model_ids) + async def test_with_empty_list(self, config): + """Test that custom list_provider_model_ids() handles empty list correctly""" + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[]) result = await mixin.list_models() assert result is not None - assert len(result) == 2 + assert len(result) == 0 + assert len(mixin._model_cache) == 0 - for model in result: - assert isinstance(model, Model) - assert model.provider_id == "test-provider" - assert model.identifier in model_ids - assert model.provider_resource_id in model_ids - assert model.model_type == ModelType.llm - - async def test_custom_get_models_bypasses_client(self, mock_client_context): - """Test that providing get_models() means client.models.list() is NOT called""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=["model-1", "model-2"]) - - # Create a mock client that should NOT be called - mock_client = MagicMock() - mock_client.models.list = MagicMock(side_effect=AssertionError("client.models.list should not be called!")) - - with mock_client_context(mixin, mock_client): - result = await mixin.list_models() - - # Should succeed without calling client.models.list - assert result is not None - assert len(result) == 2 - mock_client.models.list.assert_not_called() - - async def test_get_models_wrong_type_raises_error(self): - """Test that get_models() returning non-string items results in an error""" - - class BadGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return list with non-string items - return [["nested", "list"], {"key": "value"}] # type: ignore - - config = RemoteInferenceProviderConfig() - mixin = BadGetModelsAdapter(config=config) - - # Should raise ValueError for non-string model ID - with pytest.raises(ValueError, match="Model ID .* from get_models\\(\\) is not a string"): + async def test_wrong_type_raises_error(self, config): + """Test that list_provider_model_ids() returning unhashable items results in an error""" + mixin = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=["valid-model", ["nested", "list"]] + ) + with pytest.raises(Exception, match="is not a string"): await mixin.list_models() - async def test_get_models_non_iterable_raises_error(self): - """Test that get_models() returning non-iterable type raises error""" - - class NonIterableGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return non-iterable type - return 42 # type: ignore - - config = RemoteInferenceProviderConfig() - mixin = NonIterableGetModelsAdapter(config=config) - - # Should raise TypeError when trying to convert to list - with pytest.raises(TypeError, match="'int' object is not iterable"): + mixin = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=[{"key": "value"}, "valid-model"] + ) + with pytest.raises(Exception, match="is not a string"): await mixin.list_models() - async def test_get_models_with_none_items_raises_error(self): - """Test that get_models() returning list with None items causes error""" - - class NoneItemsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return list with None items - return [None, "valid-model", None] # type: ignore - - config = RemoteInferenceProviderConfig() - mixin = NoneItemsAdapter(config=config) - - # Should raise ValueError for non-string model ID - with pytest.raises(ValueError, match="Model ID .* from get_models\\(\\) is not a string"): + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0]) + with pytest.raises(Exception, match="is not a string"): await mixin.list_models() - async def test_get_models_with_non_string_items_raises_error(self): - """Test that get_models() returning non-string items raises ValueError""" - - class NonStringItemsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Return list with non-string items (integers) - return ["valid-model", 123, "another-model"] # type: ignore - - config = RemoteInferenceProviderConfig() - mixin = NonStringItemsAdapter(config=config) - - # Should raise ValueError for non-string model ID - with pytest.raises(ValueError, match="Model ID 123 from get_models\\(\\) is not a string"): + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None]) + with pytest.raises(Exception, match="is not a string"): await mixin.list_models() - async def test_embedding_models_from_custom_get_models_have_correct_type(self, mixin_with_custom_get_models): - """Test that embedding models from custom get_models() are properly typed as embedding""" - result = await mixin_with_custom_get_models.list_models() + async def test_non_iterable_raises_error(self, config): + """Test that list_provider_model_ids() returning non-iterable type raises error""" + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42) - # Verify we have both LLM and embedding models - llm_models = [m for m in result if m.model_type == ModelType.llm] - embedding_models = [m for m in result if m.model_type == ModelType.embedding] + with pytest.raises( + TypeError, + match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int", + ): + await mixin.list_models() - assert len(llm_models) == 2 - assert len(embedding_models) == 1 - assert embedding_models[0].identifier == "custom-embedding" + async def test_accepts_various_iterables(self, config): + """Test that list_provider_model_ids() accepts tuples, sets, generators, etc.""" - async def test_llm_models_from_custom_get_models_have_correct_type(self): - """Test that LLM models from custom get_models() are properly typed as llm""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=["gpt-4", "claude-3"]) - - result = await mixin.list_models() - - assert result is not None - assert len(result) == 2 - for model in result: - assert model.model_type == ModelType.llm - - async def test_get_models_accepts_various_iterables(self): - """Test that get_models() accepts tuples, sets, generators, etc.""" - - # Test with tuple - class TupleGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - return ("model-1", "model-2", "model-3") - - config = RemoteInferenceProviderConfig() - mixin = TupleGetModelsAdapter(config=config) - result = await mixin.list_models() + tuples = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=("model-1", "model-2", "model-3") + ) + result = await tuples.list_models() assert result is not None assert len(result) == 3 - # Test with generator - class GeneratorGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: + class GeneratorAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str]: def gen(): yield "gen-model-1" yield "gen-model-2" return gen() - mixin = GeneratorGetModelsAdapter(config=config) + mixin = GeneratorAdapter(config=config) result = await mixin.list_models() assert result is not None assert len(result) == 2 - # Test with set (order may vary) - class SetGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - return {"set-model-1", "set-model-2"} - - mixin = SetGetModelsAdapter(config=config) - result = await mixin.list_models() + sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"}) + result = await sets.list_models() assert result is not None assert len(result) == 2 - async def test_get_models_exception_propagates(self): - """Test that when get_models() raises an exception, it propagates to the caller""" - - class FailingGetModelsAdapter(OpenAIMixinImpl): - async def get_models(self) -> Iterable[str] | None: - # Simulate an exception during custom model listing - raise RuntimeError("Failed to fetch custom models") - - config = RemoteInferenceProviderConfig() - mixin = FailingGetModelsAdapter(config=config) - - # Exception should propagate and not fall back to client.models.list() - with pytest.raises(RuntimeError, match="Failed to fetch custom models"): - await mixin.list_models() - class TestOpenAIMixinProviderDataApiKey: """Test cases for provider_data_api_key_field functionality"""