diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 69a096e97..3abcc3772 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -127,6 +127,10 @@ async def resolve_impls( specs = {} for provider in providers: + if not provider.provider_id or provider.provider_id == "__disabled__": + log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + continue + if provider.provider_type not in provider_registry[api]: raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1328c88ef..8f895e170 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -155,18 +155,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any: return result elif isinstance(config, str): - pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" + # Updated pattern to support both default values (:) and conditional values (+) + pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}" def get_env_var(match): env_var = match.group(1) - default_val = match.group(2) + operator = match.group(2) # ':' for default, '+' for conditional + value_expr = match.group(3) - value = os.environ.get(env_var) - if not value: - if default_val is None: - raise EnvVarError(env_var, path) + env_value = os.environ.get(env_var) + + if operator == ":": # Default value syntax: ${env.FOO:default} + if not env_value: + if value_expr is None: + raise EnvVarError(env_var, path) + else: + value = value_expr else: - value = default_val + value = env_value + elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set} + if env_value: + value = value_expr + else: + # If env var is not set, return empty string for the conditional case + value = "" + else: # No operator case: ${env.FOO} + if not env_value: + raise EnvVarError(env_var, path) + value = env_value # expand "~" from the values return os.path.expanduser(value) diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index 9eae9ed67..fa6e5bede 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel): kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: return { "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index cbbfa9de3..3e2463252 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel): url: str @classmethod - def sample_config(cls) -> Dict[str, Any]: - return {"url": "{env.CHROMADB_URL}"} + def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> Dict[str, Any]: + return {"url": url} diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 7811de1ca..e9eb0f12d 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict + from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type @@ -16,3 +18,15 @@ class PGVectorVectorIOConfig(BaseModel): db: str = Field(default="postgres") user: str = Field(default="postgres") password: str = Field(default="mysecretpassword") + + @classmethod + def sample_run_config( + cls, + host: str = "${env.PGVECTOR_HOST:localhost}", + port: int = "${env.PGVECTOR_PORT:5432}", + db: str = "${env.PGVECTOR_DB}", + user: str = "${env.PGVECTOR_USER}", + password: str = "${env.PGVECTOR_PASSWORD}", + **kwargs: Any, + ) -> Dict[str, Any]: + return {"host": host, "port": port, "db": db, "user": user, "password": password} diff --git a/llama_stack/templates/dev/dev.py b/llama_stack/templates/dev/dev.py index 694913119..e8aa31a7e 100644 --- a/llama_stack/templates/dev/dev.py +++ b/llama_stack/templates/dev/dev.py @@ -27,6 +27,8 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry @@ -94,11 +96,27 @@ def get_distribution_template() -> DistributionTemplate: } name = "dev" - vector_io_provider = Provider( - provider_id="sqlite-vec", - provider_type="inline::sqlite-vec", - config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), - ) + vector_io_providers = [ + Provider( + provider_id="sqlite-vec", + provider_type="inline::sqlite-vec", + config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), + ), + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + Provider( + provider_id="${env.ENABLE_PGVECTOR+pgvector}", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig.sample_run_config( + db="${env.PGVECTOR_DB:}", + user="${env.PGVECTOR_USER:}", + password="${env.PGVECTOR_PASSWORD:}", + ), + ), + ] embedding_provider = Provider( provider_id="sentence-transformers", provider_type="inline::sentence-transformers", @@ -141,7 +159,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": inference_providers + [embedding_provider], - "vector_io": [vector_io_provider], + "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], default_tool_groups=default_tool_groups, diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml index f1d72d572..71fbcb353 100644 --- a/llama_stack/templates/dev/run.yaml +++ b/llama_stack/templates/dev/run.yaml @@ -42,6 +42,18 @@ providers: provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db + - provider_id: ${env.ENABLE_CHROMADB+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:} + - provider_id: ${env.ENABLE_PGVECTOR+pgvector} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:localhost} + port: ${env.PGVECTOR_PORT:5432} + db: ${env.PGVECTOR_DB:} + user: ${env.PGVECTOR_USER:} + password: ${env.PGVECTOR_PASSWORD:} safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 124d11baf..cfa0cc194 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -22,12 +22,18 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db - - provider_id: chromadb + - provider_id: ${env.ENABLE_CHROMADB+chromadb} provider_type: remote::chromadb - config: {} - - provider_id: pgvector + config: + url: ${env.CHROMADB_URL:} + - provider_id: ${env.ENABLE_PGVECTOR+pgvector} provider_type: remote::pgvector - config: {} + config: + host: ${env.PGVECTOR_HOST:localhost} + port: ${env.PGVECTOR_PORT:5432} + db: ${env.PGVECTOR_DB:} + user: ${env.PGVECTOR_USER:} + password: ${env.PGVECTOR_PASSWORD:} safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 0a0b6bd7e..08c3a54cc 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -11,8 +11,11 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry @@ -38,6 +41,30 @@ def get_distribution_template() -> DistributionTemplate: config=SambaNovaImplConfig.sample_run_config(), ) + vector_io_providers = [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config( + __distro_dir__=f"distributions/{name}", + ), + ), + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + Provider( + provider_id="${env.ENABLE_PGVECTOR+pgvector}", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig.sample_run_config( + db="${env.PGVECTOR_DB:}", + user="${env.PGVECTOR_USER:}", + password="${env.PGVECTOR_PASSWORD:}", + ), + ), + ] + available_models = { name: MODEL_ENTRIES, } @@ -69,6 +96,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider], + "vector_io": vector_io_providers, }, default_models=default_models, default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], diff --git a/tests/test_replace_env_vars.py b/tests/test_replace_env_vars.py new file mode 100644 index 000000000..7fcbbfde9 --- /dev/null +++ b/tests/test_replace_env_vars.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest + +from llama_stack.distribution.stack import replace_env_vars + + +class TestReplaceEnvVars(unittest.TestCase): + def setUp(self): + # Clear any existing environment variables we'll use in tests + for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: + if var in os.environ: + del os.environ[var] + + # Set up test environment variables + os.environ["TEST_VAR"] = "test_value" + os.environ["EMPTY_VAR"] = "" + os.environ["ZERO_VAR"] = "0" + + def test_simple_replacement(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value") + + def test_default_value_when_not_set(self): + self.assertEqual(replace_env_vars("${env.NOT_SET:default}"), "default") + + def test_default_value_when_set(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR:default}"), "test_value") + + def test_default_value_when_empty(self): + self.assertEqual(replace_env_vars("${env.EMPTY_VAR:default}"), "default") + + def test_conditional_value_when_set(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR+conditional}"), "conditional") + + def test_conditional_value_when_not_set(self): + self.assertEqual(replace_env_vars("${env.NOT_SET+conditional}"), "") + + def test_conditional_value_when_empty(self): + self.assertEqual(replace_env_vars("${env.EMPTY_VAR+conditional}"), "") + + def test_conditional_value_with_zero(self): + self.assertEqual(replace_env_vars("${env.ZERO_VAR+conditional}"), "conditional") + + def test_mixed_syntax(self): + self.assertEqual(replace_env_vars("${env.TEST_VAR:default} and ${env.NOT_SET+conditional}"), "test_value and ") + self.assertEqual( + replace_env_vars("${env.NOT_SET:default} and ${env.TEST_VAR+conditional}"), "default and conditional" + ) + + def test_nested_structures(self): + data = { + "key1": "${env.TEST_VAR:default}", + "key2": ["${env.NOT_SET:default}", "${env.TEST_VAR+conditional}"], + "key3": {"nested": "${env.NOT_SET+conditional}"}, + } + expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": ""}} + self.assertEqual(replace_env_vars(data), expected) + + +if __name__ == "__main__": + unittest.main()