From e9f615058820ec0a68b4d238b5cdc6d80cde3c36 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 21:31:09 -0700 Subject: [PATCH] A bit cleanup to avoid breakages --- llama_stack/cli/stack/build.py | 36 ++++++++---------------- llama_stack/distribution/distribution.py | 13 ++++----- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index ab6861482..d502e4c84 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -114,10 +114,10 @@ class StackBuild(Subcommand): # save build.yaml spec for building same distribution again if build_config.image_type == ImageType.docker.value: # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image - llama_stack_path = Path(os.path.abspath(__file__)).parent.parent.parent.parent - build_dir = ( - llama_stack_path / "tmp/configs/" - ) + llama_stack_path = Path( + os.path.abspath(__file__) + ).parent.parent.parent.parent + build_dir = llama_stack_path / "tmp/configs/" else: build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}" @@ -173,12 +173,7 @@ class StackBuild(Subcommand): def _run_stack_build_command(self, args: argparse.Namespace) -> None: import yaml - from llama_stack.distribution.distribution import ( - Api, - get_provider_registry, - builtin_automatically_routed_apis, - ) - from llama_stack.distribution.utils.dynamic import instantiate_class_type + from llama_stack.distribution.distribution import get_provider_registry from prompt_toolkit import prompt from prompt_toolkit.validation import Validator from termcolor import cprint @@ -212,7 +207,10 @@ class StackBuild(Subcommand): if args.name: maybe_build_config = self._get_build_config_from_name(args) if maybe_build_config: - cprint(f"Building from existing build config for {args.name} in {str(maybe_build_config)}...", "green") + cprint( + f"Building from existing build config for {args.name} in {str(maybe_build_config)}...", + "green", + ) with open(maybe_build_config, "r") as f: build_config = BuildConfig(**yaml.safe_load(f)) self._run_stack_build_command_from_build_config(build_config) @@ -240,24 +238,12 @@ class StackBuild(Subcommand): ) cprint( - f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", + "\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", color="green", ) providers = dict() - all_providers = get_provider_registry() - routing_table_apis = set( - x.routing_table_api for x in builtin_automatically_routed_apis() - ) - - for api in Api: - if api in routing_table_apis: - continue - if api == Api.inspect: - continue - - providers_for_api = all_providers[api] - + for api, providers_for_api in get_provider_registry().items(): api_provider = prompt( "> Enter provider for the {} API: (default=meta-reference): ".format( api.value diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index eea066d1f..999646cc0 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -38,17 +38,16 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: ] -def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: - ret = {} +def providable_apis() -> List[Api]: routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() ) - for api in stack_apis(): - if api in routing_table_apis: - continue - if api == Api.inspect: - continue + return [api for api in Api if api not in routing_table_apis and api != Api.inspect] + +def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: + ret = {} + for api in providable_apis(): name = api.name.lower() module = importlib.import_module(f"llama_stack.providers.registry.{name}") ret[api] = {