validator for providers

This commit is contained in:
Xi Yan 2024-09-18 11:19:08 -07:00
parent 714a1703c4
commit b2385cb2f7

View file

@ -82,6 +82,7 @@ class StackBuild(Subcommand):
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_stack.distribution.distribution import Api, api_providers
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
@ -105,12 +106,26 @@ class StackBuild(Subcommand):
providers = dict()
for api in Api:
all_providers = api_providers()
providers_for_api = all_providers[api]
api_provider = prompt(
"> Please enter the API provider for the {} API: (default=meta-reference): ".format(
api.value
),
default="meta-reference",
validator=Validator.from_callable(
lambda x: x in providers_for_api,
error_message="Invalid provider, please enter one of the following: {}".format(
providers_for_api.keys()
),
),
default=(
"meta-reference"
if "meta-reference" in providers_for_api
else list(providers_for_api.keys())[0]
),
)
providers[api.value] = api_provider
description = prompt(