From 0883944bc3ec7fe76b4c1f48dc25bd6a7d96776e Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 26 Jun 2025 20:59:15 -0400 Subject: [PATCH] fix: Some missed env variable changes from PR 2490 (#2538) # What does this PR do? Some templates were still using the old environment variable substition syntax instead of the new one and were not getting substituted properly. Also, some places didn't handle the new None vs old empty string ("") values that come from the conditional environment variable substitution. This gets the starter and remote-vllm distributions starting again, and I tested various permutations of the starter as chroma and pgvector needed some adjustments to their config classes to handle the new possible `None` values. And, I had to tweak our `Provider` class to also handle `None` values, for cases where we disable providers in the starter config via environment variables. This may not have caught everything that was missed, but I did grep around quite a bit to try and find anything lingering. ## Test Plan The following permutations now all run (or attempt to run to the point of complaining that they can't connect to chroma, vllm, etc) when before they failed immediately on startup because of bad environment variable substitions: ``` uv run llama stack run llama_stack/templates/starter/run.yaml ENABLE_SQLITE_VEC=true uv run llama stack run llama_stack/templates/starter/run.yaml ENABLE_PGVECTOR=true uv run llama stack run llama_stack/templates/starter/run.yaml ENABLE_CHROMADB=true uv run llama stack run llama_stack/templates/starter/run.yaml uv run llama stack run llama_stack/templates/remote-vllm/run.yaml ``` Signed-off-by: Ben Browning Co-authored-by: raghotham --- llama_stack/distribution/datatypes.py | 4 +++- llama_stack/distribution/providers.py | 3 +++ llama_stack/distribution/resolver.py | 4 ++++ .../remote/vector_io/chroma/chroma.py | 3 +++ .../remote/vector_io/chroma/config.py | 2 +- .../remote/vector_io/pgvector/config.py | 10 +++++----- .../experimental-post-training/run.yaml | 18 +++++++++--------- .../meta-reference-gpu/meta_reference.py | 4 ++-- .../meta-reference-gpu/run-with-safety.yaml | 4 ++-- .../templates/meta-reference-gpu/run.yaml | 2 +- .../templates/open-benchmark/open_benchmark.py | 10 +++++----- llama_stack/templates/open-benchmark/run.yaml | 8 ++++---- .../templates/postgres-demo/postgres_demo.py | 2 +- llama_stack/templates/postgres-demo/run.yaml | 2 +- .../templates/remote-vllm/run-with-safety.yaml | 2 +- llama_stack/templates/remote-vllm/run.yaml | 2 +- llama_stack/templates/remote-vllm/vllm.py | 2 +- llama_stack/templates/starter/run.yaml | 2 +- llama_stack/templates/starter/starter.py | 2 +- 19 files changed, 49 insertions(+), 37 deletions(-) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e07da001e..5e48ac0ad 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -146,7 +146,9 @@ in the runtime configuration to help route to the correct provider.""", class Provider(BaseModel): - provider_id: str + # provider_id of None means that the provider is not enabled - this happens + # when the provider is enabled via a conditional environment variable + provider_id: str | None provider_type: str config: dict[str, Any] diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index f238e3bba..1d9c1f4e9 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -48,6 +48,9 @@ class ProviderImpl(Providers): ret = [] for api, providers in safe_config.providers.items(): for p in providers: + # Skip providers that are not enabled + if p.provider_id is None: + continue ret.append( ProviderInfo( api=api, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3726bb3a5..46cd1161e 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -255,6 +255,10 @@ async def instantiate_providers( impls: dict[Api, Any] = {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: + # Skip providers that are not enabled + if provider.provider_id is None: + continue + deps = {a: impls[a] for a in provider.spec.api_dependencies} for a in provider.spec.optional_api_dependencies: if a in impls: diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 06d1786f0..3bef39e9c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -137,6 +137,9 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def initialize(self) -> None: if isinstance(self.config, RemoteChromaVectorIOConfig): + if not self.config.url: + raise ValueError("URL is a required parameter for the remote Chroma provider's config") + log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url) diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index 4e893fab4..bd11d5f8c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -10,7 +10,7 @@ from pydantic import BaseModel class ChromaVectorIOConfig(BaseModel): - url: str + url: str | None @classmethod def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 041e864ca..92908aa8a 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -13,11 +13,11 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class PGVectorVectorIOConfig(BaseModel): - host: str = Field(default="localhost") - port: int = Field(default=5432) - db: str = Field(default="postgres") - user: str = Field(default="postgres") - password: str = Field(default="mysecretpassword") + host: str | None = Field(default="localhost") + port: int | None = Field(default=5432) + db: str | None = Field(default="postgres") + user: str | None = Field(default="postgres") + password: str | None = Field(default="mysecretpassword") @classmethod def sample_run_config( diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 393cba41d..a74aa3647 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -24,7 +24,7 @@ providers: - provider_id: ollama provider_type: remote::ollama config: - url: ${env.OLLAMA_URL:http://localhost:11434} + url: ${env.OLLAMA_URL:=http://localhost:11434} eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -32,7 +32,7 @@ providers: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db scoring: - provider_id: basic provider_type: inline::basic @@ -40,7 +40,7 @@ providers: - provider_id: braintrust provider_type: inline::braintrust config: - openai_api_key: ${env.OPENAI_API_KEY:} + openai_api_key: ${env.OPENAI_API_KEY:+} datasetio: - provider_id: localfs provider_type: inline::localfs @@ -48,14 +48,14 @@ providers: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/localfs_datasetio.db - provider_id: huggingface provider_type: remote::huggingface config: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/huggingface}/huggingface_datasetio.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -74,7 +74,7 @@ providers: persistence_store: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/agents_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/agents_store.db safety: - provider_id: llama-guard provider_type: inline::llama-guard @@ -86,19 +86,19 @@ providers: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/faiss_store.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} + api_key: ${env.BRAVE_SEARCH_API_KEY:+} max_results: 3 metadata_store: namespace: null type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/registry.db models: [] shields: [] vector_dbs: [] diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index 57fb8f2af..4bfb4e9d8 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -46,7 +46,7 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::meta-reference", config=MetaReferenceInferenceConfig.sample_run_config( model="${env.INFERENCE_MODEL}", - checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}", + checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:=null}", ), ) embedding_provider = Provider( @@ -112,7 +112,7 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::meta-reference", config=MetaReferenceInferenceConfig.sample_run_config( model="${env.SAFETY_MODEL}", - checkpoint_dir="${env.SAFETY_CHECKPOINT_DIR:null}", + checkpoint_dir="${env.SAFETY_CHECKPOINT_DIR:=null}", ), ), ], diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 6b15a1e01..f60f4505f 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -16,7 +16,7 @@ providers: provider_type: inline::meta-reference config: model: ${env.INFERENCE_MODEL} - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:=null} quantization: type: ${env.QUANTIZATION_TYPE:=bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:=0} @@ -29,7 +29,7 @@ providers: provider_type: inline::meta-reference config: model: ${env.SAFETY_MODEL} - checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} + checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:=null} quantization: type: ${env.QUANTIZATION_TYPE:=bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:=0} diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 1b44a0b3e..064b958c8 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -16,7 +16,7 @@ providers: provider_type: inline::meta-reference config: model: ${env.INFERENCE_MODEL} - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:=null} quantization: type: ${env.QUANTIZATION_TYPE:=bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:=0} diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index b4cfbdb52..8d7a9dc1e 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -46,7 +46,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo model_type=ModelType.llm, ) ], - OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"), + OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:+}"), ), ( "anthropic", @@ -56,7 +56,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo model_type=ModelType.llm, ) ], - AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"), + AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:+}"), ), ( "gemini", @@ -66,17 +66,17 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo model_type=ModelType.llm, ) ], - GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"), + GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:+}"), ), ( "groq", [], - GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"), + GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:+}"), ), ( "together", [], - TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"), + TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:+}"), ), ] inference_providers = [] diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 403b0fd3d..653d76bd4 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -15,20 +15,20 @@ providers: - provider_id: openai provider_type: remote::openai config: - api_key: ${env.OPENAI_API_KEY:} + api_key: ${env.OPENAI_API_KEY:+} - provider_id: anthropic provider_type: remote::anthropic config: - api_key: ${env.ANTHROPIC_API_KEY:} + api_key: ${env.ANTHROPIC_API_KEY:+} - provider_id: gemini provider_type: remote::gemini config: - api_key: ${env.GEMINI_API_KEY:} + api_key: ${env.GEMINI_API_KEY:+} - provider_id: groq provider_type: remote::groq config: url: https://api.groq.com - api_key: ${env.GROQ_API_KEY:} + api_key: ${env.GROQ_API_KEY:+} - provider_id: together provider_type: remote::together config: diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py index 5d42b8901..5b1a302e3 100644 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -29,7 +29,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="vllm-inference", provider_type="remote::vllm", config=VLLMInferenceAdapterConfig.sample_run_config( - url="${env.VLLM_URL:http://localhost:8000/v1}", + url="${env.VLLM_URL:=http://localhost:8000/v1}", ), ), ] diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml index 03b7a59fb..66253cbdb 100644 --- a/llama_stack/templates/postgres-demo/run.yaml +++ b/llama_stack/templates/postgres-demo/run.yaml @@ -12,7 +12,7 @@ providers: - provider_id: vllm-inference provider_type: remote::vllm config: - url: ${env.VLLM_URL:http://localhost:8000/v1} + url: ${env.VLLM_URL:=http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index b297f1489..e306a771b 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -15,7 +15,7 @@ providers: - provider_id: vllm-inference provider_type: remote::vllm config: - url: ${env.VLLM_URL:http://localhost:8000/v1} + url: ${env.VLLM_URL:=http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 6bd332cc9..1dbef96a2 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -15,7 +15,7 @@ providers: - provider_id: vllm-inference provider_type: remote::vllm config: - url: ${env.VLLM_URL:http://localhost:8000/v1} + url: ${env.VLLM_URL:=http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index 94606e9d0..a8e1d9a58 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -44,7 +44,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="vllm-inference", provider_type="remote::vllm", config=VLLMInferenceAdapterConfig.sample_run_config( - url="${env.VLLM_URL:http://localhost:8000/v1}", + url="${env.VLLM_URL:=http://localhost:8000/v1}", ), ) embedding_provider = Provider( diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index f7c53170b..00faf029e 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -68,7 +68,7 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db - - provider_id: ${env.ENABLE_SQLITE_VEC+sqlite-vec} + - provider_id: ${env.ENABLE_SQLITE_VEC:+sqlite-vec} provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index df31fed84..c0f2646d7 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -175,7 +175,7 @@ def get_distribution_template() -> DistributionTemplate: config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ), Provider( - provider_id="${env.ENABLE_SQLITE_VEC+sqlite-vec}", + provider_id="${env.ENABLE_SQLITE_VEC:+sqlite-vec}", provider_type="inline::sqlite-vec", config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ),