feat: allow conditionally enabling providers in run.yaml

This commit is contained in:
Ashwin Bharambe 2025-02-27 23:46:35 -08:00
parent ece354eedd
commit 67d34cfbe3
11 changed files with 184 additions and 21 deletions

View file

@ -127,6 +127,9 @@ async def resolve_impls(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
continue
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")

View file

@ -155,18 +155,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result
elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
# Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
def get_env_var(match):
env_var = match.group(1)
default_val = match.group(2)
operator = match.group(2) # ':' for default, '+' for conditional
value_expr = match.group(3)
value = os.environ.get(env_var)
if not value:
if default_val is None:
raise EnvVarError(env_var, path)
env_value = os.environ.get(env_var)
if operator == ":": # Default value syntax: ${env.FOO:default}
if not env_value:
if value_expr is None:
raise EnvVarError(env_var, path)
else:
value = value_expr
else:
value = default_val
value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
if env_value:
value = value_expr
else:
# If env var is not set, return empty string for the conditional case
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values
return os.path.expanduser(value)

View file

@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,

View file

@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
url: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"url": "{env.CHROMADB_URL}"}
def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> Dict[str, Any]:
return {"url": url}

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@ -16,3 +18,15 @@ class PGVectorVectorIOConfig(BaseModel):
db: str = Field(default="postgres")
user: str = Field(default="postgres")
password: str = Field(default="mysecretpassword")
@classmethod
def sample_run_config(
cls,
host: str = "${env.PGVECTOR_HOST:localhost}",
port: int = "${env.PGVECTOR_PORT:5432}",
db: str = "${env.PGVECTOR_DB}",
user: str = "${env.PGVECTOR_USER}",
password: str = "${env.PGVECTOR_PASSWORD}",
**kwargs: Any,
) -> Dict[str, Any]:
return {"host": host, "port": port, "db": db, "user": user, "password": password}

View file

@ -27,6 +27,8 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
@ -94,11 +96,27 @@ def get_distribution_template() -> DistributionTemplate:
}
name = "dev"
vector_io_provider = Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
vector_io_providers = [
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:}",
user="${env.PGVECTOR_USER:}",
password="${env.PGVECTOR_PASSWORD:}",
),
),
]
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
@ -141,7 +159,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers + [embedding_provider],
"vector_io": [vector_io_provider],
"vector_io": vector_io_providers,
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,

View file

@ -42,6 +42,18 @@ providers:
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:localhost}
port: ${env.PGVECTOR_PORT:5432}
db: ${env.PGVECTOR_DB:}
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -22,12 +22,18 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
- provider_id: chromadb
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
config:
url: ${env.CHROMADB_URL:}
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
provider_type: remote::pgvector
config: {}
config:
host: ${env.PGVECTOR_HOST:localhost}
port: ${env.PGVECTOR_PORT:5432}
db: ${env.PGVECTOR_DB:}
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -11,8 +11,11 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
@ -38,6 +41,30 @@ def get_distribution_template() -> DistributionTemplate:
config=SambaNovaImplConfig.sample_run_config(),
)
vector_io_providers = [
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(
__distro_dir__=f"distributions/{name}",
),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:}",
user="${env.PGVECTOR_USER:}",
password="${env.PGVECTOR_PASSWORD:}",
),
),
]
available_models = {
name: MODEL_ENTRIES,
}
@ -69,6 +96,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": vector_io_providers,
},
default_models=default_models,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],

View file

@ -9,7 +9,7 @@ import random
import pytest
INLINE_VECTOR_DB_PROVIDERS = [
"faiss",
"chromadb",
# TODO: add sqlite_vec to templates
# "sqlite_vec",
]

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from llama_stack.distribution.stack import replace_env_vars
class TestReplaceEnvVars(unittest.TestCase):
def setUp(self):
# Clear any existing environment variables we'll use in tests
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
if var in os.environ:
del os.environ[var]
# Set up test environment variables
os.environ["TEST_VAR"] = "test_value"
os.environ["EMPTY_VAR"] = ""
os.environ["ZERO_VAR"] = "0"
def test_simple_replacement(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
def test_default_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:default}"), "default")
def test_default_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:default}"), "test_value")
def test_default_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:default}"), "default")
def test_conditional_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR+conditional}"), "conditional")
def test_conditional_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET+conditional}"), "")
def test_conditional_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR+conditional}"), "")
def test_conditional_value_with_zero(self):
self.assertEqual(replace_env_vars("${env.ZERO_VAR+conditional}"), "conditional")
def test_mixed_syntax(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:default} and ${env.NOT_SET+conditional}"), "test_value and ")
self.assertEqual(
replace_env_vars("${env.NOT_SET:default} and ${env.TEST_VAR+conditional}"), "default and conditional"
)
def test_nested_structures(self):
data = {
"key1": "${env.TEST_VAR:default}",
"key2": ["${env.NOT_SET:default}", "${env.TEST_VAR+conditional}"],
"key3": {"nested": "${env.NOT_SET+conditional}"},
}
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": ""}}
self.assertEqual(replace_env_vars(data), expected)
if __name__ == "__main__":
unittest.main()