ollama remote adapter works

This commit is contained in:
Ashwin Bharambe 2024-08-28 06:51:07 -07:00
parent 2076d2b6db
commit 2a1552a5eb
14 changed files with 196 additions and 128 deletions

View file

@ -10,6 +10,7 @@ from llama_toolchain.cli.subcommand import Subcommand
from .build import ApiBuild from .build import ApiBuild
from .configure import ApiConfigure from .configure import ApiConfigure
from .start import ApiStart
class ApiParser(Subcommand): class ApiParser(Subcommand):
@ -26,3 +27,4 @@ class ApiParser(Subcommand):
# Add sub-commands # Add sub-commands
ApiBuild.create(subparsers) ApiBuild.create(subparsers)
ApiConfigure.create(subparsers) ApiConfigure.create(subparsers)
ApiStart.create(subparsers)

View file

@ -7,9 +7,6 @@
import argparse import argparse
import json import json
import os import os
import random
import string
import uuid
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -25,10 +22,6 @@ from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403 from llama_toolchain.distribution.datatypes import * # noqa: F403
def random_string():
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
class BuildType(Enum): class BuildType(Enum):
container = "container" container = "container"
conda_env = "conda_env" conda_env = "conda_env"
@ -42,6 +35,8 @@ class Dependencies(BaseModel):
def get_dependencies( def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec] provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies: ) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]: def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
if isinstance(provider, InlineProviderSpec): if isinstance(provider, InlineProviderSpec):
return provider.pip_packages, provider.docker_image return provider.pip_packages, provider.docker_image
@ -60,7 +55,9 @@ def get_dependencies(
pip_packages.extend(dep_pip_packages) pip_packages.extend(dep_pip_packages)
return Dependencies(docker_image=docker_image, pip_packages=pip_packages) return Dependencies(
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES
)
def parse_dependencies( def parse_dependencies(
@ -88,7 +85,6 @@ def parse_dependencies(
class ApiBuild(Subcommand): class ApiBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
self.parser = subparsers.add_parser( self.parser = subparsers.add_parser(
@ -125,8 +121,8 @@ class ApiBuild(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--name", "--name",
type=str, type=str,
help="Name of the build target (image, conda env). Defaults to a random UUID", help="Name of the build target (image, conda env)",
required=False, required=True,
) )
self.parser.add_argument( self.parser.add_argument(
"--type", "--type",
@ -153,11 +149,10 @@ class ApiBuild(Subcommand):
) )
return return
name = args.name or random_string()
if args.type == BuildType.container.value: if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{name}" package_name = f"image-{args.provider}-{args.name}"
else: else:
package_name = f"env-{args.provider}-{name}" package_name = f"env-{args.provider}-{args.name}"
package_name = package_name.replace("::", "-") package_name = package_name.replace("::", "-")
build_dir = BUILDS_BASE_DIR / args.api build_dir = BUILDS_BASE_DIR / args.api
@ -176,7 +171,7 @@ class ApiBuild(Subcommand):
} }
with open(package_file, "w") as f: with open(package_file, "w") as f:
c = PackageConfig( c = PackageConfig(
built_at=str(datetime.now()), built_at=datetime.now(),
package_name=package_name, package_name=package_name,
docker_image=( docker_image=(
package_name if args.type == BuildType.container.value else None package_name if args.type == BuildType.container.value else None

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import * # noqa: F403
class ApiStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama api start",
description="""start the server for a Llama API provider. You should have already built and configured the provider.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--yaml-config",
type=str,
help="Yaml config containing the API build configuration",
required=True,
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_api_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
config_file = Path(args.yaml_config)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
)
return
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_args.extend(["--yaml_config", str(config_file), "--port", str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)

View file

@ -38,6 +38,36 @@ class ProviderSpec(BaseModel):
) )
@json_schema_type
class AdapterSpec(BaseModel):
"""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
"""
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type @json_schema_type
class InlineProviderSpec(ProviderSpec): class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field( pip_packages: List[str] = Field(
@ -63,30 +93,7 @@ Fully-qualified name of the module to import. The module is expected to have:
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
is_adapter: bool = False
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
class RemoteProviderConfig(BaseModel): class RemoteProviderConfig(BaseModel):
@ -106,40 +113,34 @@ def remote_provider_id(adapter_id: str) -> str:
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field( provider_id: str = "remote"
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig" config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
@property
def module(self) -> str:
return f"llama_toolchain.{self.api.value}.client"
# need this wrapper since we don't have Pydantic v2 and that means we don't have
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(api=api)
# TODO: use computed_field to avoid this wrapper
# the @computed_field decorator # the @computed_field decorator
def remote_provider_spec( def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
provider_id = (
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
)
module = (
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
)
config_class = ( config_class = (
adapter.config_class adapter.config_class
if adapter and adapter.config_class if adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig" else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
) )
return RemoteProviderSpec( return InlineProviderSpec(
api=api, api=api,
provider_id=provider_id, provider_id=remote_provider_id(adapter.adapter_id),
pip_packages=adapter.pip_packages if adapter is not None else [], pip_packages=adapter.pip_packages,
module=module, module=adapter.module,
config_class=config_class, config_class=config_class,
is_adapter=True,
) )

View file

@ -8,7 +8,7 @@ import asyncio
import importlib import importlib
from typing import Any, Dict from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -19,38 +19,23 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
def instantiate_provider( def instantiate_provider(
provider_spec: InlineProviderSpec, provider_spec: ProviderSpec,
provider_config: Dict[str, Any], provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec], deps: Dict[str, ProviderSpec],
): ):
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, InlineProviderSpec):
if provider_spec.is_adapter:
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**provider_config) config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps))
if isinstance(provider_spec, InlineProviderSpec):
def instantiate_client( args = [config, deps]
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
):
module = importlib.import_module(provider_spec.module)
adapter = provider_spec.adapter
if adapter is not None:
if "adapter" not in provider_config:
raise ValueError(
f"Adapter is specified but not present in provider config: {provider_config}"
)
adapter_config = provider_config["adapter"]
config_type = instantiate_class_type(adapter.config_class)
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**adapter_config)
else: else:
config = RemoteProviderConfig(**provider_config) args = [config]
return asyncio.run(module.get_provider_impl(*args))
return asyncio.run(module.get_adapter_impl(config))

View file

@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec from .datatypes import Api, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_client, instantiate_provider from .dynamic import instantiate_provider
from .registry import resolve_distribution_spec
def is_async_iterator_type(typ): def is_async_iterator_type(typ):
@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack] return [by_id[x] for x in stack]
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]: def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"] provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values()) provider_specs = topological_sort(provider_specs.values())
impls = {} impls = {}
for provider_spec in provider_specs: for provider_spec in provider_specs:
@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
f"Could not find provider_spec config for {api}. Please add it to the config" f"Could not find provider_spec config for {api}. Please add it to the config"
) )
deps = {api: impls[api] for api in provider_spec.api_dependencies}
provider_config = provider_configs[api.value] provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec): impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = instantiate_client( impls[api] = impl
provider_spec,
provider_config,
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls return impls
@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp: with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp) config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI() app = FastAPI()
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
impls = resolve_impls(dist, config) all_providers = api_providers()
for provider_spec in dist.provider_specs.values(): provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api api = provider_spec.api
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec): if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints: for endpoint in endpoints:
url = impl.base_url + endpoint.route url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(

View file

@ -8,7 +8,6 @@
set -euo pipefail set -euo pipefail
# Define color codes
RED='\033[0;31m' RED='\033[0;31m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
@ -17,20 +16,17 @@ error_handler() {
exit 1 exit 1
} }
# Set up the error trap
trap 'error_handler ${LINENO}' ERR trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then if [ $# -lt 2 ]; then
echo "Usage: $0 <environment_name> <script_args...>" echo "Usage: $0 <environment_name> <script_args...>"
exit 1 exit 1
fi fi
env_name="$1" env_name="$1"
shift shift
eval "$(conda shell.bash hook)" eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name" conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python) $CONDA_PREFIX/bin/python -m llama_toolchain.distribution.server "$@"
$python_interp -m llama_toolchain.distribution.server "$@"

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .ollama import get_adapter_impl # noqa from .ollama import get_provider_impl # noqa

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import Any, AsyncGenerator
import httpx import httpx
@ -36,7 +36,7 @@ OLLAMA_SUPPORTED_SKUS = {
} }
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference: async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
impl = OllamaInferenceAdapter(config.url) impl = OllamaInferenceAdapter(config.url)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -26,7 +26,7 @@ from .api import (
from .event_logger import EventLogger from .event_logger import EventLogger
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference: async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
return InferenceClient(config.url) return InferenceClient(config.url)

View file

@ -27,11 +27,11 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference", module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
), ),
remote_provider_spec( adapter_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="ollama", adapter_id="ollama",
pip_packages=[], pip_packages=["ollama"],
module="llama_toolchain.inference.adapters.ollama", module="llama_toolchain.inference.adapters.ollama",
), ),
), ),

View file

@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
async def get_adapter_impl(config: RemoteProviderConfig) -> Memory: async def get_provider_impl(config: RemoteProviderConfig) -> Memory:
return MemoryClient(config.url) return MemoryClient(config.url)

View file

@ -19,7 +19,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
async def get_adapter_impl(config: RemoteProviderConfig) -> Safety: async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
return SafetyClient(config.url) return SafetyClient(config.url)