From 2a1552a5ebb79343a3484d42635b5f7b389e7466 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 28 Aug 2024 06:51:07 -0700 Subject: [PATCH] ollama remote adapter works --- llama_toolchain/cli/api/api.py | 2 + llama_toolchain/cli/api/build.py | 25 ++--- llama_toolchain/cli/api/start.py | 83 ++++++++++++++++ llama_toolchain/distribution/datatypes.py | 95 ++++++++++--------- llama_toolchain/distribution/dynamic.py | 39 +++----- llama_toolchain/distribution/server.py | 54 ++++++----- ...art_distribution.sh => start_conda_env.sh} | 10 +- .../{run_image.sh => start_container.sh} | 0 .../inference/adapters/ollama/__init__.py | 2 +- .../inference/adapters/ollama/ollama.py | 4 +- llama_toolchain/inference/client.py | 2 +- llama_toolchain/inference/providers.py | 4 +- llama_toolchain/memory/client.py | 2 +- llama_toolchain/safety/client.py | 2 +- 14 files changed, 196 insertions(+), 128 deletions(-) create mode 100644 llama_toolchain/cli/api/start.py rename llama_toolchain/distribution/{start_distribution.sh => start_conda_env.sh} (69%) rename llama_toolchain/distribution/{run_image.sh => start_container.sh} (100%) mode change 100644 => 100755 diff --git a/llama_toolchain/cli/api/api.py b/llama_toolchain/cli/api/api.py index e482355d2..97546aa2a 100644 --- a/llama_toolchain/cli/api/api.py +++ b/llama_toolchain/cli/api/api.py @@ -10,6 +10,7 @@ from llama_toolchain.cli.subcommand import Subcommand from .build import ApiBuild from .configure import ApiConfigure +from .start import ApiStart class ApiParser(Subcommand): @@ -26,3 +27,4 @@ class ApiParser(Subcommand): # Add sub-commands ApiBuild.create(subparsers) ApiConfigure.create(subparsers) + ApiStart.create(subparsers) diff --git a/llama_toolchain/cli/api/build.py b/llama_toolchain/cli/api/build.py index 4e8f7cfba..0f07f3a62 100644 --- a/llama_toolchain/cli/api/build.py +++ b/llama_toolchain/cli/api/build.py @@ -7,9 +7,6 @@ import argparse import json import os -import random -import string -import uuid from pydantic import BaseModel from datetime import datetime from enum import Enum @@ -25,10 +22,6 @@ from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_toolchain.distribution.datatypes import * # noqa: F403 -def random_string(): - return "".join(random.choices(string.ascii_letters + string.digits, k=8)) - - class BuildType(Enum): container = "container" conda_env = "conda_env" @@ -42,6 +35,8 @@ class Dependencies(BaseModel): def get_dependencies( provider: ProviderSpec, dependencies: Dict[str, ProviderSpec] ) -> Dependencies: + from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES + def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]: if isinstance(provider, InlineProviderSpec): return provider.pip_packages, provider.docker_image @@ -60,7 +55,9 @@ def get_dependencies( pip_packages.extend(dep_pip_packages) - return Dependencies(docker_image=docker_image, pip_packages=pip_packages) + return Dependencies( + docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES + ) def parse_dependencies( @@ -88,7 +85,6 @@ def parse_dependencies( class ApiBuild(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( @@ -125,8 +121,8 @@ class ApiBuild(Subcommand): self.parser.add_argument( "--name", type=str, - help="Name of the build target (image, conda env). Defaults to a random UUID", - required=False, + help="Name of the build target (image, conda env)", + required=True, ) self.parser.add_argument( "--type", @@ -153,11 +149,10 @@ class ApiBuild(Subcommand): ) return - name = args.name or random_string() if args.type == BuildType.container.value: - package_name = f"image-{args.provider}-{name}" + package_name = f"image-{args.provider}-{args.name}" else: - package_name = f"env-{args.provider}-{name}" + package_name = f"env-{args.provider}-{args.name}" package_name = package_name.replace("::", "-") build_dir = BUILDS_BASE_DIR / args.api @@ -176,7 +171,7 @@ class ApiBuild(Subcommand): } with open(package_file, "w") as f: c = PackageConfig( - built_at=str(datetime.now()), + built_at=datetime.now(), package_name=package_name, docker_image=( package_name if args.type == BuildType.container.value else None diff --git a/llama_toolchain/cli/api/start.py b/llama_toolchain/cli/api/start.py new file mode 100644 index 000000000..e10bafe5e --- /dev/null +++ b/llama_toolchain/cli/api/start.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from pathlib import Path + +import pkg_resources +import yaml + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.distribution.datatypes import * # noqa: F403 + + +class ApiStart(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "start", + prog="llama api start", + description="""start the server for a Llama API provider. You should have already built and configured the provider.""", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_api_start_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "--yaml-config", + type=str, + help="Yaml config containing the API build configuration", + required=True, + ) + self.parser.add_argument( + "--port", + type=int, + help="Port to run the server on. Defaults to 5000", + default=5000, + ) + self.parser.add_argument( + "--disable-ipv6", + action="store_true", + help="Disable IPv6 support", + default=False, + ) + + def _run_api_start_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.common.exec import run_with_pty + + config_file = Path(args.yaml_config) + if not config_file.exists(): + self.parser.error( + f"Could not find {config_file}. Please run `llama api build` first" + ) + return + + with open(config_file, "r") as f: + config = PackageConfig(**yaml.safe_load(f)) + + if config.docker_image: + script = pkg_resources.resource_filename( + "llama_toolchain", + "distribution/start_container.sh", + ) + run_args = [script, config.docker_image] + else: + script = pkg_resources.resource_filename( + "llama_toolchain", + "distribution/start_conda_env.sh", + ) + run_args = [ + script, + config.conda_env, + ] + + run_args.extend(["--yaml_config", str(config_file), "--port", str(args.port)]) + if args.disable_ipv6: + run_args.append("--disable-ipv6") + + run_with_pty(run_args) diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index cb4f06fe1..51f582432 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -38,6 +38,36 @@ class ProviderSpec(BaseModel): ) +@json_schema_type +class AdapterSpec(BaseModel): + """ + If some code is needed to convert the remote responses into Llama Stack compatible + API responses, specify the adapter here. If not specified, it indicates the remote + as being "Llama Stack compatible" + """ + + adapter_id: str = Field( + ..., + description="Unique identifier for this adapter", + ) + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation +""", + ) + pip_packages: List[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + config_class: Optional[str] = Field( + default=None, + description="Fully-qualified classname of the config for this provider", + ) + + @json_schema_type class InlineProviderSpec(ProviderSpec): pip_packages: List[str] = Field( @@ -63,30 +93,7 @@ Fully-qualified name of the module to import. The module is expected to have: default_factory=list, description="Higher-level API surfaces may depend on other providers to provide their functionality", ) - - -@json_schema_type -class AdapterSpec(BaseModel): - adapter_id: str = Field( - ..., - description="Unique identifier for this adapter", - ) - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_adapter_impl(config, deps)`: returns the adapter implementation -""", - ) - pip_packages: List[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - config_class: Optional[str] = Field( - default=None, - description="Fully-qualified classname of the config for this provider", - ) + is_adapter: bool = False class RemoteProviderConfig(BaseModel): @@ -106,40 +113,34 @@ def remote_provider_id(adapter_id: str) -> str: @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: Optional[AdapterSpec] = Field( - default=None, - description=""" -If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. If not specified, it indicates the remote -as being "Llama Stack compatible" -""", - ) + provider_id: str = "remote" config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig" + @property + def module(self) -> str: + return f"llama_toolchain.{self.api.value}.client" -# need this wrapper since we don't have Pydantic v2 and that means we don't have + +def remote_provider_spec(api: Api) -> RemoteProviderSpec: + return RemoteProviderSpec(api=api) + + +# TODO: use computed_field to avoid this wrapper # the @computed_field decorator -def remote_provider_spec( - api: Api, adapter: Optional[AdapterSpec] = None -) -> RemoteProviderSpec: - provider_id = ( - remote_provider_id(adapter.adapter_id) if adapter is not None else "remote" - ) - module = ( - adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client" - ) +def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec: config_class = ( adapter.config_class - if adapter and adapter.config_class + if adapter.config_class else "llama_toolchain.distribution.datatypes.RemoteProviderConfig" ) - return RemoteProviderSpec( + return InlineProviderSpec( api=api, - provider_id=provider_id, - pip_packages=adapter.pip_packages if adapter is not None else [], - module=module, + provider_id=remote_provider_id(adapter.adapter_id), + pip_packages=adapter.pip_packages, + module=adapter.module, config_class=config_class, + is_adapter=True, ) diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index f4057bbae..3135e2aff 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -8,7 +8,7 @@ import asyncio import importlib from typing import Any, Dict -from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec +from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig def instantiate_class_type(fully_qualified_name): @@ -19,38 +19,23 @@ def instantiate_class_type(fully_qualified_name): # returns a class implementing the protocol corresponding to the Api def instantiate_provider( - provider_spec: InlineProviderSpec, + provider_spec: ProviderSpec, provider_config: Dict[str, Any], deps: Dict[str, ProviderSpec], ): module = importlib.import_module(provider_spec.module) config_type = instantiate_class_type(provider_spec.config_class) + if isinstance(provider_spec, InlineProviderSpec): + if provider_spec.is_adapter: + if not issubclass(config_type, RemoteProviderConfig): + raise ValueError( + f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig" + ) config = config_type(**provider_config) - return asyncio.run(module.get_provider_impl(config, deps)) - -def instantiate_client( - provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any] -): - module = importlib.import_module(provider_spec.module) - - adapter = provider_spec.adapter - if adapter is not None: - if "adapter" not in provider_config: - raise ValueError( - f"Adapter is specified but not present in provider config: {provider_config}" - ) - adapter_config = provider_config["adapter"] - - config_type = instantiate_class_type(adapter.config_class) - if not issubclass(config_type, RemoteProviderConfig): - raise ValueError( - f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig" - ) - - config = config_type(**adapter_config) + if isinstance(provider_spec, InlineProviderSpec): + args = [config, deps] else: - config = RemoteProviderConfig(**provider_config) - - return asyncio.run(module.get_adapter_impl(config)) + args = [config] + return asyncio.run(module.get_provider_impl(*args)) diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 271f6e87b..efd761e0e 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec -from .distribution import api_endpoints -from .dynamic import instantiate_client, instantiate_provider - -from .registry import resolve_distribution_spec +from .datatypes import Api, ProviderSpec, RemoteProviderSpec +from .distribution import api_endpoints, api_providers +from .dynamic import instantiate_provider def is_async_iterator_type(typ): @@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: return [by_id[x] for x in stack] -def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]: +def resolve_impls( + provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any] +) -> Dict[Api, Any]: provider_configs = config["providers"] - provider_specs = topological_sort(dist.provider_specs.values()) + provider_specs = topological_sort(provider_specs.values()) impls = {} for provider_spec in provider_specs: @@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A f"Could not find provider_spec config for {api}. Please add it to the config" ) + deps = {api: impls[api] for api in provider_spec.api_dependencies} provider_config = provider_configs[api.value] - if isinstance(provider_spec, RemoteProviderSpec): - impls[api] = instantiate_client( - provider_spec, - provider_config, - ) - else: - deps = {api: impls[api] for api in provider_spec.api_dependencies} - impl = instantiate_provider(provider_spec, provider_config, deps) - impls[api] = impl + impl = instantiate_provider(provider_spec, provider_config, deps) + impls[api] = impl return impls @@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: config = yaml.safe_load(fp) - spec = config["spec"] - dist = resolve_distribution_spec(spec) - if dist is None: - raise ValueError(f"Could not find distribution specification `{spec}`") - app = FastAPI() all_endpoints = api_endpoints() - impls = resolve_impls(dist, config) + all_providers = api_providers() - for provider_spec in dist.provider_specs.values(): + provider_specs = {} + for api_str, provider_config in config["providers"].items(): + api = Api(api_str) + providers = all_providers[api] + provider_id = provider_config["provider_id"] + if provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + + provider_specs[api] = providers[provider_id] + + impls = resolve_impls(provider_specs, config) + + for provider_spec in provider_specs.values(): api = provider_spec.api endpoints = all_endpoints[api] impl = impls[api] - if isinstance(provider_spec, RemoteProviderSpec): + if ( + isinstance(provider_spec, RemoteProviderSpec) + and provider_spec.adapter is None + ): for endpoint in endpoints: url = impl.base_url + endpoint.route getattr(app, endpoint.method)(endpoint.route)( diff --git a/llama_toolchain/distribution/start_distribution.sh b/llama_toolchain/distribution/start_conda_env.sh similarity index 69% rename from llama_toolchain/distribution/start_distribution.sh rename to llama_toolchain/distribution/start_conda_env.sh index 271919676..2d2bd57f7 100755 --- a/llama_toolchain/distribution/start_distribution.sh +++ b/llama_toolchain/distribution/start_conda_env.sh @@ -8,7 +8,6 @@ set -euo pipefail -# Define color codes RED='\033[0;31m' NC='\033[0m' # No Color @@ -17,20 +16,17 @@ error_handler() { exit 1 } -# Set up the error trap trap 'error_handler ${LINENO}' ERR if [ $# -lt 2 ]; then - echo "Usage: $0 " - exit 1 + echo "Usage: $0 " + exit 1 fi - env_name="$1" shift eval "$(conda shell.bash hook)" conda deactivate && conda activate "$env_name" -python_interp=$(conda run -n "$env_name" which python) -$python_interp -m llama_toolchain.distribution.server "$@" +$CONDA_PREFIX/bin/python -m llama_toolchain.distribution.server "$@" diff --git a/llama_toolchain/distribution/run_image.sh b/llama_toolchain/distribution/start_container.sh old mode 100644 new mode 100755 similarity index 100% rename from llama_toolchain/distribution/run_image.sh rename to llama_toolchain/distribution/start_container.sh diff --git a/llama_toolchain/inference/adapters/ollama/__init__.py b/llama_toolchain/inference/adapters/ollama/__init__.py index dea3d660f..14bf677cc 100644 --- a/llama_toolchain/inference/adapters/ollama/__init__.py +++ b/llama_toolchain/inference/adapters/ollama/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .ollama import get_adapter_impl # noqa +from .ollama import get_provider_impl # noqa diff --git a/llama_toolchain/inference/adapters/ollama/ollama.py b/llama_toolchain/inference/adapters/ollama/ollama.py index 05423fb96..30decd2cd 100644 --- a/llama_toolchain/inference/adapters/ollama/ollama.py +++ b/llama_toolchain/inference/adapters/ollama/ollama.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import Any, AsyncGenerator import httpx @@ -36,7 +36,7 @@ OLLAMA_SUPPORTED_SKUS = { } -async def get_adapter_impl(config: RemoteProviderConfig) -> Inference: +async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: impl = OllamaInferenceAdapter(config.url) await impl.initialize() return impl diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index d2ea88670..17bd07406 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -26,7 +26,7 @@ from .api import ( from .event_logger import EventLogger -async def get_adapter_impl(config: RemoteProviderConfig) -> Inference: +async def get_provider_impl(config: RemoteProviderConfig) -> Inference: return InferenceClient(config.url) diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index f50bb759e..c9882bf98 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -27,11 +27,11 @@ def available_inference_providers() -> List[ProviderSpec]: module="llama_toolchain.inference.meta_reference", config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", ), - remote_provider_spec( + adapter_provider_spec( api=Api.inference, adapter=AdapterSpec( adapter_id="ollama", - pip_packages=[], + pip_packages=["ollama"], module="llama_toolchain.inference.adapters.ollama", ), ), diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index cc8b80977..abf6d2910 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig from .api import * # noqa: F403 -async def get_adapter_impl(config: RemoteProviderConfig) -> Memory: +async def get_provider_impl(config: RemoteProviderConfig) -> Memory: return MemoryClient(config.url) diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index e3676d240..79a84eb3a 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -19,7 +19,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig from .api import * # noqa: F403 -async def get_adapter_impl(config: RemoteProviderConfig) -> Safety: +async def get_provider_impl(config: RemoteProviderConfig) -> Safety: return SafetyClient(config.url)