mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 21:04:39 +00:00
Merge 2a34226727
into ea15f2a270
This commit is contained in:
commit
79ced0c85b
94 changed files with 341 additions and 209 deletions
|
@ -216,7 +216,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
with open(args.config) as f:
|
||||
try:
|
||||
contents = yaml.safe_load(f)
|
||||
contents = replace_env_vars(contents)
|
||||
contents = replace_env_vars(contents, provider_registry=get_provider_registry())
|
||||
build_config = BuildConfig(**contents)
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
|
|
|
@ -165,7 +165,7 @@ def upgrade_from_routing_table(
|
|||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
processed_config_dict = replace_env_vars(config_dict, provider_registry=get_provider_registry())
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
|
@ -177,5 +177,5 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
processed_config_dict = replace_env_vars(config_dict, provider_registry=get_provider_registry())
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
@ -33,6 +33,7 @@ from termcolor import cprint
|
|||
from llama_stack.core.build import print_pip_install_help
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
|
@ -220,7 +221,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config_dict = replace_env_vars(
|
||||
yaml.safe_load(config_path.read_text()), provider_registry=get_provider_registry()
|
||||
)
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# distribution
|
||||
|
|
|
@ -43,7 +43,7 @@ from llama_stack.core.datatypes import (
|
|||
StackRunConfig,
|
||||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis, get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
|
@ -371,7 +371,7 @@ def create_app(
|
|||
logger.error(f"Error: {str(e)}")
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = replace_env_vars(config_contents, provider_registry=get_provider_registry())
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
@ -524,7 +524,10 @@ def main(args: argparse.Namespace | None = None):
|
|||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
logger.error(f"Stack trace:\n{traceback.format_exc()}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
|
@ -534,7 +537,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
config = StackRunConfig(
|
||||
**cast_image_name_to_string(replace_env_vars(config_contents, provider_registry=get_provider_registry()))
|
||||
)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -141,12 +141,19 @@ class EnvVarError(Exception):
|
|||
)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
def replace_env_vars(
|
||||
config: Any,
|
||||
path: str = "",
|
||||
provider_registry: dict[Api, dict[str, Any]] | None = None,
|
||||
current_provider_context: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
for k, v in config.items():
|
||||
try:
|
||||
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
||||
result[k] = replace_env_vars(
|
||||
v, f"{path}.{k}" if path else k, provider_registry, current_provider_context
|
||||
)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
@ -159,7 +166,9 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# is disabled so that we can skip config env variable expansion and avoid validation errors
|
||||
if isinstance(v, dict) and "provider_id" in v:
|
||||
try:
|
||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
||||
resolved_provider_id = replace_env_vars(
|
||||
v["provider_id"], f"{path}[{i}].provider_id", provider_registry, current_provider_context
|
||||
)
|
||||
if resolved_provider_id == "__disabled__":
|
||||
logger.debug(
|
||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
||||
|
@ -167,13 +176,19 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# Create a copy with resolved provider_id but original config
|
||||
disabled_provider = v.copy()
|
||||
disabled_provider["provider_id"] = resolved_provider_id
|
||||
result.append(disabled_provider)
|
||||
continue
|
||||
except EnvVarError:
|
||||
# If we can't resolve the provider_id, continue with normal processing
|
||||
pass
|
||||
|
||||
# Set up provider context for config processing
|
||||
provider_context = current_provider_context
|
||||
if isinstance(v, dict) and "provider_id" in v and "provider_type" in v and provider_registry:
|
||||
provider_context = _get_provider_context(v, provider_registry)
|
||||
|
||||
# Normal processing for non-disabled providers
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]", provider_registry, provider_context))
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
@ -228,7 +243,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
result = re.sub(pattern, get_env_var, config)
|
||||
# Only apply type conversion if substitution actually happened
|
||||
if result != config:
|
||||
return _convert_string_to_proper_type(result)
|
||||
return _convert_string_to_proper_type_with_config(result, path, current_provider_context)
|
||||
return result
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
@ -236,12 +251,113 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
return config
|
||||
|
||||
|
||||
def _get_provider_context(
|
||||
provider_dict: dict[str, Any], provider_registry: dict[Api, dict[str, Any]]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get provider context information including config class for type conversion."""
|
||||
try:
|
||||
provider_type = provider_dict.get("provider_type")
|
||||
if not provider_type:
|
||||
return None
|
||||
|
||||
for api, providers in provider_registry.items():
|
||||
if provider_type in providers:
|
||||
provider_spec = providers[provider_type]
|
||||
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
return {
|
||||
"api": api,
|
||||
"provider_type": provider_type,
|
||||
"config_class": config_class,
|
||||
"provider_spec": provider_spec,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get provider context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _convert_string_to_proper_type_with_config(value: str, path: str, provider_context: dict[str, Any] | None) -> Any:
|
||||
"""Convert string to proper type using provider config class field information."""
|
||||
if not provider_context or not provider_context.get("config_class"):
|
||||
# best effort conversion if we don't have the config class
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
try:
|
||||
# Extract field name from path (e.g., "providers.inference[0].config.api_key" -> "api_key")
|
||||
field_name = path.split(".")[-1] if "." in path else path
|
||||
|
||||
config_class = provider_context["config_class"]
|
||||
# Only instantiate if the class hasn't been instantiated already
|
||||
# This handles the case we entered replace_env_vars() with a dict, which
|
||||
# could happen if we use a sample_run_config() method that returns a dict. Our unit tests do
|
||||
# this on the adhoc config spec creation.
|
||||
if isinstance(config_class, str):
|
||||
config_class = instantiate_class_type(config_class)
|
||||
|
||||
if hasattr(config_class, "model_fields") and field_name in config_class.model_fields:
|
||||
field_info = config_class.model_fields[field_name]
|
||||
field_type = field_info.annotation
|
||||
return _convert_value_by_field_type(value, field_type)
|
||||
else:
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to convert using config class: {e}")
|
||||
return _convert_string_to_proper_type(value)
|
||||
|
||||
|
||||
def _convert_value_by_field_type(value: str, field_type: Any) -> Any:
|
||||
"""Convert string value based on Pydantic field type annotation."""
|
||||
import typing
|
||||
from typing import get_args, get_origin
|
||||
|
||||
if value == "":
|
||||
if field_type is None or (hasattr(typing, "get_origin") and get_origin(field_type) is type(None)):
|
||||
return None
|
||||
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
||||
args = get_args(field_type)
|
||||
if type(None) in args:
|
||||
return None
|
||||
return ""
|
||||
|
||||
if field_type is bool or (hasattr(typing, "get_origin") and get_origin(field_type) is bool):
|
||||
lowered = value.lower()
|
||||
if lowered == "true":
|
||||
return True
|
||||
elif lowered == "false":
|
||||
return False
|
||||
else:
|
||||
return value
|
||||
|
||||
if field_type is int or (hasattr(typing, "get_origin") and get_origin(field_type) is int):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
if field_type is float or (hasattr(typing, "get_origin") and get_origin(field_type) is float):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
if hasattr(typing, "get_origin") and get_origin(field_type) is typing.Union:
|
||||
args = get_args(field_type)
|
||||
# Try to convert to the first non-None type
|
||||
for arg in args:
|
||||
if arg is not type(None):
|
||||
try:
|
||||
return _convert_value_by_field_type(value, arg)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _convert_string_to_proper_type(value: str) -> Any:
|
||||
# This might be tricky depending on what the config type is, if 'str | None' we are
|
||||
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
|
||||
# providers config should be typed this way.
|
||||
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
|
||||
# and then convert the empty string to None or not
|
||||
# Fallback function for when provider config class is not available
|
||||
# The main type conversion logic is now in _convert_string_to_proper_type_with_config
|
||||
if value == "":
|
||||
return None
|
||||
|
||||
|
@ -416,7 +532,7 @@ def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
|
|||
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
|
||||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
return StackRunConfig(**replace_env_vars(run_config, provider_registry=get_provider_registry()))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
|
@ -452,7 +568,11 @@ def run_config_from_adhoc_config_spec(
|
|||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
provider_config = replace_env_vars(
|
||||
provider_config_type.sample_run_config(__distro_dir__=distro_dir),
|
||||
provider_registry=provider_registry,
|
||||
current_provider_context=provider_spec.model_dump(),
|
||||
)
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
|
@ -13,7 +13,7 @@ from .config import BraintrustScoringConfig
|
|||
|
||||
|
||||
class BraintrustProviderDataValidator(BaseModel):
|
||||
openai_api_key: str
|
||||
openai_api_key: SecretStr
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
|
|
|
@ -17,7 +17,7 @@ from autoevals.ragas import (
|
|||
ContextRelevancy,
|
||||
Faithfulness,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -152,9 +152,9 @@ class BraintrustScoringImpl(
|
|||
raise ValueError(
|
||||
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
|
||||
)
|
||||
self.config.openai_api_key = provider_data.openai_api_key
|
||||
self.config.openai_api_key = SecretStr(provider_data.openai_api_key)
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
|
||||
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key.get_secret_value()
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
openai_api_key: SecretStr = Field(
|
||||
description="The OpenAI API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
|
||||
logger.info(f"[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
|
|
|
@ -8,14 +8,14 @@ import os
|
|||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class NvidiaDatasetIOConfig(BaseModel):
|
||||
"""Configuration for NVIDIA DatasetIO implementation."""
|
||||
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
api_key: SecretStr = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY", "")),
|
||||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
@ -17,9 +17,7 @@ class S3FilesImplConfig(BaseModel):
|
|||
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||
)
|
||||
aws_secret_access_key: SecretStr = Field(description="AWS secret access key (optional if using IAM roles)")
|
||||
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
|
|
|
@ -47,7 +47,7 @@ def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
|||
s3_config.update(
|
||||
{
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
"aws_secret_access_key": config.aws_secret_access_key.get_secret_value(),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
|
@ -6,22 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = Field(
|
||||
default=None,
|
||||
anthropic_api_key: SecretStr = Field(
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
|
|
|
@ -18,7 +18,6 @@ class DatabricksImplConfig(BaseModel):
|
|||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None),
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
|
|
|
@ -18,8 +18,7 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -6,22 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: str | None = Field(
|
||||
default=None,
|
||||
gemini_api_key: SecretStr = Field(
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
|
@ -6,23 +6,21 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
groq_api_key: SecretStr = Field(
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
api_key: SecretStr = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
|
|
|
@ -6,22 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class LlamaProviderDataValidator(BaseModel):
|
||||
llama_api_key: str | None = Field(
|
||||
default=None,
|
||||
llama_api_key: SecretStr = Field(
|
||||
description="API key for api.llama models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The Llama API key",
|
||||
)
|
||||
|
||||
|
|
|
@ -39,8 +39,8 @@ class NVIDIAConfig(BaseModel):
|
|||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
|
||||
api_key: SecretStr = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY", "")),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
|
|
|
@ -6,22 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
openai_api_key: SecretStr = Field(
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
|
|
|
@ -18,8 +18,7 @@ class PassthroughImplConfig(BaseModel):
|
|||
description="The URL for the passthrough endpoint",
|
||||
)
|
||||
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="API Key for the passthrouth endpoint",
|
||||
)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -17,8 +17,7 @@ class RunpodImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
default=None,
|
||||
api_token: SecretStr = Field(
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
|
|
|
@ -103,7 +103,10 @@ class RunpodInferenceAdapter(
|
|||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
client = OpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token.get_secret_value() if self.config.api_token else None,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
|
|
|
@ -12,8 +12,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
sambanova_api_key: SecretStr = Field(
|
||||
description="Sambanova Cloud API key",
|
||||
)
|
||||
|
||||
|
@ -24,8 +23,7 @@ class SambaNovaImplConfig(BaseModel):
|
|||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
api_key_from_config=self.config.api_key,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
openai_compat_api_base=self.config.url,
|
||||
download_images=True, # SambaNova requires base64 image encoding
|
||||
|
|
|
@ -32,8 +32,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_token: SecretStr = Field(
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
|
@ -55,8 +54,7 @@ class InferenceAPIImplConfig(BaseModel):
|
|||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_token: SecretStr = Field(
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
|
|
|
@ -18,8 +18,7 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
import google.auth.transport.requests
|
||||
from google.auth import default
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
|
@ -23,12 +24,12 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="vertex_ai",
|
||||
api_key_from_config=None, # Vertex AI uses ADC, not API keys
|
||||
api_key_from_config=SecretStr(""), # Vertex AI uses ADC, not API keys
|
||||
provider_data_api_key_field="vertex_project", # Use project for validation
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
def get_api_key(self) -> SecretStr:
|
||||
"""
|
||||
Get an access token for Vertex AI using Application Default Credentials.
|
||||
|
||||
|
@ -39,11 +40,11 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
|
||||
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
credentials.refresh(google.auth.transport.requests.Request())
|
||||
return str(credentials.token)
|
||||
return SecretStr(credentials.token)
|
||||
except Exception:
|
||||
# If we can't get credentials, return empty string to let LiteLLM handle it
|
||||
# This allows the LiteLLM mixin to work with ADC directly
|
||||
return ""
|
||||
return SecretStr("")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
|
@ -4,13 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
class VLLMProviderDataValidator(BaseModel):
|
||||
vllm_api_token: str | None = None
|
||||
vllm_api_token: SecretStr = Field(
|
||||
description="API token for vLLM models",
|
||||
)
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -21,8 +21,8 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
default="fake",
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr("fake"),
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool | str = Field(
|
||||
|
|
|
@ -24,8 +24,8 @@ class WatsonXConfig(BaseModel):
|
|||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||
api_key: SecretStr = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("WATSONX_API_KEY", "")),
|
||||
description="The watsonx API key",
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
# TODO: add default values for all fields
|
||||
|
||||
|
@ -15,8 +15,8 @@ from pydantic import BaseModel, Field
|
|||
class NvidiaPostTrainingConfig(BaseModel):
|
||||
"""Configuration for NVIDIA Post Training implementation."""
|
||||
|
||||
api_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
api_key: SecretStr = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY", "")),
|
||||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
|
|
|
@ -12,8 +12,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
sambanova_api_key: SecretStr = Field(
|
||||
description="Sambanova Cloud API key",
|
||||
)
|
||||
|
||||
|
@ -24,8 +23,7 @@ class SambaNovaSafetyConfig(BaseModel):
|
|||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.bing_search_api_key:
|
||||
|
|
|
@ -6,13 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class BingSearchToolConfig(BaseModel):
|
||||
"""Configuration for Bing Search Tool Runtime"""
|
||||
|
||||
api_key: str | None = None
|
||||
api_key: SecretStr = Field(
|
||||
description="The Bing API key",
|
||||
)
|
||||
top_k: int = 3
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -39,7 +39,7 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.brave_search_api_key:
|
||||
|
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class BraveSearchToolConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The Brave Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
|
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class TavilySearchToolConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
api_key: SecretStr = Field(
|
||||
description="The Tavily Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
|
|
|
@ -39,7 +39,7 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.tavily_search_api_key:
|
||||
|
|
|
@ -6,13 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class WolframAlphaToolConfig(BaseModel):
|
||||
"""Configuration for WolframAlpha Tool Runtime"""
|
||||
|
||||
api_key: str | None = None
|
||||
api_key: SecretStr = Field(
|
||||
description="The WolframAlpha API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
|
@ -40,7 +40,7 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.wolfram_alpha_api_key:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
uri: str = Field(description="The URI of the Milvus server")
|
||||
token: str | None = Field(description="The token of the Milvus server")
|
||||
token: SecretStr = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
|
@ -21,7 +21,7 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
port: int | None = Field(default=5432)
|
||||
db: str | None = Field(default="postgres")
|
||||
user: str | None = Field(default="postgres")
|
||||
password: str | None = Field(default="mysecretpassword")
|
||||
password: SecretStr = Field(default=SecretStr("mysecretpassword"))
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -366,7 +366,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
password=self.config.password.get_secret_value(),
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
|
@ -23,7 +23,9 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
https: bool | None = None
|
||||
api_key: str | None = None
|
||||
api_key: SecretStr = Field(
|
||||
description="The API key for the Qdrant instance",
|
||||
)
|
||||
prefix: str | None = None
|
||||
timeout: int | None = None
|
||||
host: str | None = None
|
||||
|
|
|
@ -173,7 +173,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"})
|
||||
client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"}, mode="json")
|
||||
self.client = AsyncQdrantClient(**client_config)
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
|
|
|
@ -50,8 +50,8 @@ def create_bedrock_client(config: BedrockBaseConfig, service_name: str = "bedroc
|
|||
|
||||
session_args = {
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
"aws_session_token": config.aws_session_token,
|
||||
"aws_secret_access_key": config.aws_secret_access_key.get_secret_value(),
|
||||
"aws_session_token": config.aws_session_token.get_secret_value(),
|
||||
"region_name": config.region_name,
|
||||
"profile_name": config.profile_name,
|
||||
"session_ttl": config.session_ttl,
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
|
@ -14,12 +14,12 @@ class BedrockBaseConfig(BaseModel):
|
|||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_secret_access_key: SecretStr = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("AWS_SECRET_ACCESS_KEY", "")),
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
||||
aws_session_token: SecretStr = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("AWS_SESSION_TOKEN", "")),
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: str | None = Field(
|
||||
|
|
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator
|
|||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -61,7 +62,7 @@ class LiteLLMOpenAIMixin(
|
|||
def __init__(
|
||||
self,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
api_key_from_config: SecretStr,
|
||||
provider_data_api_key_field: str,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
openai_compat_api_base: str | None = None,
|
||||
|
@ -240,14 +241,14 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": self.get_api_key(),
|
||||
"api_key": self.get_api_key().get_secret_value(),
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
def get_api_key(self) -> SecretStr:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
|
@ -280,7 +281,7 @@ class LiteLLMOpenAIMixin(
|
|||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_key=self.get_api_key().get_secret_value(),
|
||||
api_base=self.api_base,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
|
@ -343,7 +344,7 @@ class LiteLLMOpenAIMixin(
|
|||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
api_key=self.get_api_key(),
|
||||
api_key=self.get_api_key().get_secret_value(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.atext_completion(**params)
|
||||
|
@ -407,7 +408,7 @@ class LiteLLMOpenAIMixin(
|
|||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
api_key=self.get_api_key(),
|
||||
api_key=self.get_api_key().get_secret_value(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.acompletion(**params)
|
||||
|
|
|
@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
|
|||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Model,
|
||||
|
@ -70,14 +71,14 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
allowed_models: list[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
def get_api_key(self) -> SecretStr:
|
||||
"""
|
||||
Get the API key.
|
||||
|
||||
This method must be implemented by child classes to provide the API key
|
||||
for authenticating with the OpenAI API or compatible endpoints.
|
||||
|
||||
:return: The API key as a string
|
||||
:return: The API key as a SecretStr
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -113,7 +114,7 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
implemented by child classes.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
api_key=self.get_api_key().get_secret_value(),
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ import re
|
|||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
@ -74,7 +74,7 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
password: SecretStr = SecretStr("")
|
||||
ssl_mode: str | None = None
|
||||
ca_cert_path: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
@ -118,7 +118,7 @@ class MongoDBKVStoreConfig(CommonConfig):
|
|||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
password: SecretStr = SecretStr("")
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -34,7 +34,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
"host": self.config.host,
|
||||
"port": self.config.port,
|
||||
"username": self.config.user,
|
||||
"password": self.config.password,
|
||||
"password": self.config.password.get_secret_value(),
|
||||
}
|
||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||
self.conn = AsyncMongoClient(**conn_creds)
|
||||
|
|
|
@ -30,7 +30,7 @@ class PostgresKVStoreImpl(KVStore):
|
|||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
password=self.config.password.get_secret_value(),
|
||||
sslmode=self.config.ssl_mode,
|
||||
sslrootcert=self.config.ca_cert_path,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ from enum import StrEnum
|
|||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
@ -63,11 +63,11 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
password: SecretStr = SecretStr("")
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password.get_secret_value()}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue