chore: add provider-data-api-key support to openaimixin

This commit is contained in:
Matthew Farrellee 2025-10-01 14:14:24 -04:00
parent f7c5ef4ec0
commit 2be6ec2ad8
5 changed files with 116 additions and 17 deletions

View file

@ -59,7 +59,7 @@ class LiteLLMOpenAIMixin(
self,
litellm_provider_name: str,
api_key_from_config: str | None,
provider_data_api_key_field: str,
provider_data_api_key_field: str | None = None,
model_entries: list[ProviderModelEntry] | None = None,
openai_compat_api_base: str | None = None,
download_images: bool = False,
@ -70,7 +70,7 @@ class LiteLLMOpenAIMixin(
:param model_entries: The model entries to register.
:param api_key_from_config: The API key to use from the config.
:param provider_data_api_key_field: The field in the provider data that contains the API key.
:param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
:param litellm_provider_name: The name of the provider, used for model lookups.
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
:param download_images: Whether to download images and convert to base64 for message conversion.

View file

@ -63,7 +63,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
model_entries: list[ProviderModelEntry] | None = None,
allowed_models: list[str] | None = None,
):
self.allowed_models = allowed_models
self.allowed_models = allowed_models if allowed_models else []
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import localize_image_
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelRegistryHelper, ABC):
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -69,6 +70,9 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@abstractmethod
def get_api_key(self) -> str:
"""
@ -111,9 +115,28 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
Users can also provide the API key via the provider data header, which
is used instead of any config API key.
"""
api_key = self.get_api_key()
if self.provider_data_api_key_field:
provider_data = self.get_request_provider_data()
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
api_key = getattr(provider_data, self.provider_data_api_key_field)
if not api_key: # TODO: let get_api_key return None
raise ValueError(
"API key is not set. Please provide a valid API key in the "
"provider data header, e.g. x-llamastack-provider-data: "
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
"or in the provider config."
)
return AsyncOpenAI(
api_key=self.get_api_key(),
api_key=api_key,
base_url=self.get_base_url(),
**self.get_extra_client_params(),
)