diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e07da001e..5e48ac0ad 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -146,7 +146,9 @@ in the runtime configuration to help route to the correct provider.""", class Provider(BaseModel): - provider_id: str + # provider_id of None means that the provider is not enabled - this happens + # when the provider is enabled via a conditional environment variable + provider_id: str | None provider_type: str config: dict[str, Any] diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index f238e3bba..1d9c1f4e9 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -48,6 +48,9 @@ class ProviderImpl(Providers): ret = [] for api, providers in safe_config.providers.items(): for p in providers: + # Skip providers that are not enabled + if p.provider_id is None: + continue ret.append( ProviderInfo( api=api, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3726bb3a5..46cd1161e 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -255,6 +255,10 @@ async def instantiate_providers( impls: dict[Api, Any] = {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: + # Skip providers that are not enabled + if provider.provider_id is None: + continue + deps = {a: impls[a] for a in provider.spec.api_dependencies} for a in provider.spec.optional_api_dependencies: if a in impls: diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 06d1786f0..3bef39e9c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -137,6 +137,9 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def initialize(self) -> None: if isinstance(self.config, RemoteChromaVectorIOConfig): + if not self.config.url: + raise ValueError("URL is a required parameter for the remote Chroma provider's config") + log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url) diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index 4e893fab4..bd11d5f8c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -10,7 +10,7 @@ from pydantic import BaseModel class ChromaVectorIOConfig(BaseModel): - url: str + url: str | None @classmethod def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 041e864ca..92908aa8a 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -13,11 +13,11 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class PGVectorVectorIOConfig(BaseModel): - host: str = Field(default="localhost") - port: int = Field(default=5432) - db: str = Field(default="postgres") - user: str = Field(default="postgres") - password: str = Field(default="mysecretpassword") + host: str | None = Field(default="localhost") + port: int | None = Field(default=5432) + db: str | None = Field(default="postgres") + user: str | None = Field(default="postgres") + password: str | None = Field(default="mysecretpassword") @classmethod def sample_run_config( diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 393cba41d..a74aa3647 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -24,7 +24,7 @@ providers: - provider_id: ollama provider_type: remote::ollama config: - url: ${env.OLLAMA_URL:http://localhost:11434} + url: ${env.OLLAMA_URL:=http://localhost:11434} eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -32,7 +32,7 @@ providers: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db scoring: - provider_id: basic provider_type: inline::basic @@ -40,7 +40,7 @@ providers: - provider_id: braintrust provider_type: inline::braintrust config: - openai_api_key: ${env.OPENAI_API_KEY:} + openai_api_key: ${env.OPENAI_API_KEY:+} datasetio: - provider_id: localfs provider_type: inline::localfs @@ -48,14 +48,14 @@ providers: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/localfs_datasetio.db - provider_id: huggingface provider_type: remote::huggingface config: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/huggingface}/huggingface_datasetio.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -74,7 +74,7 @@ providers: persistence_store: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/agents_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/agents_store.db safety: - provider_id: llama-guard provider_type: inline::llama-guard @@ -86,19 +86,19 @@ providers: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/faiss_store.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} + api_key: ${env.BRAVE_SEARCH_API_KEY:+} max_results: 3 metadata_store: namespace: null type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/experimental-post-training}/registry.db models: [] shields: [] vector_dbs: [] diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index 57fb8f2af..4bfb4e9d8 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -46,7 +46,7 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::meta-reference", config=MetaReferenceInferenceConfig.sample_run_config( model="${env.INFERENCE_MODEL}", - checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}", + checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:=null}", ), ) embedding_provider = Provider( @@ -112,7 +112,7 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::meta-reference", config=MetaReferenceInferenceConfig.sample_run_config( model="${env.SAFETY_MODEL}", - checkpoint_dir="${env.SAFETY_CHECKPOINT_DIR:null}", + checkpoint_dir="${env.SAFETY_CHECKPOINT_DIR:=null}", ), ), ], diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 6b15a1e01..f60f4505f 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -16,7 +16,7 @@ providers: provider_type: inline::meta-reference config: model: ${env.INFERENCE_MODEL} - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:=null} quantization: type: ${env.QUANTIZATION_TYPE:=bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:=0} @@ -29,7 +29,7 @@ providers: provider_type: inline::meta-reference config: model: ${env.SAFETY_MODEL} - checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} + checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:=null} quantization: type: ${env.QUANTIZATION_TYPE:=bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:=0} diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 1b44a0b3e..064b958c8 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -16,7 +16,7 @@ providers: provider_type: inline::meta-reference config: model: ${env.INFERENCE_MODEL} - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:=null} quantization: type: ${env.QUANTIZATION_TYPE:=bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:=0} diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index b4cfbdb52..8d7a9dc1e 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -46,7 +46,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo model_type=ModelType.llm, ) ], - OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"), + OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:+}"), ), ( "anthropic", @@ -56,7 +56,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo model_type=ModelType.llm, ) ], - AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"), + AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:+}"), ), ( "gemini", @@ -66,17 +66,17 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo model_type=ModelType.llm, ) ], - GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"), + GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:+}"), ), ( "groq", [], - GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"), + GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:+}"), ), ( "together", [], - TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"), + TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:+}"), ), ] inference_providers = [] diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 403b0fd3d..653d76bd4 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -15,20 +15,20 @@ providers: - provider_id: openai provider_type: remote::openai config: - api_key: ${env.OPENAI_API_KEY:} + api_key: ${env.OPENAI_API_KEY:+} - provider_id: anthropic provider_type: remote::anthropic config: - api_key: ${env.ANTHROPIC_API_KEY:} + api_key: ${env.ANTHROPIC_API_KEY:+} - provider_id: gemini provider_type: remote::gemini config: - api_key: ${env.GEMINI_API_KEY:} + api_key: ${env.GEMINI_API_KEY:+} - provider_id: groq provider_type: remote::groq config: url: https://api.groq.com - api_key: ${env.GROQ_API_KEY:} + api_key: ${env.GROQ_API_KEY:+} - provider_id: together provider_type: remote::together config: diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py index 5d42b8901..5b1a302e3 100644 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -29,7 +29,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="vllm-inference", provider_type="remote::vllm", config=VLLMInferenceAdapterConfig.sample_run_config( - url="${env.VLLM_URL:http://localhost:8000/v1}", + url="${env.VLLM_URL:=http://localhost:8000/v1}", ), ), ] diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml index 03b7a59fb..66253cbdb 100644 --- a/llama_stack/templates/postgres-demo/run.yaml +++ b/llama_stack/templates/postgres-demo/run.yaml @@ -12,7 +12,7 @@ providers: - provider_id: vllm-inference provider_type: remote::vllm config: - url: ${env.VLLM_URL:http://localhost:8000/v1} + url: ${env.VLLM_URL:=http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index b297f1489..e306a771b 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -15,7 +15,7 @@ providers: - provider_id: vllm-inference provider_type: remote::vllm config: - url: ${env.VLLM_URL:http://localhost:8000/v1} + url: ${env.VLLM_URL:=http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 6bd332cc9..1dbef96a2 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -15,7 +15,7 @@ providers: - provider_id: vllm-inference provider_type: remote::vllm config: - url: ${env.VLLM_URL:http://localhost:8000/v1} + url: ${env.VLLM_URL:=http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index 94606e9d0..a8e1d9a58 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -44,7 +44,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="vllm-inference", provider_type="remote::vllm", config=VLLMInferenceAdapterConfig.sample_run_config( - url="${env.VLLM_URL:http://localhost:8000/v1}", + url="${env.VLLM_URL:=http://localhost:8000/v1}", ), ) embedding_provider = Provider( diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index f7c53170b..00faf029e 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -68,7 +68,7 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db - - provider_id: ${env.ENABLE_SQLITE_VEC+sqlite-vec} + - provider_id: ${env.ENABLE_SQLITE_VEC:+sqlite-vec} provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index df31fed84..c0f2646d7 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -175,7 +175,7 @@ def get_distribution_template() -> DistributionTemplate: config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ), Provider( - provider_id="${env.ENABLE_SQLITE_VEC+sqlite-vec}", + provider_id="${env.ENABLE_SQLITE_VEC:+sqlite-vec}", provider_type="inline::sqlite-vec", config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ),