feat: add deps dynamically based on metastore config (#2405)

# What does this PR do?


## Test Plan
changed metastore in one of the templates, rerun distro gen, observe
change in build.yaml
This commit is contained in:
ehhuang 2025-06-05 14:07:25 -07:00 committed by GitHub
parent 92b59a3377
commit 446893f791
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 36 additions and 18 deletions

View file

@ -43,23 +43,12 @@ def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
config = config.build_config()
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
run_configs = config.run_configs
additional_pip_packages: list[str] = []
if run_configs:
for run_config in run_configs.values():
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
@ -87,8 +76,7 @@ def get_provider_dependencies(
else:
normal_deps.append(package)
if additional_pip_packages:
normal_deps.extend(additional_pip_packages)
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps))