mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: allow conditionally enabling providers in run.yaml
This commit is contained in:
parent
ece354eedd
commit
67d34cfbe3
11 changed files with 184 additions and 21 deletions
|
@ -127,6 +127,9 @@ async def resolve_impls(
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
|
continue
|
||||||
|
|
||||||
if provider.provider_type not in provider_registry[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||||
|
|
||||||
|
|
|
@ -155,18 +155,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
elif isinstance(config, str):
|
elif isinstance(config, str):
|
||||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
# Updated pattern to support both default values (:) and conditional values (+)
|
||||||
|
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
|
||||||
|
|
||||||
def get_env_var(match):
|
def get_env_var(match):
|
||||||
env_var = match.group(1)
|
env_var = match.group(1)
|
||||||
default_val = match.group(2)
|
operator = match.group(2) # ':' for default, '+' for conditional
|
||||||
|
value_expr = match.group(3)
|
||||||
|
|
||||||
value = os.environ.get(env_var)
|
env_value = os.environ.get(env_var)
|
||||||
if not value:
|
|
||||||
if default_val is None:
|
if operator == ":": # Default value syntax: ${env.FOO:default}
|
||||||
raise EnvVarError(env_var, path)
|
if not env_value:
|
||||||
|
if value_expr is None:
|
||||||
|
raise EnvVarError(env_var, path)
|
||||||
|
else:
|
||||||
|
value = value_expr
|
||||||
else:
|
else:
|
||||||
value = default_val
|
value = env_value
|
||||||
|
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
|
||||||
|
if env_value:
|
||||||
|
value = value_expr
|
||||||
|
else:
|
||||||
|
# If env var is not set, return empty string for the conditional case
|
||||||
|
value = ""
|
||||||
|
else: # No operator case: ${env.FOO}
|
||||||
|
if not env_value:
|
||||||
|
raise EnvVarError(env_var, path)
|
||||||
|
value = env_value
|
||||||
|
|
||||||
# expand "~" from the values
|
# expand "~" from the values
|
||||||
return os.path.expanduser(value)
|
return os.path.expanduser(value)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
|
||||||
kvstore: KVStoreConfig
|
kvstore: KVStoreConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__,
|
||||||
|
|
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> Dict[str, Any]:
|
||||||
return {"url": "{env.CHROMADB_URL}"}
|
return {"url": url}
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -16,3 +18,15 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
db: str = Field(default="postgres")
|
db: str = Field(default="postgres")
|
||||||
user: str = Field(default="postgres")
|
user: str = Field(default="postgres")
|
||||||
password: str = Field(default="mysecretpassword")
|
password: str = Field(default="mysecretpassword")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(
|
||||||
|
cls,
|
||||||
|
host: str = "${env.PGVECTOR_HOST:localhost}",
|
||||||
|
port: int = "${env.PGVECTOR_PORT:5432}",
|
||||||
|
db: str = "${env.PGVECTOR_DB}",
|
||||||
|
user: str = "${env.PGVECTOR_USER}",
|
||||||
|
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
||||||
|
|
|
@ -27,6 +27,8 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,11 +96,27 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
}
|
}
|
||||||
name = "dev"
|
name = "dev"
|
||||||
|
|
||||||
vector_io_provider = Provider(
|
vector_io_providers = [
|
||||||
provider_id="sqlite-vec",
|
Provider(
|
||||||
provider_type="inline::sqlite-vec",
|
provider_id="sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
provider_type="inline::sqlite-vec",
|
||||||
)
|
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
db="${env.PGVECTOR_DB:}",
|
||||||
|
user="${env.PGVECTOR_USER:}",
|
||||||
|
password="${env.PGVECTOR_PASSWORD:}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
embedding_provider = Provider(
|
embedding_provider = Provider(
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
|
@ -141,7 +159,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": inference_providers + [embedding_provider],
|
"inference": inference_providers + [embedding_provider],
|
||||||
"vector_io": [vector_io_provider],
|
"vector_io": vector_io_providers,
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=default_models + [embedding_model],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
|
|
|
@ -42,6 +42,18 @@ providers:
|
||||||
provider_type: inline::sqlite-vec
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
config:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
||||||
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:}
|
||||||
|
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||||
|
provider_type: remote::pgvector
|
||||||
|
config:
|
||||||
|
host: ${env.PGVECTOR_HOST:localhost}
|
||||||
|
port: ${env.PGVECTOR_PORT:5432}
|
||||||
|
db: ${env.PGVECTOR_DB:}
|
||||||
|
user: ${env.PGVECTOR_USER:}
|
||||||
|
password: ${env.PGVECTOR_PASSWORD:}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -22,12 +22,18 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
|
||||||
- provider_id: chromadb
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
config: {}
|
config:
|
||||||
- provider_id: pgvector
|
url: ${env.CHROMADB_URL:}
|
||||||
|
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||||
provider_type: remote::pgvector
|
provider_type: remote::pgvector
|
||||||
config: {}
|
config:
|
||||||
|
host: ${env.PGVECTOR_HOST:localhost}
|
||||||
|
port: ${env.PGVECTOR_PORT:5432}
|
||||||
|
db: ${env.PGVECTOR_DB:}
|
||||||
|
user: ${env.PGVECTOR_USER:}
|
||||||
|
password: ${env.PGVECTOR_PASSWORD:}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -11,8 +11,11 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +41,30 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=SambaNovaImplConfig.sample_run_config(),
|
config=SambaNovaImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vector_io_providers = [
|
||||||
|
Provider(
|
||||||
|
provider_id="faiss",
|
||||||
|
provider_type="inline::faiss",
|
||||||
|
config=FaissVectorIOConfig.sample_run_config(
|
||||||
|
__distro_dir__=f"distributions/{name}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
db="${env.PGVECTOR_DB:}",
|
||||||
|
user="${env.PGVECTOR_USER:}",
|
||||||
|
password="${env.PGVECTOR_PASSWORD:}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
available_models = {
|
available_models = {
|
||||||
name: MODEL_ENTRIES,
|
name: MODEL_ENTRIES,
|
||||||
}
|
}
|
||||||
|
@ -69,6 +96,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider],
|
||||||
|
"vector_io": vector_io_providers,
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models,
|
||||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
|
|
@ -9,7 +9,7 @@ import random
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
INLINE_VECTOR_DB_PROVIDERS = [
|
INLINE_VECTOR_DB_PROVIDERS = [
|
||||||
"faiss",
|
"chromadb",
|
||||||
# TODO: add sqlite_vec to templates
|
# TODO: add sqlite_vec to templates
|
||||||
# "sqlite_vec",
|
# "sqlite_vec",
|
||||||
]
|
]
|
||||||
|
|
66
tests/test_replace_env_vars.py
Normal file
66
tests/test_replace_env_vars.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from llama_stack.distribution.stack import replace_env_vars
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplaceEnvVars(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Clear any existing environment variables we'll use in tests
|
||||||
|
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||||
|
if var in os.environ:
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
|
# Set up test environment variables
|
||||||
|
os.environ["TEST_VAR"] = "test_value"
|
||||||
|
os.environ["EMPTY_VAR"] = ""
|
||||||
|
os.environ["ZERO_VAR"] = "0"
|
||||||
|
|
||||||
|
def test_simple_replacement(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
|
||||||
|
|
||||||
|
def test_default_value_when_not_set(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.NOT_SET:default}"), "default")
|
||||||
|
|
||||||
|
def test_default_value_when_set(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.TEST_VAR:default}"), "test_value")
|
||||||
|
|
||||||
|
def test_default_value_when_empty(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:default}"), "default")
|
||||||
|
|
||||||
|
def test_conditional_value_when_set(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.TEST_VAR+conditional}"), "conditional")
|
||||||
|
|
||||||
|
def test_conditional_value_when_not_set(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.NOT_SET+conditional}"), "")
|
||||||
|
|
||||||
|
def test_conditional_value_when_empty(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.EMPTY_VAR+conditional}"), "")
|
||||||
|
|
||||||
|
def test_conditional_value_with_zero(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.ZERO_VAR+conditional}"), "conditional")
|
||||||
|
|
||||||
|
def test_mixed_syntax(self):
|
||||||
|
self.assertEqual(replace_env_vars("${env.TEST_VAR:default} and ${env.NOT_SET+conditional}"), "test_value and ")
|
||||||
|
self.assertEqual(
|
||||||
|
replace_env_vars("${env.NOT_SET:default} and ${env.TEST_VAR+conditional}"), "default and conditional"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_nested_structures(self):
|
||||||
|
data = {
|
||||||
|
"key1": "${env.TEST_VAR:default}",
|
||||||
|
"key2": ["${env.NOT_SET:default}", "${env.TEST_VAR+conditional}"],
|
||||||
|
"key3": {"nested": "${env.NOT_SET+conditional}"},
|
||||||
|
}
|
||||||
|
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": ""}}
|
||||||
|
self.assertEqual(replace_env_vars(data), expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue