ollama remote adapter works

This commit is contained in:
Ashwin Bharambe 2024-08-28 06:51:07 -07:00
parent 2076d2b6db
commit 2a1552a5eb
14 changed files with 196 additions and 128 deletions

View file

@ -10,6 +10,7 @@ from llama_toolchain.cli.subcommand import Subcommand
from .build import ApiBuild
from .configure import ApiConfigure
from .start import ApiStart
class ApiParser(Subcommand):
@ -26,3 +27,4 @@ class ApiParser(Subcommand):
# Add sub-commands
ApiBuild.create(subparsers)
ApiConfigure.create(subparsers)
ApiStart.create(subparsers)

View file

@ -7,9 +7,6 @@
import argparse
import json
import os
import random
import string
import uuid
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
@ -25,10 +22,6 @@ from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
def random_string():
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
@ -42,6 +35,8 @@ class Dependencies(BaseModel):
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
if isinstance(provider, InlineProviderSpec):
return provider.pip_packages, provider.docker_image
@ -60,7 +55,9 @@ def get_dependencies(
pip_packages.extend(dep_pip_packages)
return Dependencies(docker_image=docker_image, pip_packages=pip_packages)
return Dependencies(
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES
)
def parse_dependencies(
@ -88,7 +85,6 @@ def parse_dependencies(
class ApiBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
@ -125,8 +121,8 @@ class ApiBuild(Subcommand):
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env). Defaults to a random UUID",
required=False,
help="Name of the build target (image, conda env)",
required=True,
)
self.parser.add_argument(
"--type",
@ -153,11 +149,10 @@ class ApiBuild(Subcommand):
)
return
name = args.name or random_string()
if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{name}"
package_name = f"image-{args.provider}-{args.name}"
else:
package_name = f"env-{args.provider}-{name}"
package_name = f"env-{args.provider}-{args.name}"
package_name = package_name.replace("::", "-")
build_dir = BUILDS_BASE_DIR / args.api
@ -176,7 +171,7 @@ class ApiBuild(Subcommand):
}
with open(package_file, "w") as f:
c = PackageConfig(
built_at=str(datetime.now()),
built_at=datetime.now(),
package_name=package_name,
docker_image=(
package_name if args.type == BuildType.container.value else None

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import * # noqa: F403
class ApiStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama api start",
description="""start the server for a Llama API provider. You should have already built and configured the provider.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--yaml-config",
type=str,
help="Yaml config containing the API build configuration",
required=True,
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_api_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
config_file = Path(args.yaml_config)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
)
return
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_args.extend(["--yaml_config", str(config_file), "--port", str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)

View file

@ -38,6 +38,36 @@ class ProviderSpec(BaseModel):
)
@json_schema_type
class AdapterSpec(BaseModel):
"""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
"""
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
@ -63,30 +93,7 @@ Fully-qualified name of the module to import. The module is expected to have:
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
is_adapter: bool = False
class RemoteProviderConfig(BaseModel):
@ -106,40 +113,34 @@ def remote_provider_id(adapter_id: str) -> str:
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
provider_id: str = "remote"
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
@property
def module(self) -> str:
return f"llama_toolchain.{self.api.value}.client"
# need this wrapper since we don't have Pydantic v2 and that means we don't have
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(api=api)
# TODO: use computed_field to avoid this wrapper
# the @computed_field decorator
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
provider_id = (
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
)
module = (
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
)
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
if adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
)
return RemoteProviderSpec(
return InlineProviderSpec(
api=api,
provider_id=provider_id,
pip_packages=adapter.pip_packages if adapter is not None else [],
module=module,
provider_id=remote_provider_id(adapter.adapter_id),
pip_packages=adapter.pip_packages,
module=adapter.module,
config_class=config_class,
is_adapter=True,
)

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
def instantiate_class_type(fully_qualified_name):
@ -19,38 +19,23 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api
def instantiate_provider(
provider_spec: InlineProviderSpec,
provider_spec: ProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, InlineProviderSpec):
if provider_spec.is_adapter:
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
):
module = importlib.import_module(provider_spec.module)
adapter = provider_spec.adapter
if adapter is not None:
if "adapter" not in provider_config:
raise ValueError(
f"Adapter is specified but not present in provider config: {provider_config}"
)
adapter_config = provider_config["adapter"]
config_type = instantiate_class_type(adapter.config_class)
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**adapter_config)
if isinstance(provider_spec, InlineProviderSpec):
args = [config, deps]
else:
config = RemoteProviderConfig(**provider_config)
return asyncio.run(module.get_adapter_impl(config))
args = [config]
return asyncio.run(module.get_provider_impl(*args))

View file

@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution_spec
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
provider_specs = topological_sort(provider_specs.values())
impls = {}
for provider_spec in provider_specs:
@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
f"Could not find provider_spec config for {api}. Please add it to the config"
)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec,
provider_config,
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
all_providers = api_providers()
for provider_spec in dist.provider_specs.values():
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(

View file

@ -8,7 +8,6 @@
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
@ -17,20 +16,17 @@ error_handler() {
exit 1
}
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <environment_name> <script_args...>"
exit 1
echo "Usage: $0 <environment_name> <script_args...>"
exit 1
fi
env_name="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python)
$python_interp -m llama_toolchain.distribution.server "$@"
$CONDA_PREFIX/bin/python -m llama_toolchain.distribution.server "$@"

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .ollama import get_adapter_impl # noqa
from .ollama import get_provider_impl # noqa

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import httpx
@ -36,7 +36,7 @@ OLLAMA_SUPPORTED_SKUS = {
}
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -26,7 +26,7 @@ from .api import (
from .event_logger import EventLogger
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
return InferenceClient(config.url)

View file

@ -27,11 +27,11 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
),
remote_provider_spec(
adapter_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="ollama",
pip_packages=[],
pip_packages=["ollama"],
module="llama_toolchain.inference.adapters.ollama",
),
),

View file

@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_adapter_impl(config: RemoteProviderConfig) -> Memory:
async def get_provider_impl(config: RemoteProviderConfig) -> Memory:
return MemoryClient(config.url)

View file

@ -19,7 +19,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_adapter_impl(config: RemoteProviderConfig) -> Safety:
async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
return SafetyClient(config.url)