A bit cleanup to avoid breakages

This commit is contained in:
Ashwin Bharambe 2024-10-02 21:31:09 -07:00
parent 988a9cada3
commit e9f6150588
2 changed files with 17 additions and 32 deletions

View file

@ -114,10 +114,10 @@ class StackBuild(Subcommand):
# save build.yaml spec for building same distribution again
if build_config.image_type == ImageType.docker.value:
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
llama_stack_path = Path(os.path.abspath(__file__)).parent.parent.parent.parent
build_dir = (
llama_stack_path / "tmp/configs/"
)
llama_stack_path = Path(
os.path.abspath(__file__)
).parent.parent.parent.parent
build_dir = llama_stack_path / "tmp/configs/"
else:
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
@ -173,12 +173,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.distribution.distribution import (
Api,
get_provider_registry,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
@ -212,7 +207,10 @@ class StackBuild(Subcommand):
if args.name:
maybe_build_config = self._get_build_config_from_name(args)
if maybe_build_config:
cprint(f"Building from existing build config for {args.name} in {str(maybe_build_config)}...", "green")
cprint(
f"Building from existing build config for {args.name} in {str(maybe_build_config)}...",
"green",
)
with open(maybe_build_config, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
self._run_stack_build_command_from_build_config(build_config)
@ -240,24 +238,12 @@ class StackBuild(Subcommand):
)
cprint(
f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
color="green",
)
providers = dict()
all_providers = get_provider_registry()
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in Api:
if api in routing_table_apis:
continue
if api == Api.inspect:
continue
providers_for_api = all_providers[api]
for api, providers_for_api in get_provider_registry().items():
api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format(
api.value

View file

@ -38,17 +38,16 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
def providable_apis() -> List[Api]:
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in stack_apis():
if api in routing_table_apis:
continue
if api == Api.inspect:
continue
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {