diff --git a/llama_toolchain/cli/api/api.py b/llama_toolchain/cli/api/api.py index e6d7e0a1e..e482355d2 100644 --- a/llama_toolchain/cli/api/api.py +++ b/llama_toolchain/cli/api/api.py @@ -9,6 +9,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand from .build import ApiBuild +from .configure import ApiConfigure class ApiParser(Subcommand): @@ -24,3 +25,4 @@ class ApiParser(Subcommand): # Add sub-commands ApiBuild.create(subparsers) + ApiConfigure.create(subparsers) diff --git a/llama_toolchain/cli/api/build.py b/llama_toolchain/cli/api/build.py index f3f620b1f..0b589cdcb 100644 --- a/llama_toolchain/cli/api/build.py +++ b/llama_toolchain/cli/api/build.py @@ -6,22 +6,92 @@ import argparse import os +import random +import string +import uuid +from pydantic import BaseModel +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional, Tuple import pkg_resources import yaml -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR - from termcolor import cprint +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR +from llama_toolchain.distribution.datatypes import * # noqa: F403 + + +def random_string(): + return "".join(random.choices(string.ascii_letters + string.digits, k=8)) + + +class BuildType(Enum): + container = "container" + conda_env = "conda_env" + + +class Dependencies(BaseModel): + pip_packages: List[str] + docker_image: Optional[str] = None + + +def get_dependencies( + provider: ProviderSpec, dependencies: Dict[str, ProviderSpec] +) -> Dependencies: + def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]: + if isinstance(provider, InlineProviderSpec): + return provider.pip_packages, provider.docker_image + else: + if provider.adapter: + return provider.adapter.pip_packages, None + return [], None + + pip_packages, docker_image = _deps(provider) + for dep in dependencies.values(): + dep_pip_packages, dep_docker_image = _deps(dep) + if docker_image and dep_docker_image: + raise ValueError( + "You can only have the root provider specify a docker image" + ) + + pip_packages.extend(dep_pip_packages) + + return Dependencies(docker_image=docker_image, pip_packages=pip_packages) + + +def parse_dependencies( + dependencies: str, parser: argparse.ArgumentParser +) -> Dict[str, ProviderSpec]: + from llama_toolchain.distribution.distribution import api_providers + + all_providers = api_providers() + + deps = {} + for dep in dependencies.split(","): + dep = dep.strip() + if not dep: + continue + api_str, provider = dep.split("=") + api = Api(api_str) + + provider = provider.strip() + if provider not in all_providers[api]: + parser.error(f"Provider `{provider}` is not available for API `{api}`") + return + deps[api] = all_providers[api][provider] + + return deps + class ApiBuild(Subcommand): def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( - "install", + "build", prog="llama api build", description="Build a Llama stack API provider container", formatter_class=argparse.RawTextHelpFormatter, @@ -36,7 +106,7 @@ class ApiBuild(Subcommand): self.parser.add_argument( "api", choices=allowed_args, - help="Stack API (one of: {})".format(", ".join(allowed_args)) + help="Stack API (one of: {})".format(", ".join(allowed_args)), ) self.parser.add_argument( @@ -45,73 +115,104 @@ class ApiBuild(Subcommand): help="The provider to package into the container", required=True, ) - self.parser.add_argument( - "--container-name", - type=str, - help="Name of the container (including tag if needed)", - required=True, - ) self.parser.add_argument( "--dependencies", type=str, help="Comma separated list of (downstream_api=provider) dependencies needed for the API", required=False, ) + self.parser.add_argument( + "--name", + type=str, + help="Name of the build target (image, conda env). Defaults to a random UUID", + required=False, + ) + self.parser.add_argument( + "--type", + type=str, + default="container", + choices=[v.value for v in BuildType], + ) def _run_api_build_command(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty - from llama_toolchain.distribution.datatypes import DistributionConfig - from llama_toolchain.distribution.distribution import distribution_dependencies - from llama_toolchain.distribution.registry import resolve_distribution_spec + from llama_toolchain.distribution.distribution import api_providers - os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) - script = pkg_resources.resource_filename( - "llama_toolchain", - "distribution/build_api.sh", - ) + os.makedirs(BUILDS_BASE_DIR, exist_ok=True) + all_providers = api_providers() - dist = resolve_distribution_spec(args.spec) - if dist is None: - self.parser.error(f"Could not find distribution {args.spec}") + api = Api(args.api) + assert api in all_providers + + providers = all_providers[api] + if args.provider not in providers: + self.parser.error( + f"Provider `{args.provider}` is not available for API `{api}`" + ) return - distrib_dir = DISTRIBS_BASE_DIR / args.name - os.makedirs(distrib_dir, exist_ok=True) - - deps = distribution_dependencies(dist) - if not args.conda_env: - print(f"Using {args.name} as the Conda environment for this distribution") - - conda_env = args.conda_env or args.name - - config_file = distrib_dir / "config.yaml" - if config_file.exists(): - c = DistributionConfig(**yaml.safe_load(config_file.read_text())) - if c.spec != dist.spec_id: - self.parser.error( - f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`" - ) - return - if c.conda_env != conda_env: - self.parser.error( - f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`" - ) - return + name = args.name or random_string() + if args.type == BuildType.container.value: + package_name = f"image-{args.provider}-{name}" else: - with open(config_file, "w") as f: - c = DistributionConfig( - spec=dist.spec_id, - name=args.name, - conda_env=conda_env, - ) - f.write(yaml.dump(c.dict(), sort_keys=False)) + package_name = f"env-{args.provider}-{name}" + package_name = package_name.replace("::", "-") - return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)]) + build_dir = BUILDS_BASE_DIR / args.api + os.makedirs(build_dir, exist_ok=True) + provider_deps = parse_dependencies(args.dependencies or "", self.parser) + dependencies = get_dependencies(providers[args.provider], provider_deps) + + package_file = build_dir / f"{package_name}.yaml" + + stub_config = { + api.value: { + "provider_id": args.provider, + }, + **{k: {"provider_id": v} for k, v in provider_deps.items()}, + } + with open(package_file, "w") as f: + c = PackageConfig( + built_at=datetime.now(), + package_name=package_name, + docker_image=( + package_name if args.type == BuildType.container.value else None + ), + conda_env=( + package_name if args.type == BuildType.conda_env.value else None + ), + providers=stub_config, + ) + f.write(yaml.dump(c.dict(), sort_keys=False)) + + if args.type == BuildType.container.value: + script = pkg_resources.resource_filename( + "llama_toolchain", "distribution/build_container.sh" + ) + args = [ + script, + args.api, + package_name, + dependencies.docker_image or "python:3.10-slim", + " ".join(dependencies.pip_packages), + ] + else: + script = pkg_resources.resource_filename( + "llama_toolchain", "distribution/build_conda_env.sh" + ) + args = [ + script, + args.api, + package_name, + " ".join(dependencies.pip_packages), + ] + + return_code = run_with_pty(args) assert return_code == 0, cprint( - f"Failed to install distribution {dist.spec_id}", color="red" + f"Failed to build target {package_name}", color="red" ) cprint( - f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!", + f"Target `{target_name}` built with configuration at {str(package_file)}", color="green", ) diff --git a/llama_toolchain/cli/api/configure.py b/llama_toolchain/cli/api/configure.py new file mode 100644 index 000000000..c2cb8b16f --- /dev/null +++ b/llama_toolchain/cli/api/configure.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json + +from pathlib import Path + +import yaml + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR +from llama_toolchain.distribution.datatypes import * # noqa: F403 +from termcolor import cprint + + +class ApiConfigure(Subcommand): + """Llama cli for configuring llama toolchain configs""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "configure", + prog="llama api configure", + description="configure a llama stack API provider", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_api_configure_cmd) + + def _add_arguments(self): + from llama_toolchain.distribution.distribution import stack_apis + + allowed_args = [a.name for a in stack_apis()] + self.parser.add_argument( + "api", + choices=allowed_args, + help="Stack API (one of: {})".format(", ".join(allowed_args)), + ) + self.parser.add_argument( + "--name", + type=str, + help="Name of the provider build to fully configure", + required=True, + ) + + def _run_api_configure_cmd(self, args: argparse.Namespace) -> None: + config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml" + if not config_file.exists(): + self.parser.error( + f"Could not find {config_file}. Please run `llama api build` first" + ) + return + + configure_llama_provider(config_file) + + +def configure_llama_provider(config_file: Path) -> None: + from llama_toolchain.common.prompt_for_config import prompt_for_config + from llama_toolchain.common.serialize import EnumEncoder + from llama_toolchain.distribution.distribution import api_providers + from llama_toolchain.distribution.dynamic import instantiate_class_type + + with open(config_file, "r") as f: + config = PackageConfig(**yaml.safe_load(f)) + + all_providers = api_providers() + + provider_configs = {} + for api, stub_config in config.providers.items(): + providers = all_providers[Api(api)] + provider_id = stub_config["provider_id"] + if provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + + provider_spec = providers[provider_id] + cprint(f"Configuring API surface: {api}", "white", attrs=["bold"]) + config_type = instantiate_class_type(provider_spec.config_class) + print(f"Config type: {config_type}") + provider_config = prompt_for_config( + config_type, + ) + print("") + + provider_configs[api.value] = { + "provider_id": provider_id, + **provider_config.dict(), + } + + config.providers = provider_configs + with open(config_file, "w") as fp: + to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) + fp.write(yaml.dump(to_write, sort_keys=False)) + + print(f"YAML configuration has been written to {config_path}") diff --git a/llama_toolchain/common/config_dirs.py b/llama_toolchain/common/config_dirs.py index e625234ab..adf3876a3 100644 --- a/llama_toolchain/common/config_dirs.py +++ b/llama_toolchain/common/config_dirs.py @@ -13,3 +13,5 @@ LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/")) DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" + +BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds" diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index 6c53477d8..c87716750 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -71,6 +71,7 @@ def prompt_for_config( """ config_data = {} + print(f"Configuring {config_type.__name__}:") for field_name, field in config_type.__fields__.items(): field_type = field.annotation @@ -85,6 +86,7 @@ def prompt_for_config( if not isinstance(field.default, PydanticUndefinedType) else None ) + print(f" {field_name}: {field_type} (default: {default_value})") is_required = field.is_required # Skip fields with Literal type diff --git a/llama_toolchain/distribution/install_distribution.sh b/llama_toolchain/distribution/build_conda_env.sh similarity index 83% rename from llama_toolchain/distribution/install_distribution.sh rename to llama_toolchain/distribution/build_conda_env.sh index 7cb343cfb..0b45edf09 100755 --- a/llama_toolchain/distribution/install_distribution.sh +++ b/llama_toolchain/distribution/build_conda_env.sh @@ -10,20 +10,29 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} +echo "llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR" set -euo pipefail +if [ "$#" -ne 3 ]; then + echo "Usage: $0 " >&2 + echo "Example: $0 [api|stack] conda-env 'numpy pandas scipy'" >&2 + exit 1 +fi + +api_or_stack="$1" +env_name="$2" +pip_dependencies="$3" + # Define color codes RED='\033[0;31m' GREEN='\033[0;32m' NC='\033[0m' # No Color -error_handler() { - echo "Error occurred in script at line: ${1}" >&2 - exit 1 -} +# this is set if we actually create a new conda in which case we need to clean up +ENVNAME="" -# Set up the error trap -trap 'error_handler ${LINENO}' ERR +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +source "$SCRIPT_DIR/common.sh" ensure_conda_env_python310() { local env_name="$1" @@ -52,6 +61,9 @@ ensure_conda_env_python310() { else echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..." conda create -n "${env_name}" python="${python_version}" -y + + ENVNAME="${env_name}" + setup_cleanup_handlers fi eval "$(conda shell.bash hook)" @@ -94,19 +106,8 @@ ensure_conda_env_python310() { fi } -if [ "$#" -ne 3 ]; then - echo "Usage: $0 " >&2 - echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2 - exit 1 -fi - -env_name="$1" -distribution_name="$2" -pip_dependencies="$3" - ensure_conda_env_python310 "$env_name" "$pip_dependencies" -echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}" +echo -e "${GREEN}Successfully setup conda environment. Configuring build...${NC}" -which python3 -python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name" +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name" diff --git a/llama_toolchain/distribution/build_image.sh b/llama_toolchain/distribution/build_container.sh old mode 100644 new mode 100755 similarity index 100% rename from llama_toolchain/distribution/build_image.sh rename to llama_toolchain/distribution/build_container.sh diff --git a/llama_toolchain/distribution/common.sh b/llama_toolchain/distribution/common.sh new file mode 100644 index 000000000..963eb395b --- /dev/null +++ b/llama_toolchain/distribution/common.sh @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +cleanup() { + envname="$1" + + set +x + echo "Cleaning up..." + conda deactivate + conda env remove --name $envname -y +} + +handle_int() { + if [ -n $ENVNAME ]; then + cleanup $ENVNAME + fi + exit 1 +} + +handle_exit() { + if [ $? -ne 0 ]; then + echo -e "\033[1;31mABORTING.\033[0m" + if [ -n $ENVNAME ]; then + cleanup $ENVNAME + fi + fi +} + +setup_cleanup_handlers() { + trap handle_int INT + trap handle_exit EXIT + + __conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)" + eval "$__conda_setup" + + conda deactivate +} diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index fbfc5aaed..cb4f06fe1 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional @@ -66,36 +67,45 @@ Fully-qualified name of the module to import. The module is expected to have: @json_schema_type class AdapterSpec(BaseModel): + adapter_id: str = Field( + ..., + description="Unique identifier for this adapter", + ) + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation +""", + ) pip_packages: List[str] = Field( default_factory=list, description="The pip dependencies needed for this implementation", ) - config_class: str = Field( - ..., + config_class: Optional[str] = Field( + default=None, description="Fully-qualified classname of the config for this provider", ) class RemoteProviderConfig(BaseModel): - base_url: str = Field(..., description="The base URL for the llama stack provider") + url: str = Field(..., description="The URL for the provider") - @validator("base_url") + @validator("url") @classmethod - def validate_base_url(cls, base_url: str) -> str: - if not base_url.startswith("http"): - raise ValueError(f"URL must start with http: {base_url}") - return base_url + def validate_url(cls, url: str) -> str: + if not url.startswith("http"): + raise ValueError(f"URL must start with http: {url}") + return url + + +def remote_provider_id(adapter_id: str) -> str: + return f"remote::{adapter_id}" @json_schema_type class RemoteProviderSpec(ProviderSpec): - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation -""", - ) adapter: Optional[AdapterSpec] = Field( default=None, description=""" @@ -107,6 +117,32 @@ as being "Llama Stack compatible" config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig" +# need this wrapper since we don't have Pydantic v2 and that means we don't have +# the @computed_field decorator +def remote_provider_spec( + api: Api, adapter: Optional[AdapterSpec] = None +) -> RemoteProviderSpec: + provider_id = ( + remote_provider_id(adapter.adapter_id) if adapter is not None else "remote" + ) + module = ( + adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client" + ) + config_class = ( + adapter.config_class + if adapter and adapter.config_class + else "llama_toolchain.distribution.datatypes.RemoteProviderConfig" + ) + + return RemoteProviderSpec( + api=api, + provider_id=provider_id, + pip_packages=adapter.pip_packages if adapter is not None else [], + module=module, + config_class=config_class, + ) + + @json_schema_type class DistributionSpec(BaseModel): spec_id: str @@ -119,13 +155,28 @@ class DistributionSpec(BaseModel): @json_schema_type -class DistributionConfig(BaseModel): - """References to a installed / configured DistributionSpec""" +class PackageConfig(BaseModel): + built_at: datetime - name: str - spec: str - conda_env: str + package_name: str = Field( + ..., + description=""" +Reference to the distribution this package refers to. For unregistered (adhoc) packages, +this could be just a hash +""", + ) + docker_image: Optional[str] = Field( + default=None, + description="Reference to the docker image if this package refers to a container", + ) + conda_env: Optional[str] = Field( + default=None, + description="Reference to the conda environment if this package refers to a conda environment", + ) providers: Dict[str, Any] = Field( default_factory=dict, - description="Provider configurations for each of the APIs provided by this distribution", + description=""" +Provider configurations for each of the APIs provided by this package. This includes configurations for +the dependencies of these providers as well. +""", ) diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index 20fa038bf..f4057bbae 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -30,7 +30,27 @@ def instantiate_provider( return asyncio.run(module.get_provider_impl(config, deps)) -def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str): +def instantiate_client( + provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any] +): module = importlib.import_module(provider_spec.module) - return asyncio.run(module.get_client_impl(base_url)) + adapter = provider_spec.adapter + if adapter is not None: + if "adapter" not in provider_config: + raise ValueError( + f"Adapter is specified but not present in provider config: {provider_config}" + ) + adapter_config = provider_config["adapter"] + + config_type = instantiate_class_type(adapter.config_class) + if not issubclass(config_type, RemoteProviderConfig): + raise ValueError( + f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig" + ) + + config = config_type(**adapter_config) + else: + config = RemoteProviderConfig(**provider_config) + + return asyncio.run(module.get_adapter_impl(config)) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index bce702472..e1d49eb05 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -7,22 +7,10 @@ from functools import lru_cache from typing import List, Optional -from .datatypes import Api, DistributionSpec, RemoteProviderSpec +from .datatypes import * # noqa: F403 from .distribution import api_providers -def client_module(api: Api) -> str: - return f"llama_toolchain.{api.value}.client" - - -def remote_spec(api: Api) -> RemoteProviderSpec: - return RemoteProviderSpec( - api=api, - provider_id=f"{api.value}-remote", - module=client_module(api), - ) - - @lru_cache() def available_distribution_specs() -> List[DistributionSpec]: providers = api_providers() @@ -40,13 +28,14 @@ def available_distribution_specs() -> List[DistributionSpec]: DistributionSpec( spec_id="remote", description="Point to remote services for all llama stack APIs", - provider_specs={x: remote_spec(x) for x in providers}, + provider_specs={x: remote_provider_spec(x) for x in providers}, ), DistributionSpec( spec_id="local-ollama", description="Like local, but use ollama for running LLM inference", provider_specs={ - Api.inference: providers[Api.inference]["meta-ollama"], + # this is ODD; make this easier -- we just need a better function to retrieve registered providers + Api.inference: providers[Api.inference][remote_provider_id("ollama")], Api.safety: providers[Api.safety]["meta-reference"], Api.agentic_system: providers[Api.agentic_system]["meta-reference"], Api.memory: providers[Api.memory]["meta-reference-faiss"], @@ -57,9 +46,9 @@ def available_distribution_specs() -> List[DistributionSpec]: description="Test agentic with others as remote", provider_specs={ Api.agentic_system: providers[Api.agentic_system]["meta-reference"], - Api.inference: remote_spec(Api.inference), - Api.memory: remote_spec(Api.memory), - Api.safety: remote_spec(Api.safety), + Api.inference: remote_provider_spec(Api.inference), + Api.memory: remote_provider_spec(Api.memory), + Api.safety: remote_provider_spec(Api.safety), }, ), DistributionSpec( diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index dd92fd43e..271f6e87b 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -264,7 +264,8 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A provider_config = provider_configs[api.value] if isinstance(provider_spec, RemoteProviderSpec): impls[api] = instantiate_client( - provider_spec, provider_config["base_url"].rstrip("/") + provider_spec, + provider_config, ) else: deps = {api: impls[api] for api in provider_spec.api_dependencies} diff --git a/llama_toolchain/inference/adapters/__init__.py b/llama_toolchain/inference/adapters/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/inference/adapters/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/inference/ollama/__init__.py b/llama_toolchain/inference/adapters/ollama/__init__.py similarity index 68% rename from llama_toolchain/inference/ollama/__init__.py rename to llama_toolchain/inference/adapters/ollama/__init__.py index 40d79618a..dea3d660f 100644 --- a/llama_toolchain/inference/ollama/__init__.py +++ b/llama_toolchain/inference/adapters/ollama/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import OllamaImplConfig # noqa -from .ollama import get_provider_impl # noqa +from .ollama import get_adapter_impl # noqa diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/adapters/ollama/ollama.py similarity index 94% rename from llama_toolchain/inference/ollama/ollama.py rename to llama_toolchain/inference/adapters/ollama/ollama.py index b1e1ca09c..05423fb96 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/adapters/ollama/ollama.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, Dict +from typing import AsyncGenerator import httpx @@ -14,7 +14,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from ollama import AsyncClient -from llama_toolchain.distribution.datatypes import Api, ProviderSpec +from llama_toolchain.distribution.datatypes import RemoteProviderConfig from llama_toolchain.inference.api import ( ChatCompletionRequest, ChatCompletionResponse, @@ -27,7 +27,6 @@ from llama_toolchain.inference.api import ( ToolCallParseStatus, ) from llama_toolchain.inference.prepare_messages import prepare_messages -from .config import OllamaImplConfig # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models @@ -37,26 +36,21 @@ OLLAMA_SUPPORTED_SKUS = { } -async def get_provider_impl( - config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec] -) -> Inference: - assert isinstance( - config, OllamaImplConfig - ), f"Unexpected config type: {type(config)}" - impl = OllamaInference(config) +async def get_adapter_impl(config: RemoteProviderConfig) -> Inference: + impl = OllamaInferenceAdapter(config.url) await impl.initialize() return impl -class OllamaInference(Inference): - def __init__(self, config: OllamaImplConfig) -> None: - self.config = config +class OllamaInferenceAdapter(Inference): + def __init__(self, url: str) -> None: + self.url = url tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @property def client(self) -> AsyncClient: - return AsyncClient(host=self.config.url) + return AsyncClient(host=self.url) async def initialize(self) -> None: try: diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index ec7ed859b..d2ea88670 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -13,6 +13,8 @@ import httpx from pydantic import BaseModel from termcolor import cprint +from llama_toolchain.distribution.datatypes import RemoteProviderConfig + from .api import ( ChatCompletionRequest, ChatCompletionResponse, @@ -24,8 +26,8 @@ from .api import ( from .event_logger import EventLogger -async def get_client_impl(base_url: str): - return InferenceClient(base_url) +async def get_adapter_impl(config: RemoteProviderConfig) -> Inference: + return InferenceClient(config.url) def encodable_dict(d: BaseModel): @@ -34,7 +36,7 @@ def encodable_dict(d: BaseModel): class InferenceClient(Inference): def __init__(self, base_url: str): - print(f"Initializing client for {base_url}") + print(f"Inference passthrough to -> {base_url}") self.base_url = base_url async def initialize(self) -> None: diff --git a/llama_toolchain/inference/ollama/config.py b/llama_toolchain/inference/ollama/config.py deleted file mode 100644 index 10d109822..000000000 --- a/llama_toolchain/inference/ollama/config.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field - - -@json_schema_type -class OllamaImplConfig(BaseModel): - url: str = Field( - default="http://localhost:11434", - description="The URL for the ollama server", - ) diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index 1b1eb05a4..f50bb759e 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -6,7 +6,7 @@ from typing import List -from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_toolchain.distribution.datatypes import * # noqa: F403 def available_inference_providers() -> List[ProviderSpec]: @@ -27,13 +27,12 @@ def available_inference_providers() -> List[ProviderSpec]: module="llama_toolchain.inference.meta_reference", config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", ), - InlineProviderSpec( + remote_provider_spec( api=Api.inference, - provider_id="meta-ollama", - pip_packages=[ - "ollama", - ], - module="llama_toolchain.inference.ollama", - config_class="llama_toolchain.inference.ollama.OllamaImplConfig", + adapter=AdapterSpec( + adapter_id="ollama", + pip_packages=[], + module="llama_toolchain.inference.adapters.ollama", + ), ), ] diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index d4009a190..cc8b80977 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -6,24 +6,23 @@ import asyncio -# import json from typing import Dict, List, Optional import fire import httpx -# from termcolor import cprint +from llama_toolchain.distribution.datatypes import RemoteProviderConfig from .api import * # noqa: F403 -async def get_client_impl(base_url: str): - return MemoryClient(base_url) +async def get_adapter_impl(config: RemoteProviderConfig) -> Memory: + return MemoryClient(config.url) class MemoryClient(Memory): def __init__(self, base_url: str): - print(f"Initializing client for {base_url}") + print(f"Memory passthrough to -> {base_url}") self.base_url = base_url async def initialize(self) -> None: diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 0fbc4c7c0..e3676d240 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -10,20 +10,17 @@ import fire import httpx from llama_models.llama3.api.datatypes import UserMessage + from pydantic import BaseModel from termcolor import cprint -from .api import ( - BuiltinShield, - RunShieldRequest, - RunShieldResponse, - Safety, - ShieldDefinition, -) +from llama_toolchain.distribution.datatypes import RemoteProviderConfig + +from .api import * # noqa: F403 -async def get_client_impl(base_url: str): - return SafetyClient(base_url) +async def get_adapter_impl(config: RemoteProviderConfig) -> Safety: + return SafetyClient(config.url) def encodable_dict(d: BaseModel): @@ -32,7 +29,7 @@ def encodable_dict(d: BaseModel): class SafetyClient(Safety): def __init__(self, base_url: str): - print(f"Initializing client for {base_url}") + print(f"Safety passthrough to -> {base_url}") self.base_url = base_url async def initialize(self) -> None: