mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: add provider-data-api-key support to openaimixin
This commit is contained in:
parent
f7c5ef4ec0
commit
2be6ec2ad8
5 changed files with 116 additions and 17 deletions
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import localize_image_
|
|||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -69,6 +70,9 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -111,9 +115,28 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
|
||||
Uses the abstract methods get_api_key() and get_base_url() which must be
|
||||
implemented by child classes.
|
||||
|
||||
Users can also provide the API key via the provider data header, which
|
||||
is used instead of any config API key.
|
||||
"""
|
||||
|
||||
api_key = self.get_api_key()
|
||||
|
||||
if self.provider_data_api_key_field:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
if not api_key: # TODO: let get_api_key return None
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
||||
"or in the provider config."
|
||||
)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
api_key=api_key,
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue