# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-06-05 11:59:11 -07:00
parent 3251b44d8a
commit 460c67be62
6 changed files with 36 additions and 18 deletions

View file

@ -43,23 +43,12 @@ def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
config = config.build_config()
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
run_configs = config.run_configs
additional_pip_packages: list[str] = []
if run_configs:
for run_config in run_configs.values():
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
@ -87,8 +76,7 @@ def get_provider_dependencies(
else:
normal_deps.append(package)
if additional_pip_packages:
normal_deps.extend(additional_pip_packages)
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps))