mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Redact sensitive information from configs when printing, etc.
This commit is contained in:
parent
d9f75cc98f
commit
e3f187fb83
13 changed files with 54 additions and 21 deletions
|
@ -39,6 +39,7 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
|||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
||||
|
@ -273,7 +274,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||
|
||||
# Redact sensitive information before printing
|
||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
|
|
@ -35,6 +35,7 @@ from llama_stack.distribution.request_headers import set_request_provider_data
|
|||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
|
@ -280,7 +281,8 @@ def main():
|
|||
config = StackRunConfig(**config)
|
||||
|
||||
print("Run configuration:")
|
||||
print(yaml.dump(config.model_dump(), indent=2))
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
|
|
@ -112,6 +112,26 @@ class EnvVarError(Exception):
|
|||
)
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
|
|
|
@ -71,7 +71,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url, api_key=self.config.api_key
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -8,7 +8,7 @@ import os
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
@ -19,7 +19,7 @@ class CerebrasImplConfig(BaseModel):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -16,7 +16,7 @@ class FireworksImplConfig(BaseModel):
|
|||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
|
|
@ -113,7 +113,7 @@ class FireworksInferenceAdapter(
|
|||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
return self.config.api_key
|
||||
return self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
|
|
|
@ -8,7 +8,7 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -40,7 +40,7 @@ class NVIDIAConfig(BaseModel):
|
|||
),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
|
|
|
@ -113,7 +113,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=self._config.api_key or "NO KEY",
|
||||
api_key=(
|
||||
self._config.api_key.get_secret_value()
|
||||
if self._config.api_key
|
||||
else "NO KEY"
|
||||
),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -15,7 +15,7 @@ class TGIImplConfig(BaseModel):
|
|||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="A bearer token if your TGI endpoint is protected.",
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ class InferenceAPIImplConfig(BaseModel):
|
|||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
|
|
@ -290,7 +290,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url, token=config.api_token.get_secret_value()
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
@ -299,7 +301,7 @@ class TGIAdapter(_HfAdapter):
|
|||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.huggingface_repo, token=config.api_token
|
||||
model=config.huggingface_repo, token=config.api_token.get_secret_value()
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
@ -309,7 +311,7 @@ class InferenceAPIAdapter(_HfAdapter):
|
|||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token)
|
||||
api = HfApi(token=config.api_token.get_secret_value())
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -16,7 +16,7 @@ class TogetherImplConfig(BaseModel):
|
|||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
|
|
@ -130,7 +130,7 @@ class TogetherInferenceAdapter(
|
|||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
together_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue