feat: allow building distro with external providers (#1967)

# What does this PR do?

We can now build a distribution that includes external providers.
Closes: https://github.com/meta-llama/llama-stack/issues/1948

## Test Plan

Build a distro with an external provider following the doc instructions.

[//]: # (## Documentation)

Added.

Rendered:


![Screenshot 2025-04-18 at 11 26
39](https://github.com/user-attachments/assets/afcf3d50-8d30-48c3-8d24-06a4b3662881)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-04-18 17:18:28 +02:00 committed by GitHub
parent c4570bcb48
commit 94f83382eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 137 additions and 69 deletions

View file

@ -7,16 +7,16 @@
import importlib.resources
import logging
from pathlib import Path
from typing import Dict, List
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack.templates.template import DistributionTemplate
log = logging.getLogger(__name__)
@ -37,19 +37,24 @@ class ApiInput(BaseModel):
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]],
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great
# Providers from BuildConfig and RunConfig are subtly different not great
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
@ -71,8 +76,8 @@ def get_provider_dependencies(
return list(set(normal_deps)), list(set(special_deps))
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
@ -91,7 +96,7 @@ def build_image(
):
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == LlamaStackImageType.CONTAINER.value: