diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index c495eb9f9..fa1a42dc0 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -8,7 +8,7 @@ import argparse import json import shlex -from pathlib import Path +from typing import Any, Dict import yaml from termcolor import cprint @@ -32,83 +32,81 @@ class DistributionConfigure(Subcommand): self.parser.set_defaults(func=self._run_distribution_configure_cmd) def _add_arguments(self): - from llama_toolchain.distribution.registry import available_distributions - self.parser.add_argument( "--name", type=str, help="Name of the distribution to configure", - default="local-source", - choices=[d.name for d in available_distributions()], + required=True, ) def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.distribution.registry import resolve_distribution + from llama_toolchain.distribution.registry import resolve_distribution_spec - dist = resolve_distribution(args.name) - if dist is None: - self.parser.error(f"Could not find distribution {args.name}") + config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" + if not config_file.exists(): + self.parser.error( + f"Could not find {config_file}. Please run `llama distribution install` first" + ) return - env_file = DISTRIBS_BASE_DIR / dist.name / "conda.env" - # read this file to get the conda env name - assert env_file.exists(), f"Could not find conda env file {env_file}" - with open(env_file, "r") as f: - conda_env = f.read().strip() + # we need to find the spec from the name + with open(config_file, "r") as f: + config = yaml.safe_load(f) - configure_llama_distribution(dist, conda_env) + dist = resolve_distribution_spec(config["spec"]) + if dist is None: + raise ValueError(f"Could not find any registered spec `{config['spec']}`") + + configure_llama_distribution(dist, config) -def configure_llama_distribution(dist: "Distribution", conda_env: str): +def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]): from llama_toolchain.common.exec import run_command from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.distribution.datatypes import RemoteProviderSpec from llama_toolchain.distribution.dynamic import instantiate_class_type python_exe = run_command(shlex.split("which python")) # simple check + conda_env = config["conda_env"] if conda_env not in python_exe: raise ValueError( f"Please re-run configure by activating the `{conda_env}` conda environment" ) - existing_config = None - config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" - if config_path.exists(): + existing_config = config + if "providers" in existing_config: cprint( - f"Configuration already exists for {dist.name}. Will overwrite...", + f"Configuration already exists for {config['name']}. Will overwrite...", "yellow", attrs=["bold"], ) - with open(config_path, "r") as fp: - existing_config = yaml.safe_load(fp) provider_configs = {} for api, provider_spec in dist.provider_specs.items(): - if isinstance(provider_spec, RemoteProviderSpec): - provider_configs[api.value] = provider_spec.dict() - else: - cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) - config_type = instantiate_class_type(provider_spec.config_class) - config = prompt_for_config( - config_type, - ( - config_type(**existing_config["providers"][api.value]) - if existing_config and api.value in existing_config["providers"] - else None - ), - ) - provider_configs[api.value] = { - "provider_id": provider_spec.provider_id, - **config.dict(), - } + cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) + config_type = instantiate_class_type(provider_spec.config_class) + config = prompt_for_config( + config_type, + ( + config_type(**existing_config["providers"][api.value]) + if existing_config + and "providers" in existing_config + and api.value in existing_config["providers"] + else None + ), + ) + provider_configs[api.value] = { + "provider_id": provider_spec.provider_id, + **config.dict(), + } dist_config = { "providers": provider_configs, - "conda_env": conda_env, + **existing_config, } + config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml" with open(config_path, "w") as fp: dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder)) fp.write(yaml.dump(dist_config, sort_keys=False)) diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py index bb0d9b42e..140f1027d 100644 --- a/llama_toolchain/cli/distribution/create.py +++ b/llama_toolchain/cli/distribution/create.py @@ -34,9 +34,9 @@ class DistributionCreate(Subcommand): # wants to pick and then ask for their configuration. def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.distribution.registry import resolve_distribution + from llama_toolchain.distribution.registry import resolve_distribution_spec - dist = resolve_distribution(args.name) + dist = resolve_distribution_spec(args.name) if dist is not None: self.parser.error(f"Distribution with name {args.name} already exists") return diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index 3679f8786..e30e05268 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -10,6 +10,7 @@ import shlex import textwrap import pkg_resources +import yaml from termcolor import cprint @@ -32,26 +33,31 @@ class DistributionInstall(Subcommand): self.parser.set_defaults(func=self._run_distribution_install_cmd) def _add_arguments(self): - from llama_toolchain.distribution.registry import available_distributions + from llama_toolchain.distribution.registry import available_distribution_specs + self.parser.add_argument( + "--spec", + type=str, + help="Distribution spec to install (try ollama-inline)", + required=True, + choices=[d.spec_id for d in available_distribution_specs()], + ) self.parser.add_argument( "--name", type=str, - help="Name of the distribution to install -- (try local-ollama)", + help="What should the installation be called locally?", required=True, - choices=[d.name for d in available_distributions()], ) self.parser.add_argument( "--conda-env", type=str, - help="Specify the name of the conda environment you would like to create or update", - required=True, + help="conda env in which this distribution will run (default = distribution name)", ) def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty from llama_toolchain.distribution.distribution import distribution_dependencies - from llama_toolchain.distribution.registry import resolve_distribution + from llama_toolchain.distribution.registry import resolve_distribution_spec os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) script = pkg_resources.resource_filename( @@ -59,25 +65,36 @@ class DistributionInstall(Subcommand): "distribution/install_distribution.sh", ) - dist = resolve_distribution(args.name) + dist = resolve_distribution_spec(args.spec) if dist is None: - self.parser.error(f"Could not find distribution {args.name}") + self.parser.error(f"Could not find distribution {args.spec}") return - os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True) + distrib_dir = DISTRIBS_BASE_DIR / args.name + os.makedirs(distrib_dir, exist_ok=True) deps = distribution_dependencies(dist) - return_code = run_with_pty([script, args.conda_env, " ".join(deps)]) + if not args.conda_env: + print(f"Using {args.name} as the Conda environment for this distribution") + + conda_env = args.conda_env or args.name + return_code = run_with_pty([script, conda_env, " ".join(deps)]) assert return_code == 0, cprint( - f"Failed to install distribution {dist.name}", color="red" + f"Failed to install distribution {dist.spec_id}", color="red" ) - with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f: - f.write(f"{args.conda_env}\n") + config_file = distrib_dir / "config.yaml" + with open(config_file, "w") as f: + c = { + "conda_env": conda_env, + "spec": dist.spec_id, + "name": args.name, + } + f.write(yaml.dump(c)) cprint( - f"Distribution `{dist.name}` has been installed successfully!", + f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!", color="green", ) print( @@ -85,8 +102,7 @@ class DistributionInstall(Subcommand): f""" Update your conda environment and configure this distribution by running: - conda deactivate && conda activate {args.conda_env} - llama distribution configure --name {dist.name} + conda deactivate && conda activate {conda_env} + llama distribution configure --name {args.name} """ - ) - ) + )) diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py index 3d6b69186..b285f2006 100644 --- a/llama_toolchain/cli/distribution/list.py +++ b/llama_toolchain/cli/distribution/list.py @@ -28,23 +28,23 @@ class DistributionList(Subcommand): def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.cli.table import print_table - from llama_toolchain.distribution.registry import available_distributions + from llama_toolchain.distribution.registry import available_distribution_specs # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ - "Name", + "Spec ID", "ProviderSpecs", "Description", ] rows = [] - for dist in available_distributions(): - providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()} + for spec in available_distribution_specs(): + providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()} rows.append( [ - dist.name, + spec.spec_id, json.dumps(providers, indent=2), - dist.description, + spec.description, ] ) print_table( diff --git a/llama_toolchain/cli/distribution/start.py b/llama_toolchain/cli/distribution/start.py index a1dbd9438..c106a237c 100644 --- a/llama_toolchain/cli/distribution/start.py +++ b/llama_toolchain/cli/distribution/start.py @@ -6,7 +6,6 @@ import argparse import shlex -from pathlib import Path import yaml @@ -49,22 +48,23 @@ class DistributionStart(Subcommand): def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_command - from llama_toolchain.distribution.registry import resolve_distribution + from llama_toolchain.distribution.registry import resolve_distribution_spec from llama_toolchain.distribution.server import main as distribution_server_init - dist = resolve_distribution(args.name) - if dist is None: - self.parser.error(f"Distribution with name {args.name} not found") + config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" + if not config_file.exists(): + self.parser.error( + f"Could not find {config_file}. Please run `llama distribution install` first" + ) return - config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" - if not config_yaml.exists(): - raise ValueError( - f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first" - ) + # we need to find the spec from the name + with open(config_file, "r") as f: + config = yaml.safe_load(f) - with open(config_yaml, "r") as fp: - config = yaml.safe_load(fp) + dist = resolve_distribution_spec(config["spec"]) + if dist is None: + raise ValueError(f"Could not find any registered spec `{config['spec']}`") conda_env = config["conda_env"] @@ -76,8 +76,7 @@ class DistributionStart(Subcommand): ) distribution_server_init( - dist.name, - config_yaml, + config_file, args.port, disable_ipv6=args.disable_ipv6, ) diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 85dcdae81..00aa07682 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, List +from typing import Dict, List, Optional from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type @@ -29,6 +29,10 @@ class ApiEndpoint(BaseModel): class ProviderSpec(BaseModel): api: Api provider_id: str + config_class: str = Field( + ..., + description="Fully-qualified classname of the config for this provider", + ) @json_schema_type @@ -45,23 +49,21 @@ Fully-qualified name of the module to import. The module is expected to have: - `get_provider_impl(config, deps)`: returns the local implementation """, ) - config_class: str = Field( - ..., - description="Fully-qualified classname of the config for this provider", - ) api_dependencies: List[Api] = Field( default_factory=list, description="Higher-level API surfaces may depend on other providers to provide their functionality", ) +class RemoteProviderConfig(BaseModel): + base_url: str = Field(..., description="The base URL for the llama stack provider") + api_key: Optional[str] = Field( + ..., description="API key, if needed, for the provider" + ) + + @json_schema_type class RemoteProviderSpec(ProviderSpec): - base_url: str = Field(..., description="The base URL for the llama stack provider") - headers: Dict[str, str] = Field( - default_factory=dict, - description="Headers (e.g., authorization) to send with the request", - ) module: str = Field( ..., description=""" @@ -69,10 +71,12 @@ Fully-qualified name of the module to import. The module is expected to have: - `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation """, ) + config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig" -class Distribution(BaseModel): - name: str +@json_schema_type +class DistributionSpec(BaseModel): + spec_id: str description: str provider_specs: Dict[Api, ProviderSpec] = Field( @@ -84,3 +88,12 @@ class Distribution(BaseModel): default_factory=list, description="Additional pip packages beyond those required by the providers", ) + + +@json_schema_type +class InstalledDistribution(BaseModel): + """References to a installed / configured DistributionSpec""" + + name: str + spec_id: str + # This is the class which represents the configs written by `configure` diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index 294f9bd4e..853092f38 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -8,13 +8,22 @@ import inspect from typing import Dict, List from llama_toolchain.agentic_system.api.endpoints import AgenticSystem +from llama_toolchain.agentic_system.providers import available_agentic_system_providers from llama_toolchain.inference.api.endpoints import Inference +from llama_toolchain.inference.providers import available_inference_providers from llama_toolchain.safety.api.endpoints import Safety +from llama_toolchain.safety.providers import available_safety_providers -from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec +from .datatypes import ( + Api, + ApiEndpoint, + DistributionSpec, + InlineProviderSpec, + ProviderSpec, +) -def distribution_dependencies(distribution: Distribution) -> List[str]: +def distribution_dependencies(distribution: DistributionSpec) -> List[str]: # only consider InlineProviderSpecs when calculating dependencies return [ dep @@ -51,3 +60,19 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: apis[api] = endpoints return apis + + +def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: + inference_providers_by_id = { + a.provider_id: a for a in available_inference_providers() + } + safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()} + agentic_system_providers_by_id = { + a.provider_id: a for a in available_agentic_system_providers() + } + + return { + Api.inference: inference_providers_by_id, + Api.safety: safety_providers_by_id, + Api.agentic_system: agentic_system_providers_by_id, + } diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 54ff74c62..1f3021599 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -7,12 +7,8 @@ from functools import lru_cache from typing import List, Optional -from llama_toolchain.agentic_system.providers import available_agentic_system_providers - -from llama_toolchain.inference.providers import available_inference_providers -from llama_toolchain.safety.providers import available_safety_providers - -from .datatypes import Api, Distribution, RemoteProviderSpec +from .datatypes import Api, DistributionSpec, RemoteProviderSpec +from .distribution import api_providers # This is currently duplicated from `requirements.txt` with a few minor changes # dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies @@ -49,39 +45,30 @@ def client_module(api: Api) -> str: return f"llama_toolchain.{api.value}.client" -def remote(api: Api, port: int) -> RemoteProviderSpec: +def remote_spec(api: Api) -> RemoteProviderSpec: return RemoteProviderSpec( api=api, provider_id=f"{api.value}-remote", - base_url=f"http://localhost:{port}", module=client_module(api), ) @lru_cache() -def available_distributions() -> List[Distribution]: - inference_providers_by_id = { - a.provider_id: a for a in available_inference_providers() - } - safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()} - agentic_system_providers_by_id = { - a.provider_id: a for a in available_agentic_system_providers() - } - +def available_distribution_specs() -> List[DistributionSpec]: + providers = api_providers() return [ - Distribution( - name="local-inline", + DistributionSpec( + spec_id="inline", description="Use code from `llama_toolchain` itself to serve all llama stack APIs", additional_pip_packages=COMMON_DEPENDENCIES, provider_specs={ - Api.inference: inference_providers_by_id["meta-reference"], - Api.safety: safety_providers_by_id["meta-reference"], - Api.agentic_system: agentic_system_providers_by_id["meta-reference"], + Api.inference: providers[Api.inference]["meta-reference"], + Api.safety: providers[Api.safety]["meta-reference"], + Api.agentic_system: providers[Api.agentic_system]["meta-reference"], }, ), - # NOTE: this hardcodes the ports to which things point to - Distribution( - name="full-remote", + DistributionSpec( + spec_id="remote", description="Point to remote services for all llama stack APIs", additional_pip_packages=[ "python-dotenv", @@ -97,28 +84,24 @@ def available_distributions() -> List[Distribution]: "pydantic_core==2.18.2", "uvicorn", ], - provider_specs={ - Api.inference: remote(Api.inference, 5001), - Api.safety: remote(Api.safety, 5001), - Api.agentic_system: remote(Api.agentic_system, 5001), - }, + provider_specs={x: remote_spec(x) for x in providers}, ), - Distribution( - name="local-ollama", + DistributionSpec( + spec_id="ollama-inline", description="Like local-source, but use ollama for running LLM inference", additional_pip_packages=COMMON_DEPENDENCIES, provider_specs={ - Api.inference: inference_providers_by_id["meta-ollama"], - Api.safety: safety_providers_by_id["meta-reference"], - Api.agentic_system: agentic_system_providers_by_id["meta-reference"], + Api.inference: providers[Api.inference]["meta-ollama"], + Api.safety: providers[Api.safety]["meta-reference"], + Api.agentic_system: providers[Api.agentic_system]["meta-reference"], }, ), ] @lru_cache() -def resolve_distribution(name: str) -> Optional[Distribution]: - for dist in available_distributions(): - if dist.name == name: - return dist +def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]: + for spec in available_distribution_specs(): + if spec.spec_id == spec_id: + return spec return None diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 4c61d1e40..39a6dac80 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -36,11 +36,11 @@ from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint -from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec +from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec from .distribution import api_endpoints from .dynamic import instantiate_client, instantiate_provider -from .registry import resolve_distribution +from .registry import resolve_distribution_spec load_dotenv() @@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: return [by_id[x] for x in stack] -def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]: +def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]: provider_configs = config["providers"] provider_specs = topological_sort(dist.provider_specs.values()) @@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]: provider_config = provider_configs[api.value] if isinstance(provider_spec, RemoteProviderSpec): impls[api] = instantiate_client( - provider_spec, provider_spec.base_url.rstrip("/") + provider_spec, provider_config.base_url.rstrip("/") ) else: deps = {api: impls[api] for api in provider_spec.api_dependencies} @@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]: return impls -def main( - dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False -): - dist = resolve_distribution(dist_name) - if dist is None: - raise ValueError(f"Could not find distribution {dist_name}") - +def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: config = yaml.safe_load(fp) + spec = config["spec"] + dist = resolve_distribution_spec(spec) + if dist is None: + raise ValueError(f"Could not find distribution specification `{spec}`") + app = FastAPI() all_endpoints = api_endpoints() @@ -293,14 +292,15 @@ def main( for provider_spec in dist.provider_specs.values(): api = provider_spec.api endpoints = all_endpoints[api] + impl = impls[api] + if isinstance(provider_spec, RemoteProviderSpec): for endpoint in endpoints: - url = provider_spec.base_url.rstrip("/") + endpoint.route + url = impl.base_url + endpoint.route getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) else: - impl = impls[api] for endpoint in endpoints: if not hasattr(impl, endpoint.name): # ideally this should be a typing violation already