diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 0cedbe901..f6821c8df 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -105,8 +105,7 @@ class StackBuild(Subcommand): import yaml - from llama_stack.distribution.build import ApiInput, build_image, ImageType - + from llama_stack.distribution.build import build_image, ImageType from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.serialize import EnumEncoder from termcolor import cprint @@ -175,9 +174,11 @@ class StackBuild(Subcommand): ) def _run_stack_build_command(self, args: argparse.Namespace) -> None: + import textwrap import yaml from llama_stack.distribution.distribution import get_provider_registry from prompt_toolkit import prompt + from prompt_toolkit.completion import WordCompleter from prompt_toolkit.validation import Validator from termcolor import cprint @@ -240,27 +241,30 @@ class StackBuild(Subcommand): default="conda", ) - cprint( - "\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", - color="green", - ) + cprint(textwrap.dedent( + """ + Llama Stack is composed of several APIs working together. Let's select + the provider types (implementations) you want to use for these APIs. + """, + ), + color="green") + + print("Tip: use to see options for the providers.\n") providers = dict() for api, providers_for_api in get_provider_registry().items(): + available_providers = [ + x for x in providers_for_api.keys() if x != "remote" + ] api_provider = prompt( - "> Enter provider for the {} API: (default=meta-reference): ".format( + "> Enter provider for API {}: ".format( api.value ), + completer=WordCompleter(available_providers), + complete_while_typing=True, validator=Validator.from_callable( - lambda x: x in providers_for_api, - error_message="Invalid provider, please enter one of the following: {}".format( - list(providers_for_api.keys()) - ), - ), - default=( - "meta-reference" - if "meta-reference" in providers_for_api - else list(providers_for_api.keys())[0] + lambda x: x in available_providers, + error_message="Invalid provider, use to see options", ), ) diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index 9aa7e2f6e..13899715b 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -71,9 +71,7 @@ class StackConfigure(Subcommand): conda_dir = ( Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}" ) - output = subprocess.check_output( - ["bash", "-c", "conda info --json -a"] - ) + output = subprocess.check_output(["bash", "-c", "conda info --json"]) conda_envs = json.loads(output.decode("utf-8"))["envs"] for x in conda_envs: diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index 233205066..6ce2b9dc6 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel): @json_schema_type class InferenceAPIImplConfig(BaseModel): - model_id: str = Field( + huggingface_repo: str = Field( description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", ) api_token: Optional[str] = Field( diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 538c11ec7..24b664068 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -243,7 +243,7 @@ class TGIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: self.client = AsyncInferenceClient( - model=config.model_id, token=config.api_token + model=config.huggingface_repo, token=config.api_token ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"]