forked from phoenix-oss/llama-stack-mirror
Redact sensitive information from configs when printing, etc.
This commit is contained in:
parent
d9f75cc98f
commit
e3f187fb83
13 changed files with 54 additions and 21 deletions
|
@ -71,7 +71,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url, api_key=self.config.api_key
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -8,7 +8,7 @@ import os
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
@ -19,7 +19,7 @@ class CerebrasImplConfig(BaseModel):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -16,7 +16,7 @@ class FireworksImplConfig(BaseModel):
|
|||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
|
|
@ -113,7 +113,7 @@ class FireworksInferenceAdapter(
|
|||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
return self.config.api_key
|
||||
return self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
|
|
|
@ -8,7 +8,7 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -40,7 +40,7 @@ class NVIDIAConfig(BaseModel):
|
|||
),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
|
|
|
@ -113,7 +113,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=self._config.api_key or "NO KEY",
|
||||
api_key=(
|
||||
self._config.api_key.get_secret_value()
|
||||
if self._config.api_key
|
||||
else "NO KEY"
|
||||
),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -15,7 +15,7 @@ class TGIImplConfig(BaseModel):
|
|||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="A bearer token if your TGI endpoint is protected.",
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ class InferenceAPIImplConfig(BaseModel):
|
|||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
|
|
@ -290,7 +290,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url, token=config.api_token.get_secret_value()
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
@ -299,7 +301,7 @@ class TGIAdapter(_HfAdapter):
|
|||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.huggingface_repo, token=config.api_token
|
||||
model=config.huggingface_repo, token=config.api_token.get_secret_value()
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
@ -309,7 +311,7 @@ class InferenceAPIAdapter(_HfAdapter):
|
|||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token)
|
||||
api = HfApi(token=config.api_token.get_secret_value())
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -16,7 +16,7 @@ class TogetherImplConfig(BaseModel):
|
|||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
|
|
@ -130,7 +130,7 @@ class TogetherInferenceAdapter(
|
|||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
together_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue