mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
attempt to finish the implementation started by matt
This commit is contained in:
parent
6b585fac00
commit
fa4a9ece5b
4 changed files with 66 additions and 44 deletions
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -48,7 +48,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
|
||||
default=None,
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
|
|
@ -60,6 +60,22 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def _default_api_key_from_env(cls, value: SecretStr | str | None) -> SecretStr | None:
|
||||
"""Populate the API key from the NVIDIA_API_KEY environment variable when absent."""
|
||||
if value is None:
|
||||
env_value = os.getenv("NVIDIA_API_KEY")
|
||||
return SecretStr(env_value) if env_value else None
|
||||
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return SecretStr(value)
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
"""
|
||||
|
||||
config: RunpodImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "runpod_api_token"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue