mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: OpenAIMixin implements ModelsProtocolPrivate
This commit is contained in:
parent
ceca3c056f
commit
ad24a2c463
8 changed files with 243 additions and 11 deletions
|
@ -26,14 +26,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -73,6 +73,9 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -356,6 +359,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
usage=usage,
|
||||
)
|
||||
|
||||
###
|
||||
# ModelsProtocolPrivate implementation - provide model management functionality
|
||||
#
|
||||
# async def register_model(self, model: Model) -> Model: ...
|
||||
# async def unregister_model(self, model_id: str) -> None: ...
|
||||
#
|
||||
# async def list_models(self) -> list[Model] | None: ...
|
||||
# async def should_refresh_models(self) -> bool: ...
|
||||
##
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not await self.check_model_availability(model.provider_model_id):
|
||||
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
return None
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
@ -400,5 +421,7 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
|
||||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue