mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
chore: add provider-data-api-key support to openaimixin (#3639)
# What does this PR do? the LiteLLMOpenAIMixin provides support for reading key from provider data (headers users send). this adds the same functionality to the OpenAIMixin. this is infrastructure for migrating providers. ## Test Plan ci w/ new tests
This commit is contained in:
parent
28bbbcf2c1
commit
4dbe0593f9
5 changed files with 116 additions and 17 deletions
|
@ -59,7 +59,7 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
provider_data_api_key_field: str | None = None,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
openai_compat_api_base: str | None = None,
|
||||
download_images: bool = False,
|
||||
|
@ -70,7 +70,7 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
:param model_entries: The model entries to register.
|
||||
:param api_key_from_config: The API key to use from the config.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
|
||||
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||
|
|
|
@ -63,7 +63,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
allowed_models: list[str] | None = None,
|
||||
):
|
||||
self.allowed_models = allowed_models
|
||||
self.allowed_models = allowed_models if allowed_models else []
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import localize_image_
|
|||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -69,6 +70,9 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -111,9 +115,28 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
|
||||
Uses the abstract methods get_api_key() and get_base_url() which must be
|
||||
implemented by child classes.
|
||||
|
||||
Users can also provide the API key via the provider data header, which
|
||||
is used instead of any config API key.
|
||||
"""
|
||||
|
||||
api_key = self.get_api_key()
|
||||
|
||||
if self.provider_data_api_key_field:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
if not api_key: # TODO: let get_api_key return None
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
||||
"or in the provider config."
|
||||
)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
api_key=api_key,
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue