diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 510601523..70ff4a7f0 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -40,8 +40,7 @@ class StackConfigure(Subcommand): self.parser.add_argument( "distribution", type=str, - choices=allowed_ids, - help="Distribution (one of: {})".format(allowed_ids), + help='Distribution ("adhoc" or one of: {})'.format(allowed_ids), ) self.parser.add_argument( "--name", @@ -79,17 +78,10 @@ class StackConfigure(Subcommand): def configure_llama_distribution(config_file: Path) -> None: from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.core.configure import configure_api_providers - from llama_toolchain.core.distribution_registry import resolve_distribution_spec with open(config_file, "r") as f: config = PackageConfig(**yaml.safe_load(f)) - dist = resolve_distribution_spec(config.distribution_id) - if dist is None: - raise ValueError( - f"Could not find any registered distribution `{config.distribution_id}`" - ) - if config.providers: cprint( f"Configuration already exists for {config.distribution_id}. Will overwrite...", diff --git a/llama_toolchain/cli/stack/list_apis.py b/llama_toolchain/cli/stack/list_apis.py new file mode 100644 index 000000000..f13ecefe9 --- /dev/null +++ b/llama_toolchain/cli/stack/list_apis.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_toolchain.cli.subcommand import Subcommand + + +class StackListApis(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list-apis", + prog="llama stack list-apis", + description="List APIs part of the Llama Stack implementation", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_apis_list_cmd) + + def _add_arguments(self): + pass + + def _run_apis_list_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.cli.table import print_table + from llama_toolchain.core.distribution import stack_apis + + # eventually, this should query a registry at llama.meta.com/llamastack/distributions + headers = [ + "API", + ] + + rows = [] + for api in stack_apis(): + rows.append( + [ + api.value, + ] + ) + print_table( + rows, + headers, + separate_rows=True, + ) diff --git a/llama_toolchain/cli/stack/list.py b/llama_toolchain/cli/stack/list_distributions.py similarity index 97% rename from llama_toolchain/cli/stack/list.py rename to llama_toolchain/cli/stack/list_distributions.py index cbd7610f5..c4d529157 100644 --- a/llama_toolchain/cli/stack/list.py +++ b/llama_toolchain/cli/stack/list_distributions.py @@ -10,7 +10,7 @@ import json from llama_toolchain.cli.subcommand import Subcommand -class StackList(Subcommand): +class StackListDistributions(Subcommand): def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( diff --git a/llama_toolchain/cli/stack/list_providers.py b/llama_toolchain/cli/stack/list_providers.py new file mode 100644 index 000000000..29602d889 --- /dev/null +++ b/llama_toolchain/cli/stack/list_providers.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_toolchain.cli.subcommand import Subcommand + + +class StackListProviders(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list-providers", + prog="llama stack list-providers", + description="Show available Llama Stack Providers for an API", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_providers_list_cmd) + + def _add_arguments(self): + from llama_toolchain.core.distribution import stack_apis + + api_values = [a.value for a in stack_apis()] + self.parser.add_argument( + "api", + type=str, + choices=api_values, + help="API to list providers for (one of: {})".format(api_values), + ) + + def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.cli.table import print_table + from llama_toolchain.core.distribution import Api, api_providers + + all_providers = api_providers() + providers_for_api = all_providers[Api(args.api)] + + # eventually, this should query a registry at llama.meta.com/llamastack/distributions + headers = [ + "Provider ID", + "PIP Package Dependencies", + ] + + rows = [] + for spec in providers_for_api.values(): + rows.append( + [ + spec.provider_id, + ",".join(spec.pip_packages), + ] + ) + print_table( + rows, + headers, + separate_rows=True, + ) diff --git a/llama_toolchain/cli/stack/stack.py b/llama_toolchain/cli/stack/stack.py index cba31e08d..e41f10633 100644 --- a/llama_toolchain/cli/stack/stack.py +++ b/llama_toolchain/cli/stack/stack.py @@ -10,7 +10,9 @@ from llama_toolchain.cli.subcommand import Subcommand from .build import StackBuild from .configure import StackConfigure -from .list import StackList +from .list_apis import StackListApis +from .list_distributions import StackListDistributions +from .list_providers import StackListProviders from .run import StackRun @@ -28,5 +30,7 @@ class StackParser(Subcommand): # Add sub-commands StackBuild.create(subparsers) StackConfigure.create(subparsers) - StackList.create(subparsers) + StackListApis.create(subparsers) + StackListDistributions.create(subparsers) + StackListProviders.create(subparsers) StackRun.create(subparsers) diff --git a/llama_toolchain/core/build_conda_env.sh b/llama_toolchain/core/build_conda_env.sh index 0a3eaf20a..1e8c002f2 100755 --- a/llama_toolchain/core/build_conda_env.sh +++ b/llama_toolchain/core/build_conda_env.sh @@ -117,12 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies" printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n" -if [ "$distribution_id" = "adhoc" ]; then - subcommand="api" - target="" -else - subcommand="stack" - target="$distribution_id" -fi - -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type conda_env +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env diff --git a/llama_toolchain/core/build_container.sh b/llama_toolchain/core/build_container.sh index 5b05f1132..b864e7098 100755 --- a/llama_toolchain/core/build_container.sh +++ b/llama_toolchain/core/build_container.sh @@ -109,12 +109,4 @@ set +x printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}" echo "You can run it with: podman run -p 8000:8000 $image_name" -if [ "$distribution_id" = "adhoc" ]; then - subcommand="api" - target="" -else - subcommand="stack" - target="$distribution_id" -fi - -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type container +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container diff --git a/llama_toolchain/inference/adapters/fireworks/__init__.py b/llama_toolchain/inference/adapters/fireworks/__init__.py index 6de34833f..a3f5a0bd4 100644 --- a/llama_toolchain/inference/adapters/fireworks/__init__.py +++ b/llama_toolchain/inference/adapters/fireworks/__init__.py @@ -7,7 +7,7 @@ from .config import FireworksImplConfig -async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference: +async def get_adapter_impl(config: FireworksImplConfig, _deps): from .fireworks import FireworksInferenceAdapter assert isinstance( diff --git a/llama_toolchain/inference/adapters/fireworks/config.py b/llama_toolchain/inference/adapters/fireworks/config.py index 68a0131aa..827bc620f 100644 --- a/llama_toolchain/inference/adapters/fireworks/config.py +++ b/llama_toolchain/inference/adapters/fireworks/config.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Field @json_schema_type class FireworksImplConfig(BaseModel): url: str = Field( - default="https://api.fireworks.api/inference", + default="https://api.fireworks.ai/inference", description="The URL for the Fireworks server", ) api_key: str = Field( diff --git a/llama_toolchain/inference/adapters/together/__init__.py b/llama_toolchain/inference/adapters/together/__init__.py index ad8bc2ac1..05ea91e58 100644 --- a/llama_toolchain/inference/adapters/together/__init__.py +++ b/llama_toolchain/inference/adapters/together/__init__.py @@ -7,7 +7,7 @@ from .config import TogetherImplConfig -async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference: +async def get_adapter_impl(config: TogetherImplConfig, _deps): from .together import TogetherInferenceAdapter assert isinstance( diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index 7514aa724..772114b41 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -42,8 +42,8 @@ def available_inference_providers() -> List[ProviderSpec]: pip_packages=[ "fireworks-ai", ], - module="llama_toolchain.inference.fireworks", - config_class="llama_toolchain.inference.fireworks.FireworksImplConfig", + module="llama_toolchain.inference.adapters.fireworks", + config_class="llama_toolchain.inference.adapters.fireworks.FireworksImplConfig", ), ), remote_provider_spec( @@ -53,8 +53,8 @@ def available_inference_providers() -> List[ProviderSpec]: pip_packages=[ "together", ], - module="llama_toolchain.inference.together", - config_class="llama_toolchain.inference.together.TogetherImplConfig", + module="llama_toolchain.inference.adapters.together", + config_class="llama_toolchain.inference.adapters.together.TogetherImplConfig", ), ), ]