diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index a2621b81e..2ab4856b5 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,16 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import Iterable from typing import Any from databricks.sdk import WorkspaceClient from llama_stack.apis.inference import ( Inference, - Model, OpenAICompletion, ) -from llama_stack.apis.models import ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -72,31 +71,13 @@ class DatabricksInferenceAdapter( ) -> OpenAICompletion: raise NotImplementedError() - async def list_models(self) -> list[Model] | None: - self._model_cache = {} # from OpenAIMixin - ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async - endpoints = ws_client.serving_endpoints.list() - for endpoint in endpoints: - model = Model( - provider_id=self.__provider_id__, - provider_resource_id=endpoint.name, - identifier=endpoint.name, - ) - if endpoint.task == "llm/v1/chat": - model.model_type = ModelType.llm # this is redundant, but informative - elif endpoint.task == "llm/v1/embeddings": - if endpoint.name not in self.embedding_model_metadata: - logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.") - continue - model.model_type = ModelType.embedding - model.metadata = self.embedding_model_metadata[endpoint.name] - else: - logger.warning(f"Unknown model type, skipping: {endpoint}") - continue - - self._model_cache[endpoint.name] = model - - return list(self._model_cache.values()) + async def list_provider_model_ids(self) -> Iterable[str]: + return [ + endpoint.name + for endpoint in WorkspaceClient( + host=self.config.url, token=self.get_api_key() + ).serving_endpoints.list() # TODO: this is not async + ] async def should_refresh_models(self) -> bool: return False diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e..3239d814c 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -7,7 +7,7 @@ import base64 import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from typing import Any from openai import NOT_GIVEN, AsyncOpenAI @@ -111,6 +111,18 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): """ return {} + async def list_provider_model_ids(self) -> Iterable[str]: + """ + List available models from the provider. + + Child classes can override this method to provide a custom implementation + for listing models. The default implementation uses the AsyncOpenAI client + to list models from the OpenAI-compatible endpoint. + + :return: An iterable of model IDs or None if not implemented + """ + return [m.id async for m in self.client.models.list()] + @property def client(self) -> AsyncOpenAI: """ @@ -387,28 +399,36 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): """ self._model_cache = {} - async for m in self.client.models.list(): - if self.allowed_models and m.id not in self.allowed_models: - logger.info(f"Skipping model {m.id} as it is not in the allowed models list") + # give subclasses a chance to provide custom model listing + iterable = await self.list_provider_model_ids() + if not hasattr(iterable, "__iter__"): + raise TypeError( + f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of " + f"strings or None, but returned {type(iterable).__name__}" + ) + provider_models_ids = list(iterable) + logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models") + + for provider_model_id in provider_models_ids: + if self.allowed_models and provider_model_id not in self.allowed_models: + logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue - if metadata := self.embedding_model_metadata.get(m.id): - # This is an embedding model - augment with metadata + if metadata := self.embedding_model_metadata.get(provider_model_id): model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m.id, - identifier=m.id, + provider_resource_id=provider_model_id, + identifier=provider_model_id, model_type=ModelType.embedding, metadata=metadata, ) else: - # This is an LLM model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=m.id, - identifier=m.id, + provider_resource_id=provider_model_id, + identifier=provider_model_id, model_type=ModelType.llm, ) - self._model_cache[m.id] = model + self._model_cache[provider_model_id] = model return list(self._model_cache.values()) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 4856f510b..1110e1843 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +from collections.abc import Iterable from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch import pytest @@ -498,6 +499,129 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl): return "default-base-url" +class CustomListProviderModelIdsImplementation(OpenAIMixinImpl): + """Test implementation with custom list_provider_model_ids override""" + + def __init__(self, custom_model_ids): + self._custom_model_ids = custom_model_ids + + async def list_provider_model_ids(self) -> Iterable[str]: + """Return custom model IDs list""" + return self._custom_model_ids + + +class TestOpenAIMixinCustomListProviderModelIds: + """Test cases for custom list_provider_model_ids() implementation functionality""" + + @pytest.fixture + def custom_model_ids_list(self): + """Create a list of custom model ID strings""" + return ["custom-model-1", "custom-model-2", "custom-embedding"] + + @pytest.fixture + def adapter(self, custom_model_ids_list): + """Create mixin instance with custom list_provider_model_ids implementation""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=custom_model_ids_list) + mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} + return mixin + + async def test_is_used(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()""" + result = await adapter.list_models() + + assert result is not None + assert len(result) == 3 + + assert set(custom_model_ids_list) == {m.identifier for m in result} + + async def test_populates_cache(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() results are cached""" + assert len(adapter._model_cache) == 0 + + await adapter.list_models() + + assert set(custom_model_ids_list) == set(adapter._model_cache.keys()) + + async def test_respects_allowed_models(self): + """Test that custom list_provider_model_ids() respects allowed_models filtering""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=["model-1", "model-2", "model-3"]) + mixin.allowed_models = ["model-1"] + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 1 + assert result[0].identifier == "model-1" + + async def test_with_empty_list(self): + """Test that custom list_provider_model_ids() handles empty list correctly""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[]) + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 + assert len(mixin._model_cache) == 0 + + async def test_wrong_type_raises_error(self): + """Test that list_provider_model_ids() returning unhashable items results in an error""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[["nested", "list"], {"key": "value"}]) + + with pytest.raises(TypeError, match="unhashable type"): + await mixin.list_models() + + async def test_non_iterable_raises_error(self): + """Test that list_provider_model_ids() returning non-iterable type raises error""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=42) + + with pytest.raises( + TypeError, + match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int", + ): + await mixin.list_models() + + async def test_with_none_items_raises_error(self): + """Test that list_provider_model_ids() returning list with None items causes error""" + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[None, "valid-model", None]) + + with pytest.raises(Exception, match="Input should be a valid string"): + await mixin.list_models() + + async def test_accepts_various_iterables(self): + """Test that list_provider_model_ids() accepts tuples, sets, generators, etc.""" + + class TupleAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str] | None: + return ("model-1", "model-2", "model-3") + + mixin = TupleAdapter() + result = await mixin.list_models() + assert result is not None + assert len(result) == 3 + + class GeneratorAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str] | None: + def gen(): + yield "gen-model-1" + yield "gen-model-2" + + return gen() + + mixin = GeneratorAdapter() + result = await mixin.list_models() + assert result is not None + assert len(result) == 2 + + class SetAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str] | None: + return {"set-model-1", "set-model-2"} + + mixin = SetAdapter() + result = await mixin.list_models() + assert result is not None + assert len(result) == 2 + + class TestOpenAIMixinProviderDataApiKey: """Test cases for provider_data_api_key_field functionality"""