forked from phoenix-oss/llama-stack-mirror
Redact sensitive information from configs when printing, etc.
This commit is contained in:
parent
d9f75cc98f
commit
e3f187fb83
13 changed files with 54 additions and 21 deletions
|
@ -39,6 +39,7 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -273,7 +274,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
|
||||||
|
# Redact sensitive information before printing
|
||||||
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
endpoints = get_all_api_endpoints()
|
endpoints = get_all_api_endpoints()
|
||||||
endpoint_impls = {}
|
endpoint_impls = {}
|
||||||
|
|
|
@ -35,6 +35,7 @@ from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
@ -280,7 +281,8 @@ def main():
|
||||||
config = StackRunConfig(**config)
|
config = StackRunConfig(**config)
|
||||||
|
|
||||||
print("Run configuration:")
|
print("Run configuration:")
|
||||||
print(yaml.dump(config.model_dump(), indent=2))
|
safe_config = redact_sensitive_fields(config.model_dump())
|
||||||
|
print(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
app.add_middleware(TracingMiddleware)
|
||||||
|
|
|
@ -112,6 +112,26 @@ class EnvVarError(Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Redact sensitive information from config before printing."""
|
||||||
|
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||||
|
|
||||||
|
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
result[k] = _redact_dict(v)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||||
|
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||||
|
result[k] = "********"
|
||||||
|
else:
|
||||||
|
result[k] = v
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _redact_dict(data)
|
||||||
|
|
||||||
|
|
||||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
result = {}
|
result = {}
|
||||||
|
|
|
@ -71,7 +71,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
self.client = AsyncCerebras(
|
self.client = AsyncCerebras(
|
||||||
base_url=self.config.base_url, api_key=self.config.api_key
|
base_url=self.config.base_url,
|
||||||
|
api_key=self.config.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -8,7 +8,7 @@ import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class CerebrasImplConfig(BaseModel):
|
||||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||||
description="Base URL for the Cerebras API",
|
description="Base URL for the Cerebras API",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[SecretStr] = Field(
|
||||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||||
description="Cerebras API Key",
|
description="Cerebras API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -16,7 +16,7 @@ class FireworksImplConfig(BaseModel):
|
||||||
default="https://api.fireworks.ai/inference/v1",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[SecretStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Fireworks.ai API Key",
|
description="The Fireworks.ai API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -113,7 +113,7 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
return self.config.api_key
|
return self.config.api_key.get_secret_value()
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
|
|
|
@ -8,7 +8,7 @@ import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -40,7 +40,7 @@ class NVIDIAConfig(BaseModel):
|
||||||
),
|
),
|
||||||
description="A base url for accessing the NVIDIA NIM",
|
description="A base url for accessing the NVIDIA NIM",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[SecretStr] = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||||
description="The NVIDIA API key, only needed of using the hosted service",
|
description="The NVIDIA API key, only needed of using the hosted service",
|
||||||
)
|
)
|
||||||
|
|
|
@ -113,7 +113,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
# make sure the client lives longer than any async calls
|
# make sure the client lives longer than any async calls
|
||||||
self._client = AsyncOpenAI(
|
self._client = AsyncOpenAI(
|
||||||
base_url=f"{self._config.url}/v1",
|
base_url=f"{self._config.url}/v1",
|
||||||
api_key=self._config.api_key or "NO KEY",
|
api_key=(
|
||||||
|
self._config.api_key.get_secret_value()
|
||||||
|
if self._config.api_key
|
||||||
|
else "NO KEY"
|
||||||
|
),
|
||||||
timeout=self._config.timeout,
|
timeout=self._config.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -15,7 +15,7 @@ class TGIImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
description="The URL for the TGI serving endpoint",
|
description="The URL for the TGI serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: Optional[str] = Field(
|
api_token: Optional[SecretStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="A bearer token if your TGI endpoint is protected.",
|
description="A bearer token if your TGI endpoint is protected.",
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
||||||
endpoint_name: str = Field(
|
endpoint_name: str = Field(
|
||||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||||
)
|
)
|
||||||
api_token: Optional[str] = Field(
|
api_token: Optional[SecretStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||||
)
|
)
|
||||||
|
@ -55,7 +55,7 @@ class InferenceAPIImplConfig(BaseModel):
|
||||||
huggingface_repo: str = Field(
|
huggingface_repo: str = Field(
|
||||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||||
)
|
)
|
||||||
api_token: Optional[str] = Field(
|
api_token: Optional[SecretStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||||
)
|
)
|
||||||
|
|
|
@ -290,7 +290,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.url}")
|
||||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
self.client = AsyncInferenceClient(
|
||||||
|
model=config.url, token=config.api_token.get_secret_value()
|
||||||
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
@ -299,7 +301,7 @@ class TGIAdapter(_HfAdapter):
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
class InferenceAPIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.huggingface_repo, token=config.api_token
|
model=config.huggingface_repo, token=config.api_token.get_secret_value()
|
||||||
)
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
@ -309,7 +311,7 @@ class InferenceAPIAdapter(_HfAdapter):
|
||||||
class InferenceEndpointAdapter(_HfAdapter):
|
class InferenceEndpointAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||||
# Get the inference endpoint details
|
# Get the inference endpoint details
|
||||||
api = HfApi(token=config.api_token)
|
api = HfApi(token=config.api_token.get_secret_value())
|
||||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||||
|
|
||||||
# Wait for the endpoint to be ready (if not already)
|
# Wait for the endpoint to be ready (if not already)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -16,7 +16,7 @@ class TogetherImplConfig(BaseModel):
|
||||||
default="https://api.together.xyz/v1",
|
default="https://api.together.xyz/v1",
|
||||||
description="The URL for the Together AI server",
|
description="The URL for the Together AI server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[SecretStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Together AI API Key",
|
description="The Together AI API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -130,7 +130,7 @@ class TogetherInferenceAdapter(
|
||||||
def _get_client(self) -> Together:
|
def _get_client(self) -> Together:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
together_api_key = self.config.api_key
|
together_api_key = self.config.api_key.get_secret_value()
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue