From b2385cb2f72fe0f2cf77c05c24f1e1d9babec0db Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 18 Sep 2024 11:19:08 -0700 Subject: [PATCH] validator for providers --- llama_stack/cli/stack/build.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 76988b058..0ca05cac1 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -82,6 +82,7 @@ class StackBuild(Subcommand): ) def _run_stack_build_command(self, args: argparse.Namespace) -> None: + from llama_stack.distribution.distribution import Api, api_providers from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config @@ -105,12 +106,26 @@ class StackBuild(Subcommand): providers = dict() for api in Api: + all_providers = api_providers() + providers_for_api = all_providers[api] + api_provider = prompt( "> Please enter the API provider for the {} API: (default=meta-reference): ".format( api.value ), - default="meta-reference", + validator=Validator.from_callable( + lambda x: x in providers_for_api, + error_message="Invalid provider, please enter one of the following: {}".format( + providers_for_api.keys() + ), + ), + default=( + "meta-reference" + if "meta-reference" in providers_for_api + else list(providers_for_api.keys())[0] + ), ) + providers[api.value] = api_provider description = prompt(