forked from phoenix-oss/llama-stack-mirror
feat: allow building distro with external providers (#1967)
# What does this PR do? We can now build a distribution that includes external providers. Closes: https://github.com/meta-llama/llama-stack/issues/1948 ## Test Plan Build a distro with an external provider following the doc instructions. [//]: # (## Documentation) Added. Rendered:  Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
c4570bcb48
commit
94f83382eb
11 changed files with 137 additions and 69 deletions
|
@ -210,16 +210,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
print(f"uv pip install {' '.join(normal_deps)}")
|
||||
for special_dep in special_deps:
|
||||
|
@ -274,9 +267,10 @@ def _generate_run_config(
|
|||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry()
|
||||
provider_registry = get_provider_registry(build_config)
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
|
@ -290,8 +284,22 @@ def _generate_run_config(
|
|||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
if hasattr(config_type, "sample_run_config"):
|
||||
try:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
except ModuleNotFoundError:
|
||||
# HACK ALERT:
|
||||
# This code executes after building is done, the import cannot work since the
|
||||
# package is either available in the venv or container - not available on the host.
|
||||
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||
# external
|
||||
cprint(
|
||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||
color="yellow",
|
||||
)
|
||||
# Set config_type to None to avoid UnboundLocalError
|
||||
config_type = None
|
||||
|
||||
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||
else:
|
||||
config = {}
|
||||
|
@ -323,6 +331,7 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
|
|
|
@ -7,16 +7,16 @@
|
|||
import importlib.resources
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.templates.template import DistributionTemplate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,19 +37,24 @@ class ApiInput(BaseModel):
|
|||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config_providers: Dict[str, List[Provider]],
|
||||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
all_providers = get_provider_registry()
|
||||
# Extract providers based on config type
|
||||
if isinstance(config, DistributionTemplate):
|
||||
providers = config.providers
|
||||
elif isinstance(config, BuildConfig):
|
||||
providers = config.distribution_spec.providers
|
||||
deps = []
|
||||
registry = get_provider_registry(config)
|
||||
|
||||
for api_str, provider_or_providers in config_providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||
|
||||
for provider in providers:
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
|
@ -71,8 +76,8 @@ def get_provider_dependencies(
|
|||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
normal_deps, special_deps = get_provider_dependencies(config)
|
||||
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||
|
@ -91,7 +96,7 @@ def build_image(
|
|||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
|
|
@ -90,7 +90,7 @@ WORKDIR /app
|
|||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||
curl wget telnet \
|
||||
curl wget telnet git\
|
||||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
|
|
|
@ -326,3 +326,12 @@ class BuildConfig(BaseModel):
|
|||
default="conda",
|
||||
description="Type of package to build (conda | container | venv)",
|
||||
)
|
||||
image_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, List
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -97,7 +96,9 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
|||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
def get_provider_registry(
|
||||
config=None,
|
||||
) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
|
@ -122,7 +123,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
|||
llama-guard.yaml
|
||||
|
||||
Args:
|
||||
config: Optional StackRunConfig containing the external providers directory path
|
||||
config: Optional object containing the external providers directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
@ -142,7 +143,8 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
|||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
if config and config.external_providers_dir:
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue