pre-commit

This commit is contained in:
Dalton Flanagan 2024-11-08 12:51:24 -05:00
parent c07919aa36
commit 731644e111
2 changed files with 11 additions and 4 deletions

View file

@ -47,7 +47,10 @@ class ApiInput(BaseModel):
api: Api api: Api
provider: str provider: str
def get_provider_dependencies(config_providers: Dict[str, List[Provider]]) -> tuple[list[str], list[str]]:
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]]
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration.""" """Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry() all_providers = get_provider_registry()
deps = [] deps = []
@ -63,7 +66,9 @@ def get_provider_dependencies(config_providers: Dict[str, List[Provider]]) -> tu
for provider in providers: for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great # Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = provider if isinstance(provider, str) else provider.provider_type provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)
if provider_type not in providers_for_api: if provider_type not in providers_for_api:
raise ValueError( raise ValueError(
@ -104,7 +109,9 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
) )
# extend package dependencies based on providers spec # extend package dependencies based on providers spec
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers) normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
package_deps.pip_packages.extend(normal_deps) package_deps.pip_packages.extend(normal_deps)
package_deps.pip_packages.extend(special_deps) package_deps.pip_packages.extend(special_deps)

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import yaml import yaml
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.store import CachedDiskDistributionRegistry from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig