forked from phoenix-oss/llama-stack-mirror
		
	Fixes to the llama stack configure script + inference adapters
				
					
				
			This commit is contained in:
		
							parent
							
								
									4869f2b983
								
							
						
					
					
						commit
						1380d78c19
					
				
					 11 changed files with 124 additions and 37 deletions
				
			
		|  | @ -40,8 +40,7 @@ class StackConfigure(Subcommand): | |||
|         self.parser.add_argument( | ||||
|             "distribution", | ||||
|             type=str, | ||||
|             choices=allowed_ids, | ||||
|             help="Distribution (one of: {})".format(allowed_ids), | ||||
|             help='Distribution ("adhoc" or one of: {})'.format(allowed_ids), | ||||
|         ) | ||||
|         self.parser.add_argument( | ||||
|             "--name", | ||||
|  | @ -79,17 +78,10 @@ class StackConfigure(Subcommand): | |||
| def configure_llama_distribution(config_file: Path) -> None: | ||||
|     from llama_toolchain.common.serialize import EnumEncoder | ||||
|     from llama_toolchain.core.configure import configure_api_providers | ||||
|     from llama_toolchain.core.distribution_registry import resolve_distribution_spec | ||||
| 
 | ||||
|     with open(config_file, "r") as f: | ||||
|         config = PackageConfig(**yaml.safe_load(f)) | ||||
| 
 | ||||
|     dist = resolve_distribution_spec(config.distribution_id) | ||||
|     if dist is None: | ||||
|         raise ValueError( | ||||
|             f"Could not find any registered distribution `{config.distribution_id}`" | ||||
|         ) | ||||
| 
 | ||||
|     if config.providers: | ||||
|         cprint( | ||||
|             f"Configuration already exists for {config.distribution_id}. Will overwrite...", | ||||
|  |  | |||
							
								
								
									
										47
									
								
								llama_toolchain/cli/stack/list_apis.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								llama_toolchain/cli/stack/list_apis.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,47 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import argparse | ||||
| 
 | ||||
| from llama_toolchain.cli.subcommand import Subcommand | ||||
| 
 | ||||
| 
 | ||||
| class StackListApis(Subcommand): | ||||
|     def __init__(self, subparsers: argparse._SubParsersAction): | ||||
|         super().__init__() | ||||
|         self.parser = subparsers.add_parser( | ||||
|             "list-apis", | ||||
|             prog="llama stack list-apis", | ||||
|             description="List APIs part of the Llama Stack implementation", | ||||
|             formatter_class=argparse.RawTextHelpFormatter, | ||||
|         ) | ||||
|         self._add_arguments() | ||||
|         self.parser.set_defaults(func=self._run_apis_list_cmd) | ||||
| 
 | ||||
|     def _add_arguments(self): | ||||
|         pass | ||||
| 
 | ||||
|     def _run_apis_list_cmd(self, args: argparse.Namespace) -> None: | ||||
|         from llama_toolchain.cli.table import print_table | ||||
|         from llama_toolchain.core.distribution import stack_apis | ||||
| 
 | ||||
|         # eventually, this should query a registry at llama.meta.com/llamastack/distributions | ||||
|         headers = [ | ||||
|             "API", | ||||
|         ] | ||||
| 
 | ||||
|         rows = [] | ||||
|         for api in stack_apis(): | ||||
|             rows.append( | ||||
|                 [ | ||||
|                     api.value, | ||||
|                 ] | ||||
|             ) | ||||
|         print_table( | ||||
|             rows, | ||||
|             headers, | ||||
|             separate_rows=True, | ||||
|         ) | ||||
|  | @ -10,7 +10,7 @@ import json | |||
| from llama_toolchain.cli.subcommand import Subcommand | ||||
| 
 | ||||
| 
 | ||||
| class StackList(Subcommand): | ||||
| class StackListDistributions(Subcommand): | ||||
|     def __init__(self, subparsers: argparse._SubParsersAction): | ||||
|         super().__init__() | ||||
|         self.parser = subparsers.add_parser( | ||||
							
								
								
									
										60
									
								
								llama_toolchain/cli/stack/list_providers.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								llama_toolchain/cli/stack/list_providers.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,60 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import argparse | ||||
| 
 | ||||
| from llama_toolchain.cli.subcommand import Subcommand | ||||
| 
 | ||||
| 
 | ||||
| class StackListProviders(Subcommand): | ||||
|     def __init__(self, subparsers: argparse._SubParsersAction): | ||||
|         super().__init__() | ||||
|         self.parser = subparsers.add_parser( | ||||
|             "list-providers", | ||||
|             prog="llama stack list-providers", | ||||
|             description="Show available Llama Stack Providers for an API", | ||||
|             formatter_class=argparse.RawTextHelpFormatter, | ||||
|         ) | ||||
|         self._add_arguments() | ||||
|         self.parser.set_defaults(func=self._run_providers_list_cmd) | ||||
| 
 | ||||
|     def _add_arguments(self): | ||||
|         from llama_toolchain.core.distribution import stack_apis | ||||
| 
 | ||||
|         api_values = [a.value for a in stack_apis()] | ||||
|         self.parser.add_argument( | ||||
|             "api", | ||||
|             type=str, | ||||
|             choices=api_values, | ||||
|             help="API to list providers for (one of: {})".format(api_values), | ||||
|         ) | ||||
| 
 | ||||
|     def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: | ||||
|         from llama_toolchain.cli.table import print_table | ||||
|         from llama_toolchain.core.distribution import Api, api_providers | ||||
| 
 | ||||
|         all_providers = api_providers() | ||||
|         providers_for_api = all_providers[Api(args.api)] | ||||
| 
 | ||||
|         # eventually, this should query a registry at llama.meta.com/llamastack/distributions | ||||
|         headers = [ | ||||
|             "Provider ID", | ||||
|             "PIP Package Dependencies", | ||||
|         ] | ||||
| 
 | ||||
|         rows = [] | ||||
|         for spec in providers_for_api.values(): | ||||
|             rows.append( | ||||
|                 [ | ||||
|                     spec.provider_id, | ||||
|                     ",".join(spec.pip_packages), | ||||
|                 ] | ||||
|             ) | ||||
|         print_table( | ||||
|             rows, | ||||
|             headers, | ||||
|             separate_rows=True, | ||||
|         ) | ||||
|  | @ -10,7 +10,9 @@ from llama_toolchain.cli.subcommand import Subcommand | |||
| 
 | ||||
| from .build import StackBuild | ||||
| from .configure import StackConfigure | ||||
| from .list import StackList | ||||
| from .list_apis import StackListApis | ||||
| from .list_distributions import StackListDistributions | ||||
| from .list_providers import StackListProviders | ||||
| from .run import StackRun | ||||
| 
 | ||||
| 
 | ||||
|  | @ -28,5 +30,7 @@ class StackParser(Subcommand): | |||
|         # Add sub-commands | ||||
|         StackBuild.create(subparsers) | ||||
|         StackConfigure.create(subparsers) | ||||
|         StackList.create(subparsers) | ||||
|         StackListApis.create(subparsers) | ||||
|         StackListDistributions.create(subparsers) | ||||
|         StackListProviders.create(subparsers) | ||||
|         StackRun.create(subparsers) | ||||
|  |  | |||
|  | @ -117,12 +117,4 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies" | |||
| 
 | ||||
| printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n" | ||||
| 
 | ||||
| if [ "$distribution_id" = "adhoc" ]; then | ||||
|   subcommand="api" | ||||
|   target="" | ||||
| else | ||||
|   subcommand="stack" | ||||
|   target="$distribution_id" | ||||
| fi | ||||
| 
 | ||||
| $CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type conda_env | ||||
| $CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type conda_env | ||||
|  |  | |||
|  | @ -109,12 +109,4 @@ set +x | |||
| printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}" | ||||
| echo "You can run it with: podman run -p 8000:8000 $image_name" | ||||
| 
 | ||||
| if [ "$distribution_id" = "adhoc" ]; then | ||||
|   subcommand="api" | ||||
|   target="" | ||||
| else | ||||
|   subcommand="stack" | ||||
|   target="$distribution_id" | ||||
| fi | ||||
| 
 | ||||
| $CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type container | ||||
| $CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama stack configure $distribution_id --name "$build_name" --type container | ||||
|  |  | |||
|  | @ -7,7 +7,7 @@ | |||
| from .config import FireworksImplConfig | ||||
| 
 | ||||
| 
 | ||||
| async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference: | ||||
| async def get_adapter_impl(config: FireworksImplConfig, _deps): | ||||
|     from .fireworks import FireworksInferenceAdapter | ||||
| 
 | ||||
|     assert isinstance( | ||||
|  |  | |||
|  | @ -11,7 +11,7 @@ from pydantic import BaseModel, Field | |||
| @json_schema_type | ||||
| class FireworksImplConfig(BaseModel): | ||||
|     url: str = Field( | ||||
|         default="https://api.fireworks.api/inference", | ||||
|         default="https://api.fireworks.ai/inference", | ||||
|         description="The URL for the Fireworks server", | ||||
|     ) | ||||
|     api_key: str = Field( | ||||
|  |  | |||
|  | @ -7,7 +7,7 @@ | |||
| from .config import TogetherImplConfig | ||||
| 
 | ||||
| 
 | ||||
| async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference: | ||||
| async def get_adapter_impl(config: TogetherImplConfig, _deps): | ||||
|     from .together import TogetherInferenceAdapter | ||||
| 
 | ||||
|     assert isinstance( | ||||
|  |  | |||
|  | @ -42,8 +42,8 @@ def available_inference_providers() -> List[ProviderSpec]: | |||
|                 pip_packages=[ | ||||
|                     "fireworks-ai", | ||||
|                 ], | ||||
|                 module="llama_toolchain.inference.fireworks", | ||||
|                 config_class="llama_toolchain.inference.fireworks.FireworksImplConfig", | ||||
|                 module="llama_toolchain.inference.adapters.fireworks", | ||||
|                 config_class="llama_toolchain.inference.adapters.fireworks.FireworksImplConfig", | ||||
|             ), | ||||
|         ), | ||||
|         remote_provider_spec( | ||||
|  | @ -53,8 +53,8 @@ def available_inference_providers() -> List[ProviderSpec]: | |||
|                 pip_packages=[ | ||||
|                     "together", | ||||
|                 ], | ||||
|                 module="llama_toolchain.inference.together", | ||||
|                 config_class="llama_toolchain.inference.together.TogetherImplConfig", | ||||
|                 module="llama_toolchain.inference.adapters.together", | ||||
|                 config_class="llama_toolchain.inference.adapters.together.TogetherImplConfig", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue