chore: give OpenAIMixin subcalsses a change to list models without leaking _model_cache details

This commit is contained in:
Matthew Farrellee 2025-10-04 08:34:23 -04:00
parent f176196fba
commit c465472e42
3 changed files with 286 additions and 36 deletions

View file

@ -13,7 +13,6 @@ from llama_stack.apis.inference import (
Model, Model,
OpenAICompletion, OpenAICompletion,
) )
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -72,31 +71,13 @@ class DatabricksInferenceAdapter(
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_models(self) -> list[Model] | None: async def get_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin return [
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async endpoint.name
endpoints = ws_client.serving_endpoints.list() for endpoint in WorkspaceClient(
for endpoint in endpoints: host=self.config.url, token=self.get_api_key()
model = Model( ).serving_endpoints.list() # TODO: this is not async
provider_id=self.__provider_id__, ]
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
return False return False

View file

@ -7,7 +7,7 @@
import base64 import base64
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator from collections.abc import AsyncIterator, Iterable
from typing import Any from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI from openai import NOT_GIVEN, AsyncOpenAI
@ -111,6 +111,18 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
""" """
return {} return {}
async def get_models(self) -> Iterable[str] | None:
"""
List available models from the provider.
Child classes can override this method to provide a custom implementation
for listing models. The default implementation uses the AsyncOpenAI client
to list models from the OpenAI-compatible endpoint.
:return: An iterable of model IDs or None if not implemented
"""
return None
@property @property
def client(self) -> AsyncOpenAI: def client(self) -> AsyncOpenAI:
""" """
@ -387,16 +399,30 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
""" """
self._model_cache = {} self._model_cache = {}
async for m in self.client.models.list(): # give subclasses a chance to provide custom model listing
if self.allowed_models and m.id not in self.allowed_models: if (iterable := await self.get_models()) is not None:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list") if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.get_models() must return an iterable of "
f"strings or None, but returned {type(iterable).__name__}"
)
models_ids = list(iterable)
logger.info(
f"Using {self.__class__.__name__}.get_models() implementation, received {len(models_ids)} models"
)
else:
models_ids = [m.id async for m in self.client.models.list()]
for m_id in models_ids:
if self.allowed_models and m_id not in self.allowed_models:
logger.info(f"Skipping model {m_id} as it is not in the allowed models list")
continue continue
if metadata := self.embedding_model_metadata.get(m.id): if metadata := self.embedding_model_metadata.get(m_id):
# This is an embedding model - augment with metadata # This is an embedding model - augment with metadata
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id, provider_resource_id=m_id,
identifier=m.id, identifier=m_id,
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata=metadata, metadata=metadata,
) )
@ -404,11 +430,11 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
# This is an LLM # This is an LLM
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id, provider_resource_id=m_id,
identifier=m.id, identifier=m_id,
model_type=ModelType.llm, model_type=ModelType.llm,
) )
self._model_cache[m.id] = model self._model_cache[m_id] = model
return list(self._model_cache.values()) return list(self._model_cache.values())

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from collections.abc import Iterable
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest import pytest
@ -498,6 +499,248 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl):
return "default-base-url" return "default-base-url"
class OpenAIMixinWithCustomGetModels(OpenAIMixinImpl):
"""Test implementation with custom get_models override"""
def __init__(self, custom_model_ids):
self._custom_model_ids = custom_model_ids
async def get_models(self) -> Iterable[str] | None:
"""Return custom model IDs list"""
return self._custom_model_ids
class TestOpenAIMixinCustomGetModels:
"""Test cases for custom get_models() implementation functionality"""
@pytest.fixture
def custom_model_ids_list(self):
"""Create a list of custom model ID strings"""
return ["custom-model-1", "custom-model-2", "custom-embedding"]
@pytest.fixture
def mixin_with_custom_get_models(self, custom_model_ids_list):
"""Create mixin instance with custom get_models implementation"""
mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=custom_model_ids_list)
# Add embedding metadata to test that feature still works
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin
async def test_custom_get_models_is_used(self, mixin_with_custom_get_models, custom_model_ids_list):
"""Test that custom get_models() implementation is used instead of client.models.list()"""
result = await mixin_with_custom_get_models.list_models()
assert result is not None
assert len(result) == 3
# Verify all custom models are present
identifiers = {m.identifier for m in result}
assert "custom-model-1" in identifiers
assert "custom-model-2" in identifiers
assert "custom-embedding" in identifiers
async def test_custom_get_models_populates_cache(self, mixin_with_custom_get_models):
"""Test that custom get_models() results are cached"""
assert len(mixin_with_custom_get_models._model_cache) == 0
await mixin_with_custom_get_models.list_models()
assert len(mixin_with_custom_get_models._model_cache) == 3
assert "custom-model-1" in mixin_with_custom_get_models._model_cache
assert "custom-model-2" in mixin_with_custom_get_models._model_cache
assert "custom-embedding" in mixin_with_custom_get_models._model_cache
async def test_custom_get_models_respects_allowed_models(self):
"""Test that custom get_models() respects allowed_models filtering"""
mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["model-1", "model-2", "model-3"])
mixin.allowed_models = ["model-1"]
result = await mixin.list_models()
assert result is not None
assert len(result) == 1
assert result[0].identifier == "model-1"
async def test_custom_get_models_with_embedding_metadata(self, mixin_with_custom_get_models):
"""Test that custom get_models() works with embedding_model_metadata"""
result = await mixin_with_custom_get_models.list_models()
# Find the embedding model
embedding_model = next((m for m in result if m.identifier == "custom-embedding"), None)
assert embedding_model is not None
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 768, "context_length": 512}
# Verify LLM models
llm_models = [m for m in result if m.model_type == ModelType.llm]
assert len(llm_models) == 2
async def test_custom_get_models_with_empty_list(self):
"""Test that custom get_models() handles empty list correctly"""
mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=[])
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
async def test_default_get_models_returns_none(self, mixin):
"""Test that default get_models() implementation returns None"""
custom_models = await mixin.get_models()
assert custom_models is None
async def test_fallback_to_client_when_get_models_returns_none(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that when get_models() returns None, falls back to client.models.list()"""
# Default get_models() returns None, so should use client
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
assert result is not None
assert len(result) == 3
mock_client_with_models.models.list.assert_called_once()
async def test_custom_get_models_creates_proper_model_objects(self):
"""Test that custom get_models() model IDs are converted to proper Model objects"""
model_ids = ["gpt-4", "gpt-3.5-turbo"]
mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=model_ids)
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
for model in result:
assert isinstance(model, Model)
assert model.provider_id == "test-provider"
assert model.identifier in model_ids
assert model.provider_resource_id in model_ids
assert model.model_type == ModelType.llm
async def test_custom_get_models_bypasses_client(self, mock_client_context):
"""Test that providing get_models() means client.models.list() is NOT called"""
mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["model-1", "model-2"])
# Create a mock client that should NOT be called
mock_client = MagicMock()
mock_client.models.list = MagicMock(side_effect=AssertionError("client.models.list should not be called!"))
with mock_client_context(mixin, mock_client):
result = await mixin.list_models()
# Should succeed without calling client.models.list
assert result is not None
assert len(result) == 2
mock_client.models.list.assert_not_called()
async def test_get_models_wrong_type_raises_error(self):
"""Test that get_models() returning unhashable items results in an error"""
class BadGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Return list with unhashable items
return [["nested", "list"], {"key": "value"}] # type: ignore
mixin = BadGetModelsAdapter()
# Should raise TypeError when trying to use unhashable items as dict keys
with pytest.raises(TypeError, match="unhashable type"):
await mixin.list_models()
async def test_get_models_non_iterable_raises_error(self):
"""Test that get_models() returning non-iterable type raises error"""
class NonIterableGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Return non-iterable type
return 42 # type: ignore
mixin = NonIterableGetModelsAdapter()
# Should raise TypeError with custom error message
with pytest.raises(
TypeError,
match=r"Failed to list models: NonIterableGetModelsAdapter\.get_models\(\) must return an iterable.*but returned int",
):
await mixin.list_models()
async def test_get_models_with_none_items_raises_error(self):
"""Test that get_models() returning list with None items causes error"""
class NoneItemsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
# Return list with None items
return [None, "valid-model", None] # type: ignore
mixin = NoneItemsAdapter()
# Should raise ValidationError when creating Model with None identifier
with pytest.raises(Exception, match="Input should be a valid string"):
await mixin.list_models()
async def test_embedding_models_from_custom_get_models_have_correct_type(self, mixin_with_custom_get_models):
"""Test that embedding models from custom get_models() are properly typed as embedding"""
result = await mixin_with_custom_get_models.list_models()
# Verify we have both LLM and embedding models
llm_models = [m for m in result if m.model_type == ModelType.llm]
embedding_models = [m for m in result if m.model_type == ModelType.embedding]
assert len(llm_models) == 2
assert len(embedding_models) == 1
assert embedding_models[0].identifier == "custom-embedding"
async def test_llm_models_from_custom_get_models_have_correct_type(self):
"""Test that LLM models from custom get_models() are properly typed as llm"""
mixin = OpenAIMixinWithCustomGetModels(custom_model_ids=["gpt-4", "claude-3"])
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
for model in result:
assert model.model_type == ModelType.llm
async def test_get_models_accepts_various_iterables(self):
"""Test that get_models() accepts tuples, sets, generators, etc."""
# Test with tuple
class TupleGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
return ("model-1", "model-2", "model-3")
mixin = TupleGetModelsAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 3
# Test with generator
class GeneratorGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
def gen():
yield "gen-model-1"
yield "gen-model-2"
return gen()
mixin = GeneratorGetModelsAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
# Test with set (order may vary)
class SetGetModelsAdapter(OpenAIMixinImpl):
async def get_models(self) -> Iterable[str] | None:
return {"set-model-1", "set-model-2"}
mixin = SetGetModelsAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
class TestOpenAIMixinProviderDataApiKey: class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality""" """Test cases for provider_data_api_key_field functionality"""