mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
feat: use SecretStr for inference provider auth credentials (#3724)
# What does this PR do? use SecretStr for OpenAIMixin providers - RemoteInferenceProviderConfig now has auth_credential: SecretStr - the default alias is api_key (most common name) - some providers override to use api_token (RunPod, vLLM, Databricks) - some providers exclude it (Ollama, TGI, Vertex AI) addresses #3517 ## Test Plan ci w/ new tests
This commit is contained in:
parent
6d8f61206e
commit
0066d986c5
57 changed files with 158 additions and 149 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import ModelType
|
||||
|
@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically from the provider",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Authentication credential for the provider",
|
||||
alias="api_key",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
|
|
|
@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
||||
This is an abstract base class that requires child classes to implement:
|
||||
- get_api_key(): Method to retrieve the API key
|
||||
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||
|
||||
The behavior of this class can be customized by child classes in the following ways:
|
||||
|
@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str | None:
|
||||
"""
|
||||
Get the API key.
|
||||
|
||||
This method must be implemented by child classes to provide the API key
|
||||
for authenticating with the OpenAI API or compatible endpoints.
|
||||
|
||||
:return: The API key as a string
|
||||
:return: The API key as a string, or None if not set
|
||||
"""
|
||||
pass
|
||||
if self.config.auth_credential is None:
|
||||
return None
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
@abstractmethod
|
||||
def get_base_url(self) -> str:
|
||||
|
@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
if not api_key: # TODO: let get_api_key return None
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
||||
"or in the provider config."
|
||||
)
|
||||
if not api_key:
|
||||
message = "API key not provided."
|
||||
if self.provider_data_api_key_field:
|
||||
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
|
||||
raise ValueError(message)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue