diff --git a/distributions/remote-vllm/compose.yaml b/distributions/remote-vllm/compose.yaml index 27d7de4e2..90d58a2af 100644 --- a/distributions/remote-vllm/compose.yaml +++ b/distributions/remote-vllm/compose.yaml @@ -71,6 +71,13 @@ services: - ~/.llama:/root/.llama - ~/local/llama-stack/distributions/remote-vllm/run.yaml:/root/llamastack-run-remote-vllm.yaml # network_mode: "host" + environment: + - LLAMA_INFERENCE_VLLM_URL=${LLAMA_INFERENCE_VLLM_URL:-http://host.docker.internal:5100/v1} + - LLAMA_INFERENCE_MODEL=${LLAMA_INFERENCE_MODEL:-Llama3.1-8B-Instruct} + - MAX_TOKENS=${MAX_TOKENS:-4096} + - SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm} + - LLAMA_SAFETY_VLLM_URL=${LLAMA_SAFETY_VLLM_URL:-http://host.docker.internal:5101/v1} + - LLAMA_SAFETY_MODEL=${LLAMA_SAFETY_MODEL:-Llama-Guard-3-1B} ports: - "5001:5001" # Hack: wait for vLLM server to start before starting docker diff --git a/distributions/remote-vllm/run.yaml b/distributions/remote-vllm/run.yaml index af02b1ba5..eae5b8a6f 100644 --- a/distributions/remote-vllm/run.yaml +++ b/distributions/remote-vllm/run.yaml @@ -16,16 +16,16 @@ providers: provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode - url: http://host.docker.internal:5100/v1 - max_tokens: 4096 + url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1} + max_tokens: ${env.MAX_TOKENS:4096} api_token: fake # serves safety llama_guard model - provider_id: vllm-1 provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode - url: http://host.docker.internal:5101/v1 - max_tokens: 4096 + url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1} + max_tokens: ${env.MAX_TOKENS:4096} api_token: fake memory: - provider_id: faiss-0 @@ -34,7 +34,7 @@ providers: kvstore: namespace: null type: sqlite - db_path: /home/ashwin/.llama/distributions/remote-vllm/faiss_store.db + db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/faiss_store.db" safety: - provider_id: llama-guard provider_type: inline::llama-guard @@ -50,7 +50,7 @@ providers: persistence_store: namespace: null type: sqlite - db_path: /home/ashwin/.llama/distributions/remote-vllm/agents_store.db + db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/agents_store.db" telemetry: - provider_id: meta0 provider_type: inline::meta-reference @@ -58,11 +58,11 @@ providers: metadata_store: namespace: null type: sqlite - db_path: /home/ashwin/.llama/distributions/remote-vllm/registry.db + db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/registry.db" models: - - model_id: Llama3.1-8B-Instruct + - model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct} provider_id: vllm-0 - - model_id: Llama-Guard-3-1B + - model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} provider_id: vllm-1 shields: - - shield_id: Llama-Guard-3-1B + - shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 05927eef5..518f9dd7c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -8,6 +8,8 @@ import asyncio import functools import inspect import json +import os +import re import signal import sys import traceback @@ -258,13 +260,66 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint +class EnvVarError(Exception): + def __init__(self, var_name: str, path: str = ""): + self.var_name = var_name + self.path = path + super().__init__( + f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}" + ) + + +def replace_env_vars(config: Any, path: str = "") -> Any: + if isinstance(config, dict): + result = {} + for k, v in config.items(): + try: + result[k] = replace_env_vars(v, f"{path}.{k}" if path else k) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + return result + + elif isinstance(config, list): + result = [] + for i, v in enumerate(config): + try: + result.append(replace_env_vars(v, f"{path}[{i}]")) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + return result + + elif isinstance(config, str): + pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" + + def get_env_var(match): + env_var = match.group(1) + default_val = match.group(2) + + value = os.environ.get(env_var) + if not value: + if default_val is None: + raise EnvVarError(env_var, path) + else: + value = default_val + + return value + + try: + return re.sub(pattern, get_env_var, config) + except EnvVarError as e: + raise EnvVarError(e.var_name, e.path) from None + + return config + + def main( yaml_config: str = "llamastack-run.yaml", port: int = 5000, disable_ipv6: bool = False, ): with open(yaml_config, "r") as fp: - config = StackRunConfig(**yaml.safe_load(fp)) + config = replace_env_vars(yaml.safe_load(fp)) + config = StackRunConfig(**config) app = FastAPI()