Merge branch 'refs/heads/main' into preprocessors

# Conflicts:
#	llama_stack/distribution/routers/routers.py
#	llama_stack/templates/ollama/build.yaml
#	llama_stack/templates/ollama/run-with-safety.yaml
#	llama_stack/templates/ollama/run.yaml
#	llama_stack/templates/remote-vllm/build.yaml
#	llama_stack/templates/remote-vllm/run-with-safety.yaml
#	llama_stack/templates/remote-vllm/run.yaml
#	llama_stack/templates/together/build.yaml
#	llama_stack/templates/together/run-with-safety.yaml
#	llama_stack/templates/together/run.yaml
This commit is contained in:
ilya-kolchinsky 2025-03-07 16:20:30 +01:00
commit 6b9f673fdb
313 changed files with 181388 additions and 7064 deletions

View file

@ -5,14 +5,15 @@
# the root directory of this source tree.
import importlib.resources
import logging
import os
import re
import tempfile
from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
@ -35,14 +36,13 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
class LlamaStack(
VectorDBs,
@ -106,12 +106,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
log.info(
logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
log.info("")
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
@ -160,18 +159,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result
elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
# Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
def get_env_var(match):
env_var = match.group(1)
default_val = match.group(2)
operator = match.group(2) # ':' for default, '+' for conditional
value_expr = match.group(3)
value = os.environ.get(env_var)
if not value:
if default_val is None:
raise EnvVarError(env_var, path)
env_value = os.environ.get(env_var)
if operator == ":": # Default value syntax: ${env.FOO:default}
if not env_value:
if value_expr is None:
raise EnvVarError(env_var, path)
else:
value = value_expr
else:
value = default_val
value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
if env_value:
value = value_expr
else:
# If env var is not set, return empty string for the conditional case
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values
return os.path.expanduser(value)
@ -220,3 +235,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config