From 5625aef48a44c8533c78b97607e09851c4b7266a Mon Sep 17 00:00:00 2001 From: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:18:21 -0500 Subject: [PATCH] Add pip install helper for test and direct scenarios (#404) * initial branch commit * pip install helptext * remove print * pre-commit --- llama_stack/distribution/build.py | 68 +++++++++++++++++-------- llama_stack/providers/tests/resolver.py | 13 ++++- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 0a989d2e4..34e953656 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -48,18 +48,14 @@ class ApiInput(BaseModel): provider: str -def build_image(build_config: BuildConfig, build_file_path: Path): - package_deps = Dependencies( - docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", - pip_packages=SERVER_DEPENDENCIES, - ) - - # extend package dependencies based on providers spec +def get_provider_dependencies( + config_providers: Dict[str, List[Provider]] +) -> tuple[list[str], list[str]]: + """Get normal and special dependencies from provider configuration.""" all_providers = get_provider_registry() - for ( - api_str, - provider_or_providers, - ) in build_config.distribution_spec.providers.items(): + deps = [] + + for api_str, provider_or_providers in config_providers.items(): providers_for_api = all_providers[Api(api_str)] providers = ( @@ -69,25 +65,55 @@ def build_image(build_config: BuildConfig, build_file_path: Path): ) for provider in providers: - if provider not in providers_for_api: + # Providers from BuildConfig and RunConfig are subtly different – not great + provider_type = ( + provider if isinstance(provider, str) else provider.provider_type + ) + + if provider_type not in providers_for_api: raise ValueError( f"Provider `{provider}` is not available for API `{api_str}`" ) - provider_spec = providers_for_api[provider] - package_deps.pip_packages.extend(provider_spec.pip_packages) + provider_spec = providers_for_api[provider_type] + deps.extend(provider_spec.pip_packages) if provider_spec.docker_image: raise ValueError("A stack's dependencies cannot have a docker image") + normal_deps = [] special_deps = [] - deps = [] - for package in package_deps.pip_packages: + for package in deps: if "--no-deps" in package or "--index-url" in package: special_deps.append(package) else: - deps.append(package) - deps = list(set(deps)) - special_deps = list(set(special_deps)) + normal_deps.append(package) + + return list(set(normal_deps)), list(set(special_deps)) + + +def print_pip_install_help(providers: Dict[str, List[Provider]]): + normal_deps, special_deps = get_provider_dependencies(providers) + + print( + f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}" + ) + for special_dep in special_deps: + print(f"\tpip install {special_dep}") + print() + + +def build_image(build_config: BuildConfig, build_file_path: Path): + package_deps = Dependencies( + docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", + pip_packages=SERVER_DEPENDENCIES, + ) + + # extend package dependencies based on providers spec + normal_deps, special_deps = get_provider_dependencies( + build_config.distribution_spec.providers + ) + package_deps.pip_packages.extend(normal_deps) + package_deps.pip_packages.extend(special_deps) if build_config.image_type == ImageType.docker.value: script = pkg_resources.resource_filename( @@ -99,7 +125,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path): package_deps.docker_image, str(build_file_path), str(BUILDS_BASE_DIR / ImageType.docker.value), - " ".join(deps), + " ".join(normal_deps), ] else: script = pkg_resources.resource_filename( @@ -109,7 +135,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path): script, build_config.name, str(build_file_path), - " ".join(deps), + " ".join(normal_deps), ] if special_deps: diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 16c2a32af..09d879c80 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional import yaml from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data @@ -37,7 +38,11 @@ async def resolve_impls_for_test_v2( sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name)) dist_registry = CachedDiskDistributionRegistry(dist_kvstore) - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + try: + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + except ModuleNotFoundError as e: + print_pip_install_help(providers) + raise e if provider_data: set_request_provider_data( @@ -66,7 +71,11 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None): providers=chosen, ) run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls(run_config, get_provider_registry()) + try: + impls = await resolve_impls(run_config, get_provider_registry()) + except ModuleNotFoundError as e: + print_pip_install_help(providers) + raise e if "provider_data" in config_dict: provider_id = chosen[api.value][0].provider_id