From 96e7ef646fd2e54d9e0bab498e1ab4db64256965 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 13 Nov 2024 11:25:58 -0800 Subject: [PATCH] add support for ${env.FOO_BAR} placeholders in run.yaml files (#439) # What does this PR do? We'd like our docker steps to require _ZERO EDITS_ to a YAML file in order to get going. This is often not possible because depending on the provider, we do need some configuration input from the user. Environment variables are the best way to obtain this information. This PR allows our run.yaml to contain `${env.FOO_BAR}` placeholders which can be replaced using `docker run -e FOO_BAR=baz` (and similar `docker compose` equivalent). ## Test Plan For remote-vllm, example `run.yaml` snippet looks like this: ```yaml providers: inference: # serves main inference model - provider_id: vllm-0 provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1} max_tokens: ${env.MAX_TOKENS:4096} api_token: fake # serves safety llama_guard model - provider_id: vllm-1 provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1} max_tokens: ${env.MAX_TOKENS:4096} api_token: fake ``` `compose.yaml` snippet looks like this: ```yaml llamastack: depends_on: - vllm-0 - vllm-1 # image: llamastack/distribution-remote-vllm image: llamastack/distribution-remote-vllm:test-0.0.52rc3 volumes: - ~/.llama:/root/.llama - ~/local/llama-stack/distributions/remote-vllm/run.yaml:/root/llamastack-run-remote-vllm.yaml # network_mode: "host" environment: - LLAMA_INFERENCE_VLLM_URL=${LLAMA_INFERENCE_VLLM_URL:-http://host.docker.internal:5100/v1} - LLAMA_INFERENCE_MODEL=${LLAMA_INFERENCE_MODEL:-Llama3.1-8B-Instruct} - MAX_TOKENS=${MAX_TOKENS:-4096} - SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm} - LLAMA_SAFETY_VLLM_URL=${LLAMA_SAFETY_VLLM_URL:-http://host.docker.internal:5101/v1} - LLAMA_SAFETY_MODEL=${LLAMA_SAFETY_MODEL:-Llama-Guard-3-1B} ``` --- distributions/remote-vllm/compose.yaml | 7 +++ distributions/remote-vllm/run.yaml | 20 ++++---- llama_stack/distribution/server/server.py | 57 ++++++++++++++++++++++- 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/distributions/remote-vllm/compose.yaml b/distributions/remote-vllm/compose.yaml index 27d7de4e2..90d58a2af 100644 --- a/distributions/remote-vllm/compose.yaml +++ b/distributions/remote-vllm/compose.yaml @@ -71,6 +71,13 @@ services: - ~/.llama:/root/.llama - ~/local/llama-stack/distributions/remote-vllm/run.yaml:/root/llamastack-run-remote-vllm.yaml # network_mode: "host" + environment: + - LLAMA_INFERENCE_VLLM_URL=${LLAMA_INFERENCE_VLLM_URL:-http://host.docker.internal:5100/v1} + - LLAMA_INFERENCE_MODEL=${LLAMA_INFERENCE_MODEL:-Llama3.1-8B-Instruct} + - MAX_TOKENS=${MAX_TOKENS:-4096} + - SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm} + - LLAMA_SAFETY_VLLM_URL=${LLAMA_SAFETY_VLLM_URL:-http://host.docker.internal:5101/v1} + - LLAMA_SAFETY_MODEL=${LLAMA_SAFETY_MODEL:-Llama-Guard-3-1B} ports: - "5001:5001" # Hack: wait for vLLM server to start before starting docker diff --git a/distributions/remote-vllm/run.yaml b/distributions/remote-vllm/run.yaml index af02b1ba5..eae5b8a6f 100644 --- a/distributions/remote-vllm/run.yaml +++ b/distributions/remote-vllm/run.yaml @@ -16,16 +16,16 @@ providers: provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode - url: http://host.docker.internal:5100/v1 - max_tokens: 4096 + url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1} + max_tokens: ${env.MAX_TOKENS:4096} api_token: fake # serves safety llama_guard model - provider_id: vllm-1 provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode - url: http://host.docker.internal:5101/v1 - max_tokens: 4096 + url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1} + max_tokens: ${env.MAX_TOKENS:4096} api_token: fake memory: - provider_id: faiss-0 @@ -34,7 +34,7 @@ providers: kvstore: namespace: null type: sqlite - db_path: /home/ashwin/.llama/distributions/remote-vllm/faiss_store.db + db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/faiss_store.db" safety: - provider_id: llama-guard provider_type: inline::llama-guard @@ -50,7 +50,7 @@ providers: persistence_store: namespace: null type: sqlite - db_path: /home/ashwin/.llama/distributions/remote-vllm/agents_store.db + db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/agents_store.db" telemetry: - provider_id: meta0 provider_type: inline::meta-reference @@ -58,11 +58,11 @@ providers: metadata_store: namespace: null type: sqlite - db_path: /home/ashwin/.llama/distributions/remote-vllm/registry.db + db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/registry.db" models: - - model_id: Llama3.1-8B-Instruct + - model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct} provider_id: vllm-0 - - model_id: Llama-Guard-3-1B + - model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} provider_id: vllm-1 shields: - - shield_id: Llama-Guard-3-1B + - shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 05927eef5..518f9dd7c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -8,6 +8,8 @@ import asyncio import functools import inspect import json +import os +import re import signal import sys import traceback @@ -258,13 +260,66 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint +class EnvVarError(Exception): + def __init__(self, var_name: str, path: str = ""): + self.var_name = var_name + self.path = path + super().__init__( + f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}" + ) + + +def replace_env_vars(config: Any, path: str = "") -> Any: + if isinstance(config, dict): + result = {} + for k, v in config.items(): + try: + result[k] = replace_env_vars(v, f"{path}.{k}" if path else k) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + return result + + elif isinstance(config, list): + result = [] + for i, v in enumerate(config): + try: + result.append(replace_env_vars(v, f"{path}[{i}]")) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + return result + + elif isinstance(config, str): + pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" + + def get_env_var(match): + env_var = match.group(1) + default_val = match.group(2) + + value = os.environ.get(env_var) + if not value: + if default_val is None: + raise EnvVarError(env_var, path) + else: + value = default_val + + return value + + try: + return re.sub(pattern, get_env_var, config) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + + return config + + def main( yaml_config: str = "llamastack-run.yaml", port: int = 5000, disable_ipv6: bool = False, ): with open(yaml_config, "r") as fp: - config = StackRunConfig(**yaml.safe_load(fp)) + config = replace_env_vars(yaml.safe_load(fp)) + config = StackRunConfig(**config) app = FastAPI()