Add pip install helper for test and direct scenarios (#404)

* initial branch commit

* pip install helptext

* remove print

* pre-commit
This commit is contained in:
Dalton Flanagan 2024-11-08 15:18:21 -05:00 committed by GitHub
parent d800a16acd
commit 5625aef48a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 58 additions and 23 deletions

View file

@ -48,18 +48,14 @@ class ApiInput(BaseModel):
provider: str provider: str
def build_image(build_config: BuildConfig, build_file_path: Path): def get_provider_dependencies(
package_deps = Dependencies( config_providers: Dict[str, List[Provider]]
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", ) -> tuple[list[str], list[str]]:
pip_packages=SERVER_DEPENDENCIES, """Get normal and special dependencies from provider configuration."""
)
# extend package dependencies based on providers spec
all_providers = get_provider_registry() all_providers = get_provider_registry()
for ( deps = []
api_str,
provider_or_providers, for api_str, provider_or_providers in config_providers.items():
) in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)] providers_for_api = all_providers[Api(api_str)]
providers = ( providers = (
@ -69,25 +65,55 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
) )
for provider in providers: for provider in providers:
if provider not in providers_for_api: # Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)
if provider_type not in providers_for_api:
raise ValueError( raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`" f"Provider `{provider}` is not available for API `{api_str}`"
) )
provider_spec = providers_for_api[provider] provider_spec = providers_for_api[provider_type]
package_deps.pip_packages.extend(provider_spec.pip_packages) deps.extend(provider_spec.pip_packages)
if provider_spec.docker_image: if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image") raise ValueError("A stack's dependencies cannot have a docker image")
normal_deps = []
special_deps = [] special_deps = []
deps = [] for package in deps:
for package in package_deps.pip_packages:
if "--no-deps" in package or "--index-url" in package: if "--no-deps" in package or "--index-url" in package:
special_deps.append(package) special_deps.append(package)
else: else:
deps.append(package) normal_deps.append(package)
deps = list(set(deps))
special_deps = list(set(special_deps)) return list(set(normal_deps)), list(set(special_deps))
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
print(f"\tpip install {special_dep}")
print()
def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
# extend package dependencies based on providers spec
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
package_deps.pip_packages.extend(normal_deps)
package_deps.pip_packages.extend(special_deps)
if build_config.image_type == ImageType.docker.value: if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
@ -99,7 +125,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps.docker_image, package_deps.docker_image,
str(build_file_path), str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.docker.value), str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(deps), " ".join(normal_deps),
] ]
else: else:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
@ -109,7 +135,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
script, script,
build_config.name, build_config.name,
str(build_file_path), str(build_file_path),
" ".join(deps), " ".join(normal_deps),
] ]
if special_deps: if special_deps:

View file

@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional
import yaml import yaml
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
@ -37,7 +38,11 @@ async def resolve_impls_for_test_v2(
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name)) dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore) dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) try:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if provider_data: if provider_data:
set_request_provider_data( set_request_provider_data(
@ -66,7 +71,11 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen, providers=chosen,
) )
run_config = parse_and_maybe_upgrade_config(run_config) run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry()) try:
impls = await resolve_impls(run_config, get_provider_registry())
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if "provider_data" in config_dict: if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id provider_id = chosen[api.value][0].provider_id