feat: allow building distro with external providers (#1967)

# What does this PR do?

We can now build a distribution that includes external providers.
Closes: https://github.com/meta-llama/llama-stack/issues/1948

## Test Plan

Build a distro with an external provider following the doc instructions.

[//]: # (## Documentation)

Added.

Rendered:


![Screenshot 2025-04-18 at 11 26
39](https://github.com/user-attachments/assets/afcf3d50-8d30-48c3-8d24-06a4b3662881)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-04-18 17:18:28 +02:00 committed by GitHub
parent c4570bcb48
commit 94f83382eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 137 additions and 69 deletions

View file

@ -210,16 +210,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
cprint(
"Please specify --image-name when building a container from a config file",
color="red",
)
sys.exit(1)
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps:
@ -274,9 +267,10 @@ def _generate_run_config(
image_name=image_name,
apis=apis,
providers={},
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
)
# build providers dict
provider_registry = get_provider_registry()
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
provider_types = build_config.distribution_spec.providers[api]
@ -290,8 +284,22 @@ def _generate_run_config(
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
except ModuleNotFoundError:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
color="yellow",
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
@ -323,6 +331,7 @@ def _run_stack_build_command_from_build_config(
template_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> str:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name:
image_name = f"distribution-{template_name}"

View file

@ -7,16 +7,16 @@
import importlib.resources
import logging
from pathlib import Path
from typing import Dict, List
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack.templates.template import DistributionTemplate
log = logging.getLogger(__name__)
@ -37,19 +37,24 @@ class ApiInput(BaseModel):
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]],
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great
# Providers from BuildConfig and RunConfig are subtly different not great
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
@ -71,8 +76,8 @@ def get_provider_dependencies(
return list(set(normal_deps)), list(set(special_deps))
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
@ -91,7 +96,7 @@ def build_image(
):
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == LlamaStackImageType.CONTAINER.value:

View file

@ -90,7 +90,7 @@ WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet \
curl wget telnet git\
procps psmisc lsof \
traceroute \
bubblewrap \

View file

@ -326,3 +326,12 @@ class BuildConfig(BaseModel):
default="conda",
description="Type of package to build (conda | container | venv)",
)
image_name: Optional[str] = Field(
default=None,
description="Name of the distribution to build",
)
external_providers_dir: Optional[str] = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)

View file

@ -12,7 +12,6 @@ from typing import Any, Dict, List
import yaml
from pydantic import BaseModel
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
@ -97,7 +96,9 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
return spec
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
def get_provider_registry(
config=None,
) -> Dict[Api, Dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
@ -122,7 +123,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
llama-guard.yaml
Args:
config: Optional StackRunConfig containing the external providers directory path
config: Optional object containing the external providers directory path
Returns:
A dictionary mapping APIs to their available providers
@ -142,7 +143,8 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
if config and config.external_providers_dir:
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")