From 6609d4ada49f68a67dec92c04603608b91d216de Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 1 Mar 2025 11:19:14 -0800 Subject: [PATCH] feat: allow conditionally enabling providers in run.yaml (#1321) # What does this PR do? We want to bundle a bunch of (typically remote) providers in a distro template and be able to configure them "on the fly" via environment variables. So far, we have been able to do this with simple env var replacements. However, sometimes you want to only conditionally enable providers (because the relevant remote services may not be alive, or relevant.) This was not possible until now. To aid this, we add a simple (bash-like) env var replacement enhancement: `${env.FOO+bar}` evaluates to `bar` if the variable is SET and evaluates to empty string if it is not. On top of that, we update our main resolver to ignore any provider whose ID is null. This allows using the distro like this: ```bash llama stack run dev --env CHROMADB_URL=http://localhost:6001 --env ENABLE_CHROMADB=1 ``` when only Chroma is UP. This disables the other `pgvector` provider in the run configuration. ## Test Plan Hard code `chromadb` as the vector io provider inside `test_vector_io.py` and run: ```bash LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -s -v tests/client-sdk/vector_io/ --embedding-model all-MiniLM-L6-v2 ``` --- llama_stack/distribution/resolver.py | 4 ++ llama_stack/distribution/stack.py | 30 +++++++-- .../inline/vector_io/faiss/config.py | 2 +- .../remote/vector_io/chroma/config.py | 4 +- .../remote/vector_io/pgvector/config.py | 14 ++++ llama_stack/templates/dev/dev.py | 30 +++++++-- llama_stack/templates/dev/run.yaml | 12 ++++ llama_stack/templates/sambanova/run.yaml | 14 ++-- llama_stack/templates/sambanova/sambanova.py | 28 ++++++++ tests/test_replace_env_vars.py | 66 +++++++++++++++++++ 10 files changed, 184 insertions(+), 20 deletions(-) create mode 100644 tests/test_replace_env_vars.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 69a096e97..3abcc3772 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -127,6 +127,10 @@ async def resolve_impls( specs = {} for provider in providers: + if not provider.provider_id or provider.provider_id == "__disabled__": + log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + continue + if provider.provider_type not in provider_registry[api]: raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1328c88ef..8f895e170 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -155,18 +155,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any: return result elif isinstance(config, str): - pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" + # Updated pattern to support both default values (:) and conditional values (+) + pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}" def get_env_var(match): env_var = match.group(1) - default_val = match.group(2) + operator = match.group(2) # ':' for default, '+' for conditional + value_expr = match.group(3) - value = os.environ.get(env_var) - if not value: - if default_val is None: - raise EnvVarError(env_var, path) + env_value = os.environ.get(env_var) + + if operator == ":": # Default value syntax: ${env.FOO:default} + if not env_value: + if value_expr is None: + raise EnvVarError(env_var, path) + else: + value = value_expr else: - value = default_val + value = env_value + elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set} + if env_value: + value = value_expr + else: + # If env var is not set, return empty string for the conditional case + value = "" + else: # No operator case: ${env.FOO} + if not env_value: + raise EnvVarError(env_var, path) + value = env_value # expand "~" from the values return os.path.expanduser(value) diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index 9eae9ed67..fa6e5bede 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel): kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: return { "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index cbbfa9de3..3e2463252 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel): url: str @classmethod - def sample_config(cls) -> Dict[str, Any]: - return {"url": "{env.CHROMADB_URL}"} + def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> Dict[str, Any]: + return {"url": url} diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 7811de1ca..e9eb0f12d 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict + from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type @@ -16,3 +18,15 @@ class PGVectorVectorIOConfig(BaseModel): db: str = Field(default="postgres") user: str = Field(default="postgres") password: str = Field(default="mysecretpassword") + + @classmethod + def sample_run_config( + cls, + host: str = "${env.PGVECTOR_HOST:localhost}", + port: int = "${env.PGVECTOR_PORT:5432}", + db: str = "${env.PGVECTOR_DB}", + user: str = "${env.PGVECTOR_USER}", + password: str = "${env.PGVECTOR_PASSWORD}", + **kwargs: Any, + ) -> Dict[str, Any]: + return {"host": host, "port": port, "db": db, "user": user, "password": password} diff --git a/llama_stack/templates/dev/dev.py b/llama_stack/templates/dev/dev.py index 694913119..e8aa31a7e 100644 --- a/llama_stack/templates/dev/dev.py +++ b/llama_stack/templates/dev/dev.py @@ -27,6 +27,8 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry @@ -94,11 +96,27 @@ def get_distribution_template() -> DistributionTemplate: } name = "dev" - vector_io_provider = Provider( - provider_id="sqlite-vec", - provider_type="inline::sqlite-vec", - config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), - ) + vector_io_providers = [ + Provider( + provider_id="sqlite-vec", + provider_type="inline::sqlite-vec", + config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), + ), + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + Provider( + provider_id="${env.ENABLE_PGVECTOR+pgvector}", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig.sample_run_config( + db="${env.PGVECTOR_DB:}", + user="${env.PGVECTOR_USER:}", + password="${env.PGVECTOR_PASSWORD:}", + ), + ), + ] embedding_provider = Provider( provider_id="sentence-transformers", provider_type="inline::sentence-transformers", @@ -141,7 +159,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": inference_providers + [embedding_provider], - "vector_io": [vector_io_provider], + "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], default_tool_groups=default_tool_groups, diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml index f1d72d572..71fbcb353 100644 --- a/llama_stack/templates/dev/run.yaml +++ b/llama_stack/templates/dev/run.yaml @@ -42,6 +42,18 @@ providers: provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db + - provider_id: ${env.ENABLE_CHROMADB+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:} + - provider_id: ${env.ENABLE_PGVECTOR+pgvector} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:localhost} + port: ${env.PGVECTOR_PORT:5432} + db: ${env.PGVECTOR_DB:} + user: ${env.PGVECTOR_USER:} + password: ${env.PGVECTOR_PASSWORD:} safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 124d11baf..cfa0cc194 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -22,12 +22,18 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db - - provider_id: chromadb + - provider_id: ${env.ENABLE_CHROMADB+chromadb} provider_type: remote::chromadb - config: {} - - provider_id: pgvector + config: + url: ${env.CHROMADB_URL:} + - provider_id: ${env.ENABLE_PGVECTOR+pgvector} provider_type: remote::pgvector - config: {} + config: + host: ${env.PGVECTOR_HOST:localhost} + port: ${env.PGVECTOR_PORT:5432} + db: ${env.PGVECTOR_DB:} + user: ${env.PGVECTOR_USER:} + password: ${env.PGVECTOR_PASSWORD:} safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 0a0b6bd7e..08c3a54cc 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -11,8 +11,11 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry @@ -38,6 +41,30 @@ def get_distribution_template() -> DistributionTemplate: config=SambaNovaImplConfig.sample_run_config(), ) + vector_io_providers = [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config( + __distro_dir__=f"distributions/{name}", + ), + ), + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + Provider( + provider_id="${env.ENABLE_PGVECTOR+pgvector}", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig.sample_run_config( + db="${env.PGVECTOR_DB:}", + user="${env.PGVECTOR_USER:}", + password="${env.PGVECTOR_PASSWORD:}", + ), + ), + ] + available_models = { name: MODEL_ENTRIES, } @@ -69,6 +96,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider], + "vector_io": vector_io_providers, }, default_models=default_models, default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], diff --git a/tests/test_replace_env_vars.py b/tests/test_replace_env_vars.py new file mode 100644 index 000000000..7fcbbfde9 --- /dev/null +++ b/tests/test_replace_env_vars.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest + +from llama_stack.distribution.stack import replace_env_vars + + +class TestReplaceEnvVars(unittest.TestCase): + def setUp(self): + # Clear any existing environment variables we'll use in tests + for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: + if var in os.environ: + del os.environ[var] + + # Set up test environment variables + os.environ["TEST_VAR"] = "test_value" + os.environ["EMPTY_VAR"] = "" + os.environ["ZERO_VAR"] = "0" + + def test_simple_replacement(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value") + + def test_default_value_when_not_set(self): + self.assertEqual(replace_env_vars("${env.NOT_SET:default}"), "default") + + def test_default_value_when_set(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR:default}"), "test_value") + + def test_default_value_when_empty(self): + self.assertEqual(replace_env_vars("${env.EMPTY_VAR:default}"), "default") + + def test_conditional_value_when_set(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR+conditional}"), "conditional") + + def test_conditional_value_when_not_set(self): + self.assertEqual(replace_env_vars("${env.NOT_SET+conditional}"), "") + + def test_conditional_value_when_empty(self): + self.assertEqual(replace_env_vars("${env.EMPTY_VAR+conditional}"), "") + + def test_conditional_value_with_zero(self): + self.assertEqual(replace_env_vars("${env.ZERO_VAR+conditional}"), "conditional") + + def test_mixed_syntax(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR:default} and ${env.NOT_SET+conditional}"), "test_value and ") + self.assertEqual( + replace_env_vars("${env.NOT_SET:default} and ${env.TEST_VAR+conditional}"), "default and conditional" + ) + + def test_nested_structures(self): + data = { + "key1": "${env.TEST_VAR:default}", + "key2": ["${env.NOT_SET:default}", "${env.TEST_VAR+conditional}"], + "key3": {"nested": "${env.NOT_SET+conditional}"}, + } + expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": ""}} + self.assertEqual(replace_env_vars(data), expected) + + +if __name__ == "__main__": + unittest.main()