A bit cleanup to avoid breakages

This commit is contained in:
Ashwin Bharambe 2024-10-02 21:31:09 -07:00
parent 988a9cada3
commit e9f6150588
2 changed files with 17 additions and 32 deletions

View file

@ -114,10 +114,10 @@ class StackBuild(Subcommand):
# save build.yaml spec for building same distribution again # save build.yaml spec for building same distribution again
if build_config.image_type == ImageType.docker.value: if build_config.image_type == ImageType.docker.value:
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
llama_stack_path = Path(os.path.abspath(__file__)).parent.parent.parent.parent llama_stack_path = Path(
build_dir = ( os.path.abspath(__file__)
llama_stack_path / "tmp/configs/" ).parent.parent.parent.parent
) build_dir = llama_stack_path / "tmp/configs/"
else: else:
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}" build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
@ -173,12 +173,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None: def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import yaml import yaml
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import get_provider_registry
Api,
get_provider_registry,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from prompt_toolkit import prompt from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from termcolor import cprint from termcolor import cprint
@ -212,7 +207,10 @@ class StackBuild(Subcommand):
if args.name: if args.name:
maybe_build_config = self._get_build_config_from_name(args) maybe_build_config = self._get_build_config_from_name(args)
if maybe_build_config: if maybe_build_config:
cprint(f"Building from existing build config for {args.name} in {str(maybe_build_config)}...", "green") cprint(
f"Building from existing build config for {args.name} in {str(maybe_build_config)}...",
"green",
)
with open(maybe_build_config, "r") as f: with open(maybe_build_config, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f)) build_config = BuildConfig(**yaml.safe_load(f))
self._run_stack_build_command_from_build_config(build_config) self._run_stack_build_command_from_build_config(build_config)
@ -240,24 +238,12 @@ class StackBuild(Subcommand):
) )
cprint( cprint(
f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", "\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
color="green", color="green",
) )
providers = dict() providers = dict()
all_providers = get_provider_registry() for api, providers_for_api in get_provider_registry().items():
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in Api:
if api in routing_table_apis:
continue
if api == Api.inspect:
continue
providers_for_api = all_providers[api]
api_provider = prompt( api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format( "> Enter provider for the {} API: (default=meta-reference): ".format(
api.value api.value

View file

@ -38,17 +38,16 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
] ]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: def providable_apis() -> List[Api]:
ret = {}
routing_table_apis = set( routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis() x.routing_table_api for x in builtin_automatically_routed_apis()
) )
for api in stack_apis(): return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
if api in routing_table_apis:
continue
if api == Api.inspect:
continue
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
for api in providable_apis():
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = { ret[api] = {