mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
slight upgrade to CLI
This commit is contained in:
parent
1550187cd8
commit
099a95b614
4 changed files with 23 additions and 21 deletions
|
@ -105,8 +105,7 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.distribution.build import ApiInput, build_image, ImageType
|
from llama_stack.distribution.build import build_image, ImageType
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -175,9 +174,11 @@ class StackBuild(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
|
import textwrap
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
|
from prompt_toolkit.completion import WordCompleter
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -240,27 +241,30 @@ class StackBuild(Subcommand):
|
||||||
default="conda",
|
default="conda",
|
||||||
)
|
)
|
||||||
|
|
||||||
cprint(
|
cprint(textwrap.dedent(
|
||||||
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
|
"""
|
||||||
color="green",
|
Llama Stack is composed of several APIs working together. Let's select
|
||||||
)
|
the provider types (implementations) you want to use for these APIs.
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
color="green")
|
||||||
|
|
||||||
|
print("Tip: use <TAB> to see options for the providers.\n")
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
|
available_providers = [
|
||||||
|
x for x in providers_for_api.keys() if x != "remote"
|
||||||
|
]
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for the {} API: (default=meta-reference): ".format(
|
"> Enter provider for API {}: ".format(
|
||||||
api.value
|
api.value
|
||||||
),
|
),
|
||||||
|
completer=WordCompleter(available_providers),
|
||||||
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in providers_for_api,
|
lambda x: x in available_providers,
|
||||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
error_message="Invalid provider, use <TAB> to see options",
|
||||||
list(providers_for_api.keys())
|
|
||||||
),
|
|
||||||
),
|
|
||||||
default=(
|
|
||||||
"meta-reference"
|
|
||||||
if "meta-reference" in providers_for_api
|
|
||||||
else list(providers_for_api.keys())[0]
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -71,9 +71,7 @@ class StackConfigure(Subcommand):
|
||||||
conda_dir = (
|
conda_dir = (
|
||||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
||||||
)
|
)
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(["bash", "-c", "conda info --json"])
|
||||||
["bash", "-c", "conda info --json -a"]
|
|
||||||
)
|
|
||||||
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
||||||
|
|
||||||
for x in conda_envs:
|
for x in conda_envs:
|
||||||
|
|
|
@ -34,7 +34,7 @@ class InferenceEndpointImplConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InferenceAPIImplConfig(BaseModel):
|
class InferenceAPIImplConfig(BaseModel):
|
||||||
model_id: str = Field(
|
huggingface_repo: str = Field(
|
||||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||||
)
|
)
|
||||||
api_token: Optional[str] = Field(
|
api_token: Optional[str] = Field(
|
||||||
|
|
|
@ -243,7 +243,7 @@ class TGIAdapter(_HfAdapter):
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
class InferenceAPIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.model_id, token=config.api_token
|
model=config.huggingface_repo, token=config.api_token
|
||||||
)
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue