chore: OpenAIMixin implements ModelsProtocolPrivate (#3662)

# What does this PR do?

add ModelsProtocolPrivate methods to OpenAIMixin

this will allow providers using OpenAIMixin to use a common interface


## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-03 00:32:02 -04:00 committed by GitHub
parent 14a94e9894
commit 0a41c4ead0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 243 additions and 11 deletions

View file

@ -26,14 +26,14 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -73,6 +73,9 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
# automatically set by the resolver when instantiating the provider
__provider_id__: str
@abstractmethod
def get_api_key(self) -> str:
"""
@ -356,6 +359,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
usage=usage,
)
###
# ModelsProtocolPrivate implementation - provide model management functionality
#
# async def register_model(self, model: Model) -> Model: ...
# async def unregister_model(self, model_id: str) -> None: ...
#
# async def list_models(self) -> list[Model] | None: ...
# async def should_refresh_models(self) -> bool: ...
##
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_model_id):
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
return model
async def unregister_model(self, model_id: str) -> None:
return None
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
@ -400,5 +421,7 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
"""
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def should_refresh_models(self) -> bool:
return False