mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
bunch more work to make adapters work
This commit is contained in:
parent
68f3db62e9
commit
c4fe72c3a3
20 changed files with 461 additions and 173 deletions
|
@ -9,6 +9,7 @@ import argparse
|
|||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
|
||||
from .build import ApiBuild
|
||||
from .configure import ApiConfigure
|
||||
|
||||
|
||||
class ApiParser(Subcommand):
|
||||
|
@ -24,3 +25,4 @@ class ApiParser(Subcommand):
|
|||
|
||||
# Add sub-commands
|
||||
ApiBuild.create(subparsers)
|
||||
ApiConfigure.create(subparsers)
|
||||
|
|
|
@ -6,22 +6,92 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def random_string():
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||
|
||||
|
||||
class BuildType(Enum):
|
||||
container = "container"
|
||||
conda_env = "conda_env"
|
||||
|
||||
|
||||
class Dependencies(BaseModel):
|
||||
pip_packages: List[str]
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
|
||||
def get_dependencies(
|
||||
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
|
||||
) -> Dependencies:
|
||||
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
|
||||
if isinstance(provider, InlineProviderSpec):
|
||||
return provider.pip_packages, provider.docker_image
|
||||
else:
|
||||
if provider.adapter:
|
||||
return provider.adapter.pip_packages, None
|
||||
return [], None
|
||||
|
||||
pip_packages, docker_image = _deps(provider)
|
||||
for dep in dependencies.values():
|
||||
dep_pip_packages, dep_docker_image = _deps(dep)
|
||||
if docker_image and dep_docker_image:
|
||||
raise ValueError(
|
||||
"You can only have the root provider specify a docker image"
|
||||
)
|
||||
|
||||
pip_packages.extend(dep_pip_packages)
|
||||
|
||||
return Dependencies(docker_image=docker_image, pip_packages=pip_packages)
|
||||
|
||||
|
||||
def parse_dependencies(
|
||||
dependencies: str, parser: argparse.ArgumentParser
|
||||
) -> Dict[str, ProviderSpec]:
|
||||
from llama_toolchain.distribution.distribution import api_providers
|
||||
|
||||
all_providers = api_providers()
|
||||
|
||||
deps = {}
|
||||
for dep in dependencies.split(","):
|
||||
dep = dep.strip()
|
||||
if not dep:
|
||||
continue
|
||||
api_str, provider = dep.split("=")
|
||||
api = Api(api_str)
|
||||
|
||||
provider = provider.strip()
|
||||
if provider not in all_providers[api]:
|
||||
parser.error(f"Provider `{provider}` is not available for API `{api}`")
|
||||
return
|
||||
deps[api] = all_providers[api][provider]
|
||||
|
||||
return deps
|
||||
|
||||
|
||||
class ApiBuild(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"install",
|
||||
"build",
|
||||
prog="llama api build",
|
||||
description="Build a Llama stack API provider container",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
|
@ -36,7 +106,7 @@ class ApiBuild(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"api",
|
||||
choices=allowed_args,
|
||||
help="Stack API (one of: {})".format(", ".join(allowed_args))
|
||||
help="Stack API (one of: {})".format(", ".join(allowed_args)),
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
@ -45,73 +115,104 @@ class ApiBuild(Subcommand):
|
|||
help="The provider to package into the container",
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--container-name",
|
||||
type=str,
|
||||
help="Name of the container (including tag if needed)",
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--dependencies",
|
||||
type=str,
|
||||
help="Comma separated list of (downstream_api=provider) dependencies needed for the API",
|
||||
required=False,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the build target (image, conda env). Defaults to a random UUID",
|
||||
required=False,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--type",
|
||||
type=str,
|
||||
default="container",
|
||||
choices=[v.value for v in BuildType],
|
||||
)
|
||||
|
||||
def _run_api_build_command(self, args: argparse.Namespace) -> None:
|
||||
from llama_toolchain.common.exec import run_with_pty
|
||||
from llama_toolchain.distribution.datatypes import DistributionConfig
|
||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||
from llama_toolchain.distribution.distribution import api_providers
|
||||
|
||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain",
|
||||
"distribution/build_api.sh",
|
||||
)
|
||||
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
|
||||
all_providers = api_providers()
|
||||
|
||||
dist = resolve_distribution_spec(args.spec)
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.spec}")
|
||||
api = Api(args.api)
|
||||
assert api in all_providers
|
||||
|
||||
providers = all_providers[api]
|
||||
if args.provider not in providers:
|
||||
self.parser.error(
|
||||
f"Provider `{args.provider}` is not available for API `{api}`"
|
||||
)
|
||||
return
|
||||
|
||||
distrib_dir = DISTRIBS_BASE_DIR / args.name
|
||||
os.makedirs(distrib_dir, exist_ok=True)
|
||||
|
||||
deps = distribution_dependencies(dist)
|
||||
if not args.conda_env:
|
||||
print(f"Using {args.name} as the Conda environment for this distribution")
|
||||
|
||||
conda_env = args.conda_env or args.name
|
||||
|
||||
config_file = distrib_dir / "config.yaml"
|
||||
if config_file.exists():
|
||||
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
|
||||
if c.spec != dist.spec_id:
|
||||
self.parser.error(
|
||||
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
|
||||
)
|
||||
return
|
||||
if c.conda_env != conda_env:
|
||||
self.parser.error(
|
||||
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
|
||||
)
|
||||
return
|
||||
name = args.name or random_string()
|
||||
if args.type == BuildType.container.value:
|
||||
package_name = f"image-{args.provider}-{name}"
|
||||
else:
|
||||
with open(config_file, "w") as f:
|
||||
c = DistributionConfig(
|
||||
spec=dist.spec_id,
|
||||
name=args.name,
|
||||
conda_env=conda_env,
|
||||
)
|
||||
f.write(yaml.dump(c.dict(), sort_keys=False))
|
||||
package_name = f"env-{args.provider}-{name}"
|
||||
package_name = package_name.replace("::", "-")
|
||||
|
||||
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
|
||||
build_dir = BUILDS_BASE_DIR / args.api
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
|
||||
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
|
||||
dependencies = get_dependencies(providers[args.provider], provider_deps)
|
||||
|
||||
package_file = build_dir / f"{package_name}.yaml"
|
||||
|
||||
stub_config = {
|
||||
api.value: {
|
||||
"provider_id": args.provider,
|
||||
},
|
||||
**{k: {"provider_id": v} for k, v in provider_deps.items()},
|
||||
}
|
||||
with open(package_file, "w") as f:
|
||||
c = PackageConfig(
|
||||
built_at=datetime.now(),
|
||||
package_name=package_name,
|
||||
docker_image=(
|
||||
package_name if args.type == BuildType.container.value else None
|
||||
),
|
||||
conda_env=(
|
||||
package_name if args.type == BuildType.conda_env.value else None
|
||||
),
|
||||
providers=stub_config,
|
||||
)
|
||||
f.write(yaml.dump(c.dict(), sort_keys=False))
|
||||
|
||||
if args.type == BuildType.container.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "distribution/build_container.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
args.api,
|
||||
package_name,
|
||||
dependencies.docker_image or "python:3.10-slim",
|
||||
" ".join(dependencies.pip_packages),
|
||||
]
|
||||
else:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "distribution/build_conda_env.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
args.api,
|
||||
package_name,
|
||||
" ".join(dependencies.pip_packages),
|
||||
]
|
||||
|
||||
return_code = run_with_pty(args)
|
||||
assert return_code == 0, cprint(
|
||||
f"Failed to install distribution {dist.spec_id}", color="red"
|
||||
f"Failed to build target {package_name}", color="red"
|
||||
)
|
||||
cprint(
|
||||
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
||||
f"Target `{target_name}` built with configuration at {str(package_file)}",
|
||||
color="green",
|
||||
)
|
||||
|
|
100
llama_toolchain/cli/api/configure.py
Normal file
100
llama_toolchain/cli/api/configure.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class ApiConfigure(Subcommand):
|
||||
"""Llama cli for configuring llama toolchain configs"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"configure",
|
||||
prog="llama api configure",
|
||||
description="configure a llama stack API provider",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_api_configure_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
from llama_toolchain.distribution.distribution import stack_apis
|
||||
|
||||
allowed_args = [a.name for a in stack_apis()]
|
||||
self.parser.add_argument(
|
||||
"api",
|
||||
choices=allowed_args,
|
||||
help="Stack API (one of: {})".format(", ".join(allowed_args)),
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the provider build to fully configure",
|
||||
required=True,
|
||||
)
|
||||
|
||||
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"Could not find {config_file}. Please run `llama api build` first"
|
||||
)
|
||||
return
|
||||
|
||||
configure_llama_provider(config_file)
|
||||
|
||||
|
||||
def configure_llama_provider(config_file: Path) -> None:
|
||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
from llama_toolchain.distribution.distribution import api_providers
|
||||
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
config = PackageConfig(**yaml.safe_load(f))
|
||||
|
||||
all_providers = api_providers()
|
||||
|
||||
provider_configs = {}
|
||||
for api, stub_config in config.providers.items():
|
||||
providers = all_providers[Api(api)]
|
||||
provider_id = stub_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider_spec = providers[provider_id]
|
||||
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
print(f"Config type: {config_type}")
|
||||
provider_config = prompt_for_config(
|
||||
config_type,
|
||||
)
|
||||
print("")
|
||||
|
||||
provider_configs[api.value] = {
|
||||
"provider_id": provider_id,
|
||||
**provider_config.dict(),
|
||||
}
|
||||
|
||||
config.providers = provider_configs
|
||||
with open(config_file, "w") as fp:
|
||||
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||
fp.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
print(f"YAML configuration has been written to {config_path}")
|
|
@ -13,3 +13,5 @@ LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
|
|||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||
|
|
|
@ -71,6 +71,7 @@ def prompt_for_config(
|
|||
"""
|
||||
config_data = {}
|
||||
|
||||
print(f"Configuring {config_type.__name__}:")
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
|
||||
|
@ -85,6 +86,7 @@ def prompt_for_config(
|
|||
if not isinstance(field.default, PydanticUndefinedType)
|
||||
else None
|
||||
)
|
||||
print(f" {field_name}: {field_type} (default: {default_value})")
|
||||
is_required = field.is_required
|
||||
|
||||
# Skip fields with Literal type
|
||||
|
|
|
@ -10,20 +10,29 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
echo "llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
|
||||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 3 ]; then
|
||||
echo "Usage: $0 <api_or_stack> <environment_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 [api|stack] conda-env 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
api_or_stack="$1"
|
||||
env_name="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
# Set up the error trap
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
|
@ -52,6 +61,9 @@ ensure_conda_env_python310() {
|
|||
else
|
||||
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
@ -94,19 +106,8 @@ ensure_conda_env_python310() {
|
|||
fi
|
||||
}
|
||||
|
||||
if [ "$#" -ne 3 ]; then
|
||||
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
env_name="$1"
|
||||
distribution_name="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||
|
||||
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}"
|
||||
echo -e "${GREEN}Successfully setup conda environment. Configuring build...${NC}"
|
||||
|
||||
which python3
|
||||
python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name"
|
0
llama_toolchain/distribution/build_image.sh → llama_toolchain/distribution/build_container.sh
Normal file → Executable file
0
llama_toolchain/distribution/build_image.sh → llama_toolchain/distribution/build_container.sh
Normal file → Executable file
40
llama_toolchain/distribution/common.sh
Normal file
40
llama_toolchain/distribution/common.sh
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
envname="$1"
|
||||
|
||||
set +x
|
||||
echo "Cleaning up..."
|
||||
conda deactivate
|
||||
conda env remove --name $envname -y
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
||||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
setup_cleanup_handlers() {
|
||||
trap handle_int INT
|
||||
trap handle_exit EXIT
|
||||
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
|
||||
conda deactivate
|
||||
}
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -66,36 +67,45 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
...,
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
url: str = Field(..., description="The URL for the provider")
|
||||
|
||||
@validator("base_url")
|
||||
@validator("url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, base_url: str) -> str:
|
||||
if not base_url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {base_url}")
|
||||
return base_url
|
||||
def validate_url(cls, url: str) -> str:
|
||||
if not url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {url}")
|
||||
return url
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
return f"remote::{adapter_id}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
||||
""",
|
||||
)
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -107,6 +117,32 @@ as being "Llama Stack compatible"
|
|||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
|
||||
|
||||
# need this wrapper since we don't have Pydantic v2 and that means we don't have
|
||||
# the @computed_field decorator
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
provider_id = (
|
||||
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
|
||||
)
|
||||
module = (
|
||||
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
|
||||
)
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_id=provider_id,
|
||||
pip_packages=adapter.pip_packages if adapter is not None else [],
|
||||
module=module,
|
||||
config_class=config_class,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
spec_id: str
|
||||
|
@ -119,13 +155,28 @@ class DistributionSpec(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionConfig(BaseModel):
|
||||
"""References to a installed / configured DistributionSpec"""
|
||||
class PackageConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
||||
name: str
|
||||
spec: str
|
||||
conda_env: str
|
||||
package_name: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the docker image if this package refers to a container",
|
||||
)
|
||||
conda_env: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
providers: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Provider configurations for each of the APIs provided by this distribution",
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
||||
the dependencies of these providers as well.
|
||||
""",
|
||||
)
|
||||
|
|
|
@ -30,7 +30,27 @@ def instantiate_provider(
|
|||
return asyncio.run(module.get_provider_impl(config, deps))
|
||||
|
||||
|
||||
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
|
||||
def instantiate_client(
|
||||
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
return asyncio.run(module.get_client_impl(base_url))
|
||||
adapter = provider_spec.adapter
|
||||
if adapter is not None:
|
||||
if "adapter" not in provider_config:
|
||||
raise ValueError(
|
||||
f"Adapter is specified but not present in provider config: {provider_config}"
|
||||
)
|
||||
adapter_config = provider_config["adapter"]
|
||||
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
if not issubclass(config_type, RemoteProviderConfig):
|
||||
raise ValueError(
|
||||
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
|
||||
)
|
||||
|
||||
config = config_type(**adapter_config)
|
||||
else:
|
||||
config = RemoteProviderConfig(**provider_config)
|
||||
|
||||
return asyncio.run(module.get_adapter_impl(config))
|
||||
|
|
|
@ -7,22 +7,10 @@
|
|||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
|
||||
from .datatypes import * # noqa: F403
|
||||
from .distribution import api_providers
|
||||
|
||||
|
||||
def client_module(api: Api) -> str:
|
||||
return f"llama_toolchain.{api.value}.client"
|
||||
|
||||
|
||||
def remote_spec(api: Api) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_id=f"{api.value}-remote",
|
||||
module=client_module(api),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def available_distribution_specs() -> List[DistributionSpec]:
|
||||
providers = api_providers()
|
||||
|
@ -40,13 +28,14 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
DistributionSpec(
|
||||
spec_id="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
provider_specs={x: remote_spec(x) for x in providers},
|
||||
provider_specs={x: remote_provider_spec(x) for x in providers},
|
||||
),
|
||||
DistributionSpec(
|
||||
spec_id="local-ollama",
|
||||
description="Like local, but use ollama for running LLM inference",
|
||||
provider_specs={
|
||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||
# this is ODD; make this easier -- we just need a better function to retrieve registered providers
|
||||
Api.inference: providers[Api.inference][remote_provider_id("ollama")],
|
||||
Api.safety: providers[Api.safety]["meta-reference"],
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||
|
@ -57,9 +46,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
description="Test agentic with others as remote",
|
||||
provider_specs={
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
Api.inference: remote_spec(Api.inference),
|
||||
Api.memory: remote_spec(Api.memory),
|
||||
Api.safety: remote_spec(Api.safety),
|
||||
Api.inference: remote_provider_spec(Api.inference),
|
||||
Api.memory: remote_provider_spec(Api.memory),
|
||||
Api.safety: remote_provider_spec(Api.safety),
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
|
|
@ -264,7 +264,8 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
|
|||
provider_config = provider_configs[api.value]
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
impls[api] = instantiate_client(
|
||||
provider_spec, provider_config["base_url"].rstrip("/")
|
||||
provider_spec,
|
||||
provider_config,
|
||||
)
|
||||
else:
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
|
|
5
llama_toolchain/inference/adapters/__init__.py
Normal file
5
llama_toolchain/inference/adapters/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -4,5 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OllamaImplConfig # noqa
|
||||
from .ollama import get_provider_impl # noqa
|
||||
from .ollama import get_adapter_impl # noqa
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import resolve_model
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -27,7 +27,6 @@ from llama_toolchain.inference.api import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
|
@ -37,26 +36,21 @@ OLLAMA_SUPPORTED_SKUS = {
|
|||
}
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
) -> Inference:
|
||||
assert isinstance(
|
||||
config, OllamaImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = OllamaInference(config)
|
||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class OllamaInference(Inference):
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.config = config
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.config.url)
|
||||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
|
@ -13,6 +13,8 @@ import httpx
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -24,8 +26,8 @@ from .api import (
|
|||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(base_url: str):
|
||||
return InferenceClient(base_url)
|
||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
||||
return InferenceClient(config.url)
|
||||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
|
@ -34,7 +36,7 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
class InferenceClient(Inference):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Initializing client for {base_url}")
|
||||
print(f"Inference passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434",
|
||||
description="The URL for the ollama server",
|
||||
)
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_inference_providers() -> List[ProviderSpec]:
|
||||
|
@ -27,13 +27,12 @@ def available_inference_providers() -> List[ProviderSpec]:
|
|||
module="llama_toolchain.inference.meta_reference",
|
||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
provider_id="meta-ollama",
|
||||
pip_packages=[
|
||||
"ollama",
|
||||
],
|
||||
module="llama_toolchain.inference.ollama",
|
||||
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="ollama",
|
||||
pip_packages=[],
|
||||
module="llama_toolchain.inference.adapters.ollama",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,24 +6,23 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
# import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
||||
# from termcolor import cprint
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import * # noqa: F403
|
||||
|
||||
|
||||
async def get_client_impl(base_url: str):
|
||||
return MemoryClient(base_url)
|
||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Memory:
|
||||
return MemoryClient(config.url)
|
||||
|
||||
|
||||
class MemoryClient(Memory):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Initializing client for {base_url}")
|
||||
print(f"Memory passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -10,20 +10,17 @@ import fire
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import UserMessage
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .api import (
|
||||
BuiltinShield,
|
||||
RunShieldRequest,
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
ShieldDefinition,
|
||||
)
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import * # noqa: F403
|
||||
|
||||
|
||||
async def get_client_impl(base_url: str):
|
||||
return SafetyClient(base_url)
|
||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Safety:
|
||||
return SafetyClient(config.url)
|
||||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
|
@ -32,7 +29,7 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
class SafetyClient(Safety):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Initializing client for {base_url}")
|
||||
print(f"Safety passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue