chore: give OpenAIMixin subcalsses a change to list models without leaking _model_cache details (#3682)

# What does this PR do?

close the _model_cache abstraction leak

## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-06 09:44:33 -04:00 committed by GitHub
parent f00bcd9561
commit 724dac498c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 164 additions and 39 deletions

View file

@ -4,16 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Iterable
from typing import Any from typing import Any
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
Model,
OpenAICompletion, OpenAICompletion,
) )
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -72,31 +71,13 @@ class DatabricksInferenceAdapter(
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_models(self) -> list[Model] | None: async def list_provider_model_ids(self) -> Iterable[str]:
self._model_cache = {} # from OpenAIMixin return [
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async endpoint.name
endpoints = ws_client.serving_endpoints.list() for endpoint in WorkspaceClient(
for endpoint in endpoints: host=self.config.url, token=self.get_api_key()
model = Model( ).serving_endpoints.list() # TODO: this is not async
provider_id=self.__provider_id__, ]
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
return False return False

View file

@ -7,7 +7,7 @@
import base64 import base64
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator from collections.abc import AsyncIterator, Iterable
from typing import Any from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI from openai import NOT_GIVEN, AsyncOpenAI
@ -111,6 +111,18 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
""" """
return {} return {}
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from the provider.
Child classes can override this method to provide a custom implementation
for listing models. The default implementation uses the AsyncOpenAI client
to list models from the OpenAI-compatible endpoint.
:return: An iterable of model IDs or None if not implemented
"""
return [m.id async for m in self.client.models.list()]
@property @property
def client(self) -> AsyncOpenAI: def client(self) -> AsyncOpenAI:
""" """
@ -387,28 +399,36 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
""" """
self._model_cache = {} self._model_cache = {}
async for m in self.client.models.list(): # give subclasses a chance to provide custom model listing
if self.allowed_models and m.id not in self.allowed_models: iterable = await self.list_provider_model_ids()
logger.info(f"Skipping model {m.id} as it is not in the allowed models list") if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings or None, but returned {type(iterable).__name__}"
)
provider_models_ids = list(iterable)
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
for provider_model_id in provider_models_ids:
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue continue
if metadata := self.embedding_model_metadata.get(m.id): if metadata := self.embedding_model_metadata.get(provider_model_id):
# This is an embedding model - augment with metadata
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id, provider_resource_id=provider_model_id,
identifier=m.id, identifier=provider_model_id,
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata=metadata, metadata=metadata,
) )
else: else:
# This is an LLM
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id, provider_resource_id=provider_model_id,
identifier=m.id, identifier=provider_model_id,
model_type=ModelType.llm, model_type=ModelType.llm,
) )
self._model_cache[m.id] = model self._model_cache[provider_model_id] = model
return list(self._model_cache.values()) return list(self._model_cache.values())

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from collections.abc import Iterable
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest import pytest
@ -498,6 +499,129 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl):
return "default-base-url" return "default-base-url"
class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
"""Test implementation with custom list_provider_model_ids override"""
def __init__(self, custom_model_ids):
self._custom_model_ids = custom_model_ids
async def list_provider_model_ids(self) -> Iterable[str]:
"""Return custom model IDs list"""
return self._custom_model_ids
class TestOpenAIMixinCustomListProviderModelIds:
"""Test cases for custom list_provider_model_ids() implementation functionality"""
@pytest.fixture
def custom_model_ids_list(self):
"""Create a list of custom model ID strings"""
return ["custom-model-1", "custom-model-2", "custom-embedding"]
@pytest.fixture
def adapter(self, custom_model_ids_list):
"""Create mixin instance with custom list_provider_model_ids implementation"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=custom_model_ids_list)
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin
async def test_is_used(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
result = await adapter.list_models()
assert result is not None
assert len(result) == 3
assert set(custom_model_ids_list) == {m.identifier for m in result}
async def test_populates_cache(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() results are cached"""
assert len(adapter._model_cache) == 0
await adapter.list_models()
assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
async def test_respects_allowed_models(self):
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=["model-1", "model-2", "model-3"])
mixin.allowed_models = ["model-1"]
result = await mixin.list_models()
assert result is not None
assert len(result) == 1
assert result[0].identifier == "model-1"
async def test_with_empty_list(self):
"""Test that custom list_provider_model_ids() handles empty list correctly"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[])
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
async def test_wrong_type_raises_error(self):
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[["nested", "list"], {"key": "value"}])
with pytest.raises(TypeError, match="unhashable type"):
await mixin.list_models()
async def test_non_iterable_raises_error(self):
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=42)
with pytest.raises(
TypeError,
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
):
await mixin.list_models()
async def test_with_none_items_raises_error(self):
"""Test that list_provider_model_ids() returning list with None items causes error"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[None, "valid-model", None])
with pytest.raises(Exception, match="Input should be a valid string"):
await mixin.list_models()
async def test_accepts_various_iterables(self):
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
class TupleAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str] | None:
return ("model-1", "model-2", "model-3")
mixin = TupleAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 3
class GeneratorAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str] | None:
def gen():
yield "gen-model-1"
yield "gen-model-2"
return gen()
mixin = GeneratorAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
class SetAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str] | None:
return {"set-model-1", "set-model-2"}
mixin = SetAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
class TestOpenAIMixinProviderDataApiKey: class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality""" """Test cases for provider_data_api_key_field functionality"""