slight upgrade to CLI

This commit is contained in:
Ashwin Bharambe 2024-10-06 18:02:47 -07:00 committed by Ashwin Bharambe
parent 1550187cd8
commit 099a95b614
4 changed files with 23 additions and 21 deletions

View file

@ -105,8 +105,7 @@ class StackBuild(Subcommand):
import yaml
from llama_stack.distribution.build import ApiInput, build_image, ImageType
from llama_stack.distribution.build import build_image, ImageType
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.serialize import EnumEncoder
from termcolor import cprint
@ -175,9 +174,11 @@ class StackBuild(Subcommand):
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap
import yaml
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import cprint
@ -240,27 +241,30 @@ class StackBuild(Subcommand):
default="conda",
)
cprint(
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
color="green",
)
cprint(textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green")
print("Tip: use <TAB> to see options for the providers.\n")
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x for x in providers_for_api.keys() if x != "remote"
]
api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format(
"> Enter provider for API {}: ".format(
api.value
),
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in providers_for_api,
error_message="Invalid provider, please enter one of the following: {}".format(
list(providers_for_api.keys())
),
),
default=(
"meta-reference"
if "meta-reference" in providers_for_api
else list(providers_for_api.keys())[0]
lambda x: x in available_providers,
error_message="Invalid provider, use <TAB> to see options",
),
)

View file

@ -71,9 +71,7 @@ class StackConfigure(Subcommand):
conda_dir = (
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
)
output = subprocess.check_output(
["bash", "-c", "conda info --json -a"]
)
output = subprocess.check_output(["bash", "-c", "conda info --json"])
conda_envs = json.loads(output.decode("utf-8"))["envs"]
for x in conda_envs:

View file

@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
model_id: str = Field(
huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: Optional[str] = Field(

View file

@ -243,7 +243,7 @@ class TGIAdapter(_HfAdapter):
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
model=config.model_id, token=config.api_token
model=config.huggingface_repo, token=config.api_token
)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]