mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
feat: use SecretStr for inference provider auth credentials (#3724)
# What does this PR do? use SecretStr for OpenAIMixin providers - RemoteInferenceProviderConfig now has auth_credential: SecretStr - the default alias is api_key (most common name) - some providers override to use api_token (RunPod, vLLM, Databricks) - some providers exclude it (Ollama, TGI, Vertex AI) addresses #3517 ## Test Plan ci w/ new tests
This commit is contained in:
parent
6d8f61206e
commit
0066d986c5
57 changed files with 158 additions and 149 deletions
|
@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
|||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
|
|
|
@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
|
|
@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
|
|
|
@ -15,9 +15,6 @@ from .config import CerebrasImplConfig
|
|||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -18,8 +18,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
|||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None), # type: ignore[arg-type]
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
|
|
|
@ -27,9 +27,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
|
|
@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -17,8 +17,5 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
|
|
@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
description="The URL for the Groq AI server",
|
||||
|
|
|
@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
|
|
@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
description="The URL for the Llama API server",
|
||||
|
|
|
@ -21,9 +21,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
Llama API Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
|
|
|
@ -49,7 +49,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.api_key:
|
||||
if not self.config.auth_credential:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
|
@ -60,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
if not _is_nvidia_hosted(self.config):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
return None
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
|
@ -6,12 +6,16 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
return self._clients[loop]
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO_KEY"
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
|
|
@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
|
|
|
@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
|||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
|||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
|
|
|
@ -24,10 +24,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
|
||||
config: RunpodImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
return self.config.api_token
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
|||
SambaNova Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
|
|
@ -30,7 +30,7 @@ class _HfAdapter(OpenAIMixin):
|
|||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -39,15 +39,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
|
||||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_api_key(self):
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
default="fake",
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool | str = Field(
|
||||
|
|
|
@ -38,8 +38,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token or ""
|
||||
def get_api_key(self) -> str | None:
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
# This seems like it should be required, but none of the other remote inference
|
||||
# providers require it, so this is optional here too for consistency.
|
||||
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai API key",
|
||||
)
|
||||
# As above, this is optional here too for consistency.
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
|
|
|
@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
)
|
||||
self.available_models = None
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
|
|||
|
||||
|
||||
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: None = Field(default=None, exclude=True)
|
||||
aws_access_key_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import ModelType
|
||||
|
@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically from the provider",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Authentication credential for the provider",
|
||||
alias="api_key",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
|
|
|
@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
||||
This is an abstract base class that requires child classes to implement:
|
||||
- get_api_key(): Method to retrieve the API key
|
||||
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||
|
||||
The behavior of this class can be customized by child classes in the following ways:
|
||||
|
@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str | None:
|
||||
"""
|
||||
Get the API key.
|
||||
|
||||
This method must be implemented by child classes to provide the API key
|
||||
for authenticating with the OpenAI API or compatible endpoints.
|
||||
|
||||
:return: The API key as a string
|
||||
:return: The API key as a string, or None if not set
|
||||
"""
|
||||
pass
|
||||
if self.config.auth_credential is None:
|
||||
return None
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
@abstractmethod
|
||||
def get_base_url(self) -> str:
|
||||
|
@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
if not api_key: # TODO: let get_api_key return None
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
||||
"or in the provider config."
|
||||
)
|
||||
if not api_key:
|
||||
message = "API key not provided."
|
||||
if self.provider_data_api_key_field:
|
||||
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
|
||||
raise ValueError(message)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue