Merge branch 'refactor-list-models' into make-openaimix-pydantic

This commit is contained in:
Matthew Farrellee 2025-10-06 06:58:59 -04:00
commit 6fa23d816f
3 changed files with 101 additions and 248 deletions

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Iterable
from typing import Any from typing import Any
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import ( from llama_stack.apis.inference import OpenAICompletion
OpenAICompletion,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -69,3 +68,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -48,7 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
- download_images: If True, downloads images and converts to base64 for providers that require it - download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- provider_data_api_key_field: Optional field name in provider data to look for API key - provider_data_api_key_field: Optional field name in provider data to look for API key
- get_models: Method to list available models from the provider - list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client - get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
Expected Dependencies: Expected Dependencies:
@ -122,7 +122,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
return {} return {}
async def get_models(self) -> Iterable[str] | None: async def list_provider_model_ids(self) -> Iterable[str]:
""" """
List available models from the provider. List available models from the provider.
@ -132,7 +132,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
:return: An iterable of model IDs or None if not implemented :return: An iterable of model IDs or None if not implemented
""" """
return None return [m.id async for m in self.client.models.list()]
async def initialize(self) -> None: async def initialize(self) -> None:
""" """
@ -430,46 +430,42 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
self._model_cache = {} self._model_cache = {}
# give subclasses a chance to provide custom model listing
models_ids = []
try: try:
if (iterable := await self.get_models()) is not None: # TODO: handle exceptions from get_models iterable = await self.list_provider_model_ids()
models_ids = list(iterable)
logger.info(
f"Using {self.__class__.__name__}.get_models() implementation, received {len(models_ids)} models"
)
for id_ in models_ids:
if not isinstance(id_, str):
raise ValueError(f"Model ID {id_} from get_models() is not a string")
except Exception as e: except Exception as e:
logger.error(f"{self.__class__.__name__}.get_models() failed with: {e}") logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
raise raise
if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings, but returned {type(iterable).__name__}"
)
if not models_ids: provider_models_ids = list(iterable)
models_ids = [m.id async for m in self.client.models.list()] logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
for m_id in models_ids: for provider_model_id in provider_models_ids:
if self.allowed_models and m_id not in self.allowed_models: if not isinstance(provider_model_id, str):
logger.info(f"Skipping model {m_id} as it is not in the allowed models list") raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue continue
if metadata := self.embedding_model_metadata.get(m_id): if metadata := self.embedding_model_metadata.get(provider_model_id):
# This is an embedding model - augment with metadata
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m_id, provider_resource_id=provider_model_id,
identifier=m_id, identifier=provider_model_id,
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata=metadata, metadata=metadata,
) )
else: else:
# This is an LLM
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m_id, provider_resource_id=provider_model_id,
identifier=m_id, identifier=provider_model_id,
model_type=ModelType.llm, model_type=ModelType.llm,
) )
self._model_cache[m_id] = model self._model_cache[provider_model_id] = model
return list(self._model_cache.values()) return list(self._model_cache.values())

View file

