validator for providers

This commit is contained in:
Xi Yan 2024-09-18 11:19:08 -07:00
parent 714a1703c4
commit b2385cb2f7

View file

@ -82,6 +82,7 @@ class StackBuild(Subcommand):
) )
def _run_stack_build_command(self, args: argparse.Namespace) -> None: def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_stack.distribution.distribution import Api, api_providers
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
@ -105,12 +106,26 @@ class StackBuild(Subcommand):
providers = dict() providers = dict()
for api in Api: for api in Api:
all_providers = api_providers()
providers_for_api = all_providers[api]
api_provider = prompt( api_provider = prompt(
"> Please enter the API provider for the {} API: (default=meta-reference): ".format( "> Please enter the API provider for the {} API: (default=meta-reference): ".format(
api.value api.value
), ),
default="meta-reference", validator=Validator.from_callable(
lambda x: x in providers_for_api,
error_message="Invalid provider, please enter one of the following: {}".format(
providers_for_api.keys()
),
),
default=(
"meta-reference"
if "meta-reference" in providers_for_api
else list(providers_for_api.keys())[0]
),
) )
providers[api.value] = api_provider providers[api.value] = api_provider
description = prompt( description = prompt(