chore: add provider-data-api-key support to openaimixin (#3639)

# What does this PR do?

the LiteLLMOpenAIMixin provides support for reading key from provider
data (headers users send).

this adds the same functionality to the OpenAIMixin.

this is infrastructure for migrating providers.


## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-01 16:44:59 -04:00 committed by GitHub
parent 28bbbcf2c1
commit 4dbe0593f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 116 additions and 17 deletions

View file

@ -59,7 +59,7 @@ class LiteLLMOpenAIMixin(
self,
litellm_provider_name: str,
api_key_from_config: str | None,
provider_data_api_key_field: str,
provider_data_api_key_field: str | None = None,
model_entries: list[ProviderModelEntry] | None = None,
openai_compat_api_base: str | None = None,
download_images: bool = False,
@ -70,7 +70,7 @@ class LiteLLMOpenAIMixin(
:param model_entries: The model entries to register.
:param api_key_from_config: The API key to use from the config.
:param provider_data_api_key_field: The field in the provider data that contains the API key.
:param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
:param litellm_provider_name: The name of the provider, used for model lookups.
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
:param download_images: Whether to download images and convert to base64 for message conversion.

View file

@ -63,7 +63,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
model_entries: list[ProviderModelEntry] | None = None,
allowed_models: list[str] | None = None,
):
self.allowed_models = allowed_models
self.allowed_models = allowed_models if allowed_models else []
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import localize_image_
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelRegistryHelper, ABC):
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -69,6 +70,9 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@abstractmethod
def get_api_key(self) -> str:
"""
@ -111,9 +115,28 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
Users can also provide the API key via the provider data header, which
is used instead of any config API key.
"""
api_key = self.get_api_key()
if self.provider_data_api_key_field:
provider_data = self.get_request_provider_data()
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
api_key = getattr(provider_data, self.provider_data_api_key_field)
if not api_key: # TODO: let get_api_key return None
raise ValueError(
"API key is not set. Please provide a valid API key in the "
"provider data header, e.g. x-llamastack-provider-data: "
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
"or in the provider config."
)
return AsyncOpenAI(
api_key=self.get_api_key(),
api_key=api_key,
base_url=self.get_base_url(),
**self.get_extra_client_params(),
)