Redact sensitive information from configs when printing, etc.

This commit is contained in:
Ashwin Bharambe 2025-01-02 11:40:48 -08:00
parent d9f75cc98f
commit e3f187fb83
13 changed files with 54 additions and 21 deletions

View file

@ -39,6 +39,7 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
@ -273,7 +274,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
console = Console() console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))
# Redact sensitive information before printing
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints() endpoints = get_all_api_endpoints()
endpoint_impls = {} endpoint_impls = {}

View file

@ -35,6 +35,7 @@ from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
@ -280,7 +281,8 @@ def main():
config = StackRunConfig(**config) config = StackRunConfig(**config)
print("Run configuration:") print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2)) safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)

View file

@ -112,6 +112,26 @@ class EnvVarError(Exception):
) )
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any: def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict): if isinstance(config, dict):
result = {} result = {}

View file

@ -71,7 +71,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras( self.client = AsyncCerebras(
base_url=self.config.base_url, api_key=self.config.api_key base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
) )
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -8,7 +8,7 @@ import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
DEFAULT_BASE_URL = "https://api.cerebras.ai" DEFAULT_BASE_URL = "https://api.cerebras.ai"
@ -19,7 +19,7 @@ class CerebrasImplConfig(BaseModel):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API", description="Base URL for the Cerebras API",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=os.environ.get("CEREBRAS_API_KEY"), default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key", description="Cerebras API Key",
) )

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -16,7 +16,7 @@ class FireworksImplConfig(BaseModel):
default="https://api.fireworks.ai/inference/v1", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=None, default=None,
description="The Fireworks.ai API Key", description="The Fireworks.ai API Key",
) )

View file

@ -113,7 +113,7 @@ class FireworksInferenceAdapter(
def _get_api_key(self) -> str: def _get_api_key(self) -> str:
if self.config.api_key is not None: if self.config.api_key is not None:
return self.config.api_key return self.config.api_key.get_secret_value()
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key: if provider_data is None or not provider_data.fireworks_api_key:

View file

@ -8,7 +8,7 @@ import os
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -40,7 +40,7 @@ class NVIDIAConfig(BaseModel):
), ),
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"), default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service", description="The NVIDIA API key, only needed of using the hosted service",
) )

View file

@ -113,7 +113,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# make sure the client lives longer than any async calls # make sure the client lives longer than any async calls
self._client = AsyncOpenAI( self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1", base_url=f"{self._config.url}/v1",
api_key=self._config.api_key or "NO KEY", api_key=(
self._config.api_key.get_secret_value()
if self._config.api_key
else "NO KEY"
),
timeout=self._config.timeout, timeout=self._config.timeout,
) )

View file

@ -7,7 +7,7 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -15,7 +15,7 @@ class TGIImplConfig(BaseModel):
url: str = Field( url: str = Field(
description="The URL for the TGI serving endpoint", description="The URL for the TGI serving endpoint",
) )
api_token: Optional[str] = Field( api_token: Optional[SecretStr] = Field(
default=None, default=None,
description="A bearer token if your TGI endpoint is protected.", description="A bearer token if your TGI endpoint is protected.",
) )
@ -32,7 +32,7 @@ class InferenceEndpointImplConfig(BaseModel):
endpoint_name: str = Field( endpoint_name: str = Field(
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
) )
api_token: Optional[str] = Field( api_token: Optional[SecretStr] = Field(
default=None, default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)", description="Your Hugging Face user access token (will default to locally saved token if not provided)",
) )
@ -55,7 +55,7 @@ class InferenceAPIImplConfig(BaseModel):
huggingface_repo: str = Field( huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
) )
api_token: Optional[str] = Field( api_token: Optional[SecretStr] = Field(
default=None, default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)", description="Your Hugging Face user access token (will default to locally saved token if not provided)",
) )

View file

@ -290,7 +290,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
log.info(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token) self.client = AsyncInferenceClient(
model=config.url, token=config.api_token.get_secret_value()
)
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
@ -299,7 +301,7 @@ class TGIAdapter(_HfAdapter):
class InferenceAPIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None: async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(
model=config.huggingface_repo, token=config.api_token model=config.huggingface_repo, token=config.api_token.get_secret_value()
) )
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
@ -309,7 +311,7 @@ class InferenceAPIAdapter(_HfAdapter):
class InferenceEndpointAdapter(_HfAdapter): class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None: async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details # Get the inference endpoint details
api = HfApi(token=config.api_token) api = HfApi(token=config.api_token.get_secret_value())
endpoint = api.get_inference_endpoint(config.endpoint_name) endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already) # Wait for the endpoint to be ready (if not already)

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -16,7 +16,7 @@ class TogetherImplConfig(BaseModel):
default="https://api.together.xyz/v1", default="https://api.together.xyz/v1",
description="The URL for the Together AI server", description="The URL for the Together AI server",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=None, default=None,
description="The Together AI API Key", description="The Together AI API Key",
) )

View file

@ -130,7 +130,7 @@ class TogetherInferenceAdapter(
def _get_client(self) -> Together: def _get_client(self) -> Together:
together_api_key = None together_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
together_api_key = self.config.api_key together_api_key = self.config.api_key.get_secret_value()
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key: