From fe190768382019e04b27c5b6603b35e7bfe9f9b8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 18 Nov 2024 18:05:05 -0800 Subject: [PATCH] get stack run config based on template name (#477) This PR adds a method in stack to return the stackrunconfig object based on the template name. This will be used to instantiate a direct client without the need for an explicit run.yaml --------- Co-authored-by: Dinesh Yeduguru --- llama_stack/distribution/server/server.py | 78 ++------------------ llama_stack/distribution/stack.py | 90 +++++++++++++++++++++++ 2 files changed, 95 insertions(+), 73 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ccd345181..fecc41b5d 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -10,7 +10,6 @@ import functools import inspect import json import os -import re import signal import sys import traceback @@ -41,7 +40,11 @@ from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import InvalidProviderError -from llama_stack.distribution.stack import construct_stack +from llama_stack.distribution.stack import ( + construct_stack, + replace_env_vars, + validate_env_pair, +) from .endpoints import get_all_api_endpoints @@ -271,77 +274,6 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint -class EnvVarError(Exception): - def __init__(self, var_name: str, path: str = ""): - self.var_name = var_name - self.path = path - super().__init__( - f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}" - ) - - -def replace_env_vars(config: Any, path: str = "") -> Any: - if isinstance(config, dict): - result = {} - for k, v in config.items(): - try: - result[k] = replace_env_vars(v, f"{path}.{k}" if path else k) - except EnvVarError as e: - raise EnvVarError(e.var_name, e.path) from None - return result - - elif isinstance(config, list): - result = [] - for i, v in enumerate(config): - try: - result.append(replace_env_vars(v, f"{path}[{i}]")) - except EnvVarError as e: - raise EnvVarError(e.var_name, e.path) from None - return result - - elif isinstance(config, str): - pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" - - def get_env_var(match): - env_var = match.group(1) - default_val = match.group(2) - - value = os.environ.get(env_var) - if not value: - if default_val is None: - raise EnvVarError(env_var, path) - else: - value = default_val - - # expand "~" from the values - return os.path.expanduser(value) - - try: - return re.sub(pattern, get_env_var, config) - except EnvVarError as e: - raise EnvVarError(e.var_name, e.path) from None - - return config - - -def validate_env_pair(env_pair: str) -> tuple[str, str]: - """Validate and split an environment variable key-value pair.""" - try: - key, value = env_pair.split("=", 1) - key = key.strip() - if not key: - raise ValueError(f"Empty key in environment variable pair: {env_pair}") - if not all(c.isalnum() or c == "_" for c in key): - raise ValueError( - f"Key must contain only alphanumeric characters and underscores: {key}" - ) - return key, value - except ValueError as e: - raise ValueError( - f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value" - ) from e - - def main(): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1cffd7749..de196b223 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -4,8 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os +from pathlib import Path from typing import Any, Dict +import pkg_resources +import yaml + from termcolor import colored from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -92,6 +97,77 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): print("") +class EnvVarError(Exception): + def __init__(self, var_name: str, path: str = ""): + self.var_name = var_name + self.path = path + super().__init__( + f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}" + ) + + +def replace_env_vars(config: Any, path: str = "") -> Any: + if isinstance(config, dict): + result = {} + for k, v in config.items(): + try: + result[k] = replace_env_vars(v, f"{path}.{k}" if path else k) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + return result + + elif isinstance(config, list): + result = [] + for i, v in enumerate(config): + try: + result.append(replace_env_vars(v, f"{path}[{i}]")) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + return result + + elif isinstance(config, str): + pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" + + def get_env_var(match): + env_var = match.group(1) + default_val = match.group(2) + + value = os.environ.get(env_var) + if not value: + if default_val is None: + raise EnvVarError(env_var, path) + else: + value = default_val + + # expand "~" from the values + return os.path.expanduser(value) + + try: + return re.sub(pattern, get_env_var, config) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + + return config + + +def validate_env_pair(env_pair: str) -> tuple[str, str]: + """Validate and split an environment variable key-value pair.""" + try: + key, value = env_pair.split("=", 1) + key = key.strip() + if not key: + raise ValueError(f"Empty key in environment variable pair: {env_pair}") + if not all(c.isalnum() or c == "_" for c in key): + raise ValueError( + f"Key must contain only alphanumeric characters and underscores: {key}" + ) + return key, value + except ValueError as e: + raise ValueError( + f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value" + ) from e + + # Produces a stack of providers for the given run config. Not all APIs may be # asked for in the run config. async def construct_stack( @@ -105,3 +181,17 @@ async def construct_stack( ) await register_resources(run_config, impls) return impls + + +def get_stack_run_config_from_template(template: str) -> StackRunConfig: + template_path = pkg_resources.resource_filename( + "llama_stack", f"templates/{template}/run.yaml" + ) + + if not Path(template_path).exists(): + raise ValueError(f"Template '{template}' not found at {template_path}") + + with open(template_path) as f: + run_config = yaml.safe_load(f) + + return StackRunConfig(**replace_env_vars(run_config))