@ -6,6 +6,7 @@
import json import json
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest import pytest
@ -502,20 +503,18 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl):
return "default-base-url" return "default-base-url"
class OpenAIMixinWithCustomGetModels(OpenAIMixinImpl): class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
"""Test implementation with custom get_models override""" """Test implementation with custom list_provider_model_ids override"""
def __init__(self, config, custom_model_ids): custom_model_ids: Any
super().__init__(config=config)
self._custom_model_ids = custom_model_ids
async def get_models(self) -> Iterable[str] | None: async def list_provider_model_ids(self) -> Iterable[str]:
"""Return custom model IDs list""" """Return custom model IDs list"""
return self._custom_model_ids return self.custom_model_ids
class TestOpenAIMixinCustomGetModels: class TestOpenAIMixinCustomListProviderModelIds:
"""Test cases for custom get_models() implementation functionality""" """Test cases for custom list_provider_model_ids() implementation functionality"""
@pytest.fixture @pytest.fixture
def custom_model_ids_list(self): def custom_model_ids_list(self):
@ -523,42 +522,39 @@ class TestOpenAIMixinCustomGetModels:
return ["custom-model-1", "custom-model-2", "custom-embedding"] return ["custom-model-1", "custom-model-2", "custom-embedding"]
@pytest.fixture @pytest.fixture
def mixin_with_custom_get_models(self, custom_model_ids_list): def config(self):
"""Create mixin instance with custom get_models implementation""" """Create RemoteInferenceProviderConfig instance"""
config = RemoteInferenceProviderConfig() return RemoteInferenceProviderConfig()
mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=custom_model_ids_list)
# Add embedding metadata to test that feature still works @pytest.fixture
def adapter(self, custom_model_ids_list, config):
"""Create mixin instance with custom list_provider_model_ids implementation"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list)
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin return mixin
async def test_custom_get_models_is_used(self, mixin_with_custom_get_models, custom_model_ids_list): async def test_is_used(self, adapter, custom_model_ids_list):
"""Test that custom get_models() implementation is used instead of client.models.list()""" """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
result = await mixin_with_custom_get_models.list_models() result = await adapter.list_models()
assert result is not None assert result is not None
assert len(result) == 3 assert len(result) == 3
# Verify all custom models are present assert set(custom_model_ids_list) == {m.identifier for m in result}
identifiers = {m.identifier for m in result}
assert "custom-model-1" in identifiers
assert "custom-model-2" in identifiers
assert "custom-embedding" in identifiers
async def test_custom_get_models_populates_cache(self, mixin_with_custom_get_models): async def test_populates_cache(self, adapter, custom_model_ids_list):
"""Test that custom get_models() results are cached""" """Test that custom list_provider_model_ids() results are cached"""
assert len(mixin_with_custom_get_models._model_cache) == 0 assert len(adapter._model_cache) == 0
await mixin_with_custom_get_models.list_models() await adapter.list_models()
assert len(mixin_with_custom_get_models._model_cache) == 3 assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
assert "custom-model-1" in mixin_with_custom_get_models._model_cache
assert "custom-model-2" in mixin_with_custom_get_models._model_cache
assert "custom-embedding" in mixin_with_custom_get_models._model_cache
async def test_custom_get_models_respects_allowed_models(self): async def test_respects_allowed_models(self, config):
"""Test that custom get_models() respects allowed_models filtering""" """Test that custom list_provider_model_ids() respects allowed_models filtering"""
config = RemoteInferenceProviderConfig() mixin = CustomListProviderModelIdsImplementation(
mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=["model-1", "model-2", "model-3"]) config=config, custom_model_ids=["model-1", "model-2", "model-3"]
)
mixin.allowed_models = ["model-1"] mixin.allowed_models = ["model-1"]
result = await mixin.list_models() result = await mixin.list_models()
@ -567,222 +563,76 @@ class TestOpenAIMixinCustomGetModels:
assert len(result) == 1 assert len(result) == 1
assert result[0].identifier == "model-1" assert result[0].identifier == "model-1"
async def test_custom_get_models_with_embedding_metadata(self, mixin_with_custom_get_models): async def test_with_empty_list(self, config):
"""Test that custom get_models() works with embedding_model_metadata""" """Test that custom list_provider_model_ids() handles empty list correctly"""
result = await mixin_with_custom_get_models.list_models() mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[])
# Find the embedding model
embedding_model = next((m for m in result if m.identifier == "custom-embedding"), None)
assert embedding_model is not None
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 768, "context_length": 512}
# Verify LLM models
llm_models = [m for m in result if m.model_type == ModelType.llm]
assert len(llm_models) == 2
async def test_custom_get_models_with_empty_list(self, mock_client_with_empty_models, mock_client_context):
"""Test that custom get_models() handles empty list correctly"""
config = RemoteInferenceProviderConfig()
mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=[])
# Empty list from get_models() falls back to client.models.list()
with mock_client_context(mixin, mock_client_with_empty_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
async def test_default_get_models_returns_none(self, mixin):
"""Test that default get_models() implementation returns None"""
custom_models = await mixin.get_models()
assert custom_models is None
async def test_fallback_to_client_when_get_models_returns_none(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that when get_models() returns None, falls back to client.models.list()"""
# Default get_models() returns None, so should use client
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 3
mock_client_with_models.models.list.assert_called_once()
async def test_custom_get_models_creates_proper_model_objects(self):
"""Test that custom get_models() model IDs are converted to proper Model objects"""
config = RemoteInferenceProviderConfig()
model_ids = ["gpt-4", "gpt-3.5-turbo"]
mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=model_ids)
result = await mixin.list_models() result = await mixin.list_models()
assert result is not None assert result is not None
assert len(result) == 2 assert len(result) == 0
assert len(mixin._model_cache) == 0
for model in result: async def test_wrong_type_raises_error(self, config):
assert isinstance(model, Model) """Test that list_provider_model_ids() returning unhashable items results in an error"""
assert model.provider_id == "test-provider" mixin = CustomListProviderModelIdsImplementation(
assert model.identifier in model_ids config=config, custom_model_ids=["valid-model", ["nested", "list"]]
assert model.provider_resource_id in model_ids )
assert model.model_type == ModelType.llm with pytest.raises(Exception, match="is not a string"):
async def test_custom_get_models_bypasses_client(self, mock_client_context):
"""Test that providing get_models() means client.models.list() is NOT called"""
config = RemoteInferenceProviderConfig()
mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=["model-1", "model-2"])
# Create a mock client that should NOT be called
mock_client = MagicMock()
mock_client.models.list = MagicMock(side_effect=AssertionError("client.models.list should not be called!"))
with mock_client_context(mixin, mock_client):
result = await mixin.list_models()
# Should succeed without calling client.models.list
assert result is not None
assert len(result) == 2
mock_client.models.list.assert_not_called()
async def test_get_models_wrong_type_raises_error(self):
"""Test that get_models() returning non-string items results in an error"""
class BadGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Return list with non-string items
return [["nested", "list"], {"key": "value"}] # type: ignore
config = RemoteInferenceProviderConfig()
mixin = BadGetModelsAdapter(config=config)
# Should raise ValueError for non-string model ID
with pytest.raises(ValueError, match="Model ID .* from get_models\\(\\) is not a string"):
await mixin.list_models() await mixin.list_models()
async def test_get_models_non_iterable_raises_error(self): mixin = CustomListProviderModelIdsImplementation(
"""Test that get_models() returning non-iterable type raises error""" config=config, custom_model_ids=[{"key": "value"}, "valid-model"]
)
class NonIterableGetModelsAdapter(OpenAIMixinImpl): with pytest.raises(Exception, match="is not a string"):
async def get_models(self) -> Iterable[str] | None:
# Return non-iterable type
return 42 # type: ignore
config = RemoteInferenceProviderConfig()
mixin = NonIterableGetModelsAdapter(config=config)
# Should raise TypeError when trying to convert to list
with pytest.raises(TypeError, match="'int' object is not iterable"):
await mixin.list_models() await mixin.list_models()
async def test_get_models_with_none_items_raises_error(self): mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0])
"""Test that get_models() returning list with None items causes error""" with pytest.raises(Exception, match="is not a string"):
class NoneItemsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Return list with None items
return [None, "valid-model", None] # type: ignore
config = RemoteInferenceProviderConfig()
mixin = NoneItemsAdapter(config=config)
# Should raise ValueError for non-string model ID
with pytest.raises(ValueError, match="Model ID .* from get_models\\(\\) is not a string"):
await mixin.list_models() await mixin.list_models()
async def test_get_models_with_non_string_items_raises_error(self): mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None])
"""Test that get_models() returning non-string items raises ValueError""" with pytest.raises(Exception, match="is not a string"):
class NonStringItemsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Return list with non-string items (integers)
return ["valid-model", 123, "another-model"] # type: ignore
config = RemoteInferenceProviderConfig()
mixin = NonStringItemsAdapter(config=config)
# Should raise ValueError for non-string model ID
with pytest.raises(ValueError, match="Model ID 123 from get_models\\(\\) is not a string"):
await mixin.list_models() await mixin.list_models()
async def test_embedding_models_from_custom_get_models_have_correct_type(self, mixin_with_custom_get_models): async def test_non_iterable_raises_error(self, config):
"""Test that embedding models from custom get_models() are properly typed as embedding""" """Test that list_provider_model_ids() returning non-iterable type raises error"""
result = await mixin_with_custom_get_models.list_models() mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42)
# Verify we have both LLM and embedding models with pytest.raises(
llm_models = [m for m in result if m.model_type == ModelType.llm] TypeError,
embedding_models = [m for m in result if m.model_type == ModelType.embedding] match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
):
await mixin.list_models()
assert len(llm_models) == 2 async def test_accepts_various_iterables(self, config):
assert len(embedding_models) == 1 """Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
assert embedding_models[0].identifier == "custom-embedding"
async def test_llm_models_from_custom_get_models_have_correct_type(self): tuples = CustomListProviderModelIdsImplementation(
"""Test that LLM models from custom get_models() are properly typed as llm""" config=config, custom_model_ids=("model-1", "model-2", "model-3")
config = RemoteInferenceProviderConfig() )
mixin = OpenAIMixinWithCustomGetModels(config=config, custom_model_ids=["gpt-4", "claude-3"]) result = await tuples.list_models()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
for model in result:
assert model.model_type == ModelType.llm
async def test_get_models_accepts_various_iterables(self):
"""Test that get_models() accepts tuples, sets, generators, etc."""
# Test with tuple
class TupleGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
return ("model-1", "model-2", "model-3")
config = RemoteInferenceProviderConfig()
mixin = TupleGetModelsAdapter(config=config)
result = await mixin.list_models()
assert result is not None assert result is not None
assert len(result) == 3 assert len(result) == 3
# Test with generator class GeneratorAdapter(OpenAIMixinImpl):
class GeneratorGetModelsAdapter(OpenAIMixinImpl): async def list_provider_model_ids(self) -> Iterable[str]:
async def get_models(self) -> Iterable[str] | None:
def gen(): def gen():
yield "gen-model-1" yield "gen-model-1"
yield "gen-model-2" yield "gen-model-2"
return gen() return gen()
mixin = GeneratorGetModelsAdapter(config=config) mixin = GeneratorAdapter(config=config)
result = await mixin.list_models() result = await mixin.list_models()
assert result is not None assert result is not None
assert len(result) == 2 assert len(result) == 2
# Test with set (order may vary) sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"})
class SetGetModelsAdapter(OpenAIMixinImpl): result = await sets.list_models()
async def get_models(self) -> Iterable[str] | None:
return {"set-model-1", "set-model-2"}
mixin = SetGetModelsAdapter(config=config)
result = await mixin.list_models()
assert result is not None assert result is not None
assert len(result) == 2 assert len(result) == 2
async def test_get_models_exception_propagates(self):
"""Test that when get_models() raises an exception, it propagates to the caller"""
class FailingGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Simulate an exception during custom model listing
raise RuntimeError("Failed to fetch custom models")
config = RemoteInferenceProviderConfig()
mixin = FailingGetModelsAdapter(config=config)
# Exception should propagate and not fall back to client.models.list()
with pytest.raises(RuntimeError, match="Failed to fetch custom models"):
await mixin.list_models()
class TestOpenAIMixinProviderDataApiKey: class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality""" """Test cases for provider_data_api_key_field functionality"""