mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
bunch more work to make adapters work
This commit is contained in:
parent
68f3db62e9
commit
c4fe72c3a3
20 changed files with 461 additions and 173 deletions
|
@ -9,6 +9,7 @@ import argparse
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import ApiBuild
|
from .build import ApiBuild
|
||||||
|
from .configure import ApiConfigure
|
||||||
|
|
||||||
|
|
||||||
class ApiParser(Subcommand):
|
class ApiParser(Subcommand):
|
||||||
|
@ -24,3 +25,4 @@ class ApiParser(Subcommand):
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
ApiBuild.create(subparsers)
|
ApiBuild.create(subparsers)
|
||||||
|
ApiConfigure.create(subparsers)
|
||||||
|
|
|
@ -6,22 +6,92 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import uuid
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||||
|
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def random_string():
|
||||||
|
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||||
|
|
||||||
|
|
||||||
|
class BuildType(Enum):
|
||||||
|
container = "container"
|
||||||
|
conda_env = "conda_env"
|
||||||
|
|
||||||
|
|
||||||
|
class Dependencies(BaseModel):
|
||||||
|
pip_packages: List[str]
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_dependencies(
|
||||||
|
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
|
||||||
|
) -> Dependencies:
|
||||||
|
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
|
||||||
|
if isinstance(provider, InlineProviderSpec):
|
||||||
|
return provider.pip_packages, provider.docker_image
|
||||||
|
else:
|
||||||
|
if provider.adapter:
|
||||||
|
return provider.adapter.pip_packages, None
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
pip_packages, docker_image = _deps(provider)
|
||||||
|
for dep in dependencies.values():
|
||||||
|
dep_pip_packages, dep_docker_image = _deps(dep)
|
||||||
|
if docker_image and dep_docker_image:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only have the root provider specify a docker image"
|
||||||
|
)
|
||||||
|
|
||||||
|
pip_packages.extend(dep_pip_packages)
|
||||||
|
|
||||||
|
return Dependencies(docker_image=docker_image, pip_packages=pip_packages)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dependencies(
|
||||||
|
dependencies: str, parser: argparse.ArgumentParser
|
||||||
|
) -> Dict[str, ProviderSpec]:
|
||||||
|
from llama_toolchain.distribution.distribution import api_providers
|
||||||
|
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
deps = {}
|
||||||
|
for dep in dependencies.split(","):
|
||||||
|
dep = dep.strip()
|
||||||
|
if not dep:
|
||||||
|
continue
|
||||||
|
api_str, provider = dep.split("=")
|
||||||
|
api = Api(api_str)
|
||||||
|
|
||||||
|
provider = provider.strip()
|
||||||
|
if provider not in all_providers[api]:
|
||||||
|
parser.error(f"Provider `{provider}` is not available for API `{api}`")
|
||||||
|
return
|
||||||
|
deps[api] = all_providers[api][provider]
|
||||||
|
|
||||||
|
return deps
|
||||||
|
|
||||||
|
|
||||||
class ApiBuild(Subcommand):
|
class ApiBuild(Subcommand):
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
"install",
|
"build",
|
||||||
prog="llama api build",
|
prog="llama api build",
|
||||||
description="Build a Llama stack API provider container",
|
description="Build a Llama stack API provider container",
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
@ -36,7 +106,7 @@ class ApiBuild(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"api",
|
"api",
|
||||||
choices=allowed_args,
|
choices=allowed_args,
|
||||||
help="Stack API (one of: {})".format(", ".join(allowed_args))
|
help="Stack API (one of: {})".format(", ".join(allowed_args)),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
@ -45,73 +115,104 @@ class ApiBuild(Subcommand):
|
||||||
help="The provider to package into the container",
|
help="The provider to package into the container",
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
|
||||||
"--container-name",
|
|
||||||
type=str,
|
|
||||||
help="Name of the container (including tag if needed)",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--dependencies",
|
"--dependencies",
|
||||||
type=str,
|
type=str,
|
||||||
help="Comma separated list of (downstream_api=provider) dependencies needed for the API",
|
help="Comma separated list of (downstream_api=provider) dependencies needed for the API",
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the build target (image, conda env). Defaults to a random UUID",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--type",
|
||||||
|
type=str,
|
||||||
|
default="container",
|
||||||
|
choices=[v.value for v in BuildType],
|
||||||
|
)
|
||||||
|
|
||||||
def _run_api_build_command(self, args: argparse.Namespace) -> None:
|
def _run_api_build_command(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
from llama_toolchain.distribution.datatypes import DistributionConfig
|
from llama_toolchain.distribution.distribution import api_providers
|
||||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
|
||||||
script = pkg_resources.resource_filename(
|
all_providers = api_providers()
|
||||||
"llama_toolchain",
|
|
||||||
"distribution/build_api.sh",
|
|
||||||
)
|
|
||||||
|
|
||||||
dist = resolve_distribution_spec(args.spec)
|
api = Api(args.api)
|
||||||
if dist is None:
|
assert api in all_providers
|
||||||
self.parser.error(f"Could not find distribution {args.spec}")
|
|
||||||
|
providers = all_providers[api]
|
||||||
|
if args.provider not in providers:
|
||||||
|
self.parser.error(
|
||||||
|
f"Provider `{args.provider}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
distrib_dir = DISTRIBS_BASE_DIR / args.name
|
name = args.name or random_string()
|
||||||
os.makedirs(distrib_dir, exist_ok=True)
|
if args.type == BuildType.container.value:
|
||||||
|
package_name = f"image-{args.provider}-{name}"
|
||||||
deps = distribution_dependencies(dist)
|
|
||||||
if not args.conda_env:
|
|
||||||
print(f"Using {args.name} as the Conda environment for this distribution")
|
|
||||||
|
|
||||||
conda_env = args.conda_env or args.name
|
|
||||||
|
|
||||||
config_file = distrib_dir / "config.yaml"
|
|
||||||
if config_file.exists():
|
|
||||||
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
|
|
||||||
if c.spec != dist.spec_id:
|
|
||||||
self.parser.error(
|
|
||||||
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if c.conda_env != conda_env:
|
|
||||||
self.parser.error(
|
|
||||||
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
with open(config_file, "w") as f:
|
package_name = f"env-{args.provider}-{name}"
|
||||||
c = DistributionConfig(
|
package_name = package_name.replace("::", "-")
|
||||||
spec=dist.spec_id,
|
|
||||||
name=args.name,
|
|
||||||
conda_env=conda_env,
|
|
||||||
)
|
|
||||||
f.write(yaml.dump(c.dict(), sort_keys=False))
|
|
||||||
|
|
||||||
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
|
build_dir = BUILDS_BASE_DIR / args.api
|
||||||
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
|
||||||
|
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
|
||||||
|
dependencies = get_dependencies(providers[args.provider], provider_deps)
|
||||||
|
|
||||||
|
package_file = build_dir / f"{package_name}.yaml"
|
||||||
|
|
||||||
|
stub_config = {
|
||||||
|
api.value: {
|
||||||
|
"provider_id": args.provider,
|
||||||
|
},
|
||||||
|
**{k: {"provider_id": v} for k, v in provider_deps.items()},
|
||||||
|
}
|
||||||
|
with open(package_file, "w") as f:
|
||||||
|
c = PackageConfig(
|
||||||
|
built_at=datetime.now(),
|
||||||
|
package_name=package_name,
|
||||||
|
docker_image=(
|
||||||
|
package_name if args.type == BuildType.container.value else None
|
||||||
|
),
|
||||||
|
conda_env=(
|
||||||
|
package_name if args.type == BuildType.conda_env.value else None
|
||||||
|
),
|
||||||
|
providers=stub_config,
|
||||||
|
)
|
||||||
|
f.write(yaml.dump(c.dict(), sort_keys=False))
|
||||||
|
|
||||||
|
if args.type == BuildType.container.value:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain", "distribution/build_container.sh"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
script,
|
||||||
|
args.api,
|
||||||
|
package_name,
|
||||||
|
dependencies.docker_image or "python:3.10-slim",
|
||||||
|
" ".join(dependencies.pip_packages),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain", "distribution/build_conda_env.sh"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
script,
|
||||||
|
args.api,
|
||||||
|
package_name,
|
||||||
|
" ".join(dependencies.pip_packages),
|
||||||
|
]
|
||||||
|
|
||||||
|
return_code = run_with_pty(args)
|
||||||
assert return_code == 0, cprint(
|
assert return_code == 0, cprint(
|
||||||
f"Failed to install distribution {dist.spec_id}", color="red"
|
f"Failed to build target {package_name}", color="red"
|
||||||
)
|
)
|
||||||
cprint(
|
cprint(
|
||||||
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
f"Target `{target_name}` built with configuration at {str(package_file)}",
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
|
|
100
llama_toolchain/cli/api/configure.py
Normal file
100
llama_toolchain/cli/api/configure.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||||
|
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
|
class ApiConfigure(Subcommand):
|
||||||
|
"""Llama cli for configuring llama toolchain configs"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"configure",
|
||||||
|
prog="llama api configure",
|
||||||
|
description="configure a llama stack API provider",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_api_configure_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
from llama_toolchain.distribution.distribution import stack_apis
|
||||||
|
|
||||||
|
allowed_args = [a.name for a in stack_apis()]
|
||||||
|
self.parser.add_argument(
|
||||||
|
"api",
|
||||||
|
choices=allowed_args,
|
||||||
|
help="Stack API (one of: {})".format(", ".join(allowed_args)),
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the provider build to fully configure",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
|
||||||
|
if not config_file.exists():
|
||||||
|
self.parser.error(
|
||||||
|
f"Could not find {config_file}. Please run `llama api build` first"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
configure_llama_provider(config_file)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_llama_provider(config_file: Path) -> None:
|
||||||
|
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||||
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
|
from llama_toolchain.distribution.distribution import api_providers
|
||||||
|
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = PackageConfig(**yaml.safe_load(f))
|
||||||
|
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
provider_configs = {}
|
||||||
|
for api, stub_config in config.providers.items():
|
||||||
|
providers = all_providers[Api(api)]
|
||||||
|
provider_id = stub_config["provider_id"]
|
||||||
|
if provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_spec = providers[provider_id]
|
||||||
|
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
print(f"Config type: {config_type}")
|
||||||
|
provider_config = prompt_for_config(
|
||||||
|
config_type,
|
||||||
|
)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
provider_configs[api.value] = {
|
||||||
|
"provider_id": provider_id,
|
||||||
|
**provider_config.dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
config.providers = provider_configs
|
||||||
|
with open(config_file, "w") as fp:
|
||||||
|
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||||
|
fp.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
|
print(f"YAML configuration has been written to {config_path}")
|
|
@ -13,3 +13,5 @@ LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
|
||||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||||
|
|
||||||
|
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||||
|
|
|
@ -71,6 +71,7 @@ def prompt_for_config(
|
||||||
"""
|
"""
|
||||||
config_data = {}
|
config_data = {}
|
||||||
|
|
||||||
|
print(f"Configuring {config_type.__name__}:")
|
||||||
for field_name, field in config_type.__fields__.items():
|
for field_name, field in config_type.__fields__.items():
|
||||||
field_type = field.annotation
|
field_type = field.annotation
|
||||||
|
|
||||||
|
@ -85,6 +86,7 @@ def prompt_for_config(
|
||||||
if not isinstance(field.default, PydanticUndefinedType)
|
if not isinstance(field.default, PydanticUndefinedType)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
print(f" {field_name}: {field_type} (default: {default_value})")
|
||||||
is_required = field.is_required
|
is_required = field.is_required
|
||||||
|
|
||||||
# Skip fields with Literal type
|
# Skip fields with Literal type
|
||||||
|
|
|
@ -10,20 +10,29 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
|
echo "llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ "$#" -ne 3 ]; then
|
||||||
|
echo "Usage: $0 <api_or_stack> <environment_name> <pip_dependencies>" >&2
|
||||||
|
echo "Example: $0 [api|stack] conda-env 'numpy pandas scipy'" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
api_or_stack="$1"
|
||||||
|
env_name="$2"
|
||||||
|
pip_dependencies="$3"
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
error_handler() {
|
# this is set if we actually create a new conda in which case we need to clean up
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
ENVNAME=""
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set up the error trap
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
trap 'error_handler ${LINENO}' ERR
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
ensure_conda_env_python310() {
|
ensure_conda_env_python310() {
|
||||||
local env_name="$1"
|
local env_name="$1"
|
||||||
|
@ -52,6 +61,9 @@ ensure_conda_env_python310() {
|
||||||
else
|
else
|
||||||
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
|
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
|
||||||
conda create -n "${env_name}" python="${python_version}" -y
|
conda create -n "${env_name}" python="${python_version}" -y
|
||||||
|
|
||||||
|
ENVNAME="${env_name}"
|
||||||
|
setup_cleanup_handlers
|
||||||
fi
|
fi
|
||||||
|
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
|
@ -94,19 +106,8 @@ ensure_conda_env_python310() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ "$#" -ne 3 ]; then
|
|
||||||
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
|
|
||||||
echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
env_name="$1"
|
|
||||||
distribution_name="$2"
|
|
||||||
pip_dependencies="$3"
|
|
||||||
|
|
||||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||||
|
|
||||||
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}"
|
echo -e "${GREEN}Successfully setup conda environment. Configuring build...${NC}"
|
||||||
|
|
||||||
which python3
|
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name"
|
||||||
python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"
|
|
0
llama_toolchain/distribution/build_image.sh → llama_toolchain/distribution/build_container.sh
Normal file → Executable file
0
llama_toolchain/distribution/build_image.sh → llama_toolchain/distribution/build_container.sh
Normal file → Executable file
40
llama_toolchain/distribution/common.sh
Normal file
40
llama_toolchain/distribution/common.sh
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
envname="$1"
|
||||||
|
|
||||||
|
set +x
|
||||||
|
echo "Cleaning up..."
|
||||||
|
conda deactivate
|
||||||
|
conda env remove --name $envname -y
|
||||||
|
}
|
||||||
|
|
||||||
|
handle_int() {
|
||||||
|
if [ -n $ENVNAME ]; then
|
||||||
|
cleanup $ENVNAME
|
||||||
|
fi
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
handle_exit() {
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo -e "\033[1;31mABORTING.\033[0m"
|
||||||
|
if [ -n $ENVNAME ]; then
|
||||||
|
cleanup $ENVNAME
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
setup_cleanup_handlers() {
|
||||||
|
trap handle_int INT
|
||||||
|
trap handle_exit EXIT
|
||||||
|
|
||||||
|
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||||
|
eval "$__conda_setup"
|
||||||
|
|
||||||
|
conda deactivate
|
||||||
|
}
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -66,36 +67,45 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
|
adapter_id: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique identifier for this adapter",
|
||||||
|
)
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
pip_packages: List[str] = Field(
|
pip_packages: List[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The pip dependencies needed for this implementation",
|
description="The pip dependencies needed for this implementation",
|
||||||
)
|
)
|
||||||
config_class: str = Field(
|
config_class: Optional[str] = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RemoteProviderConfig(BaseModel):
|
class RemoteProviderConfig(BaseModel):
|
||||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
url: str = Field(..., description="The URL for the provider")
|
||||||
|
|
||||||
@validator("base_url")
|
@validator("url")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_base_url(cls, base_url: str) -> str:
|
def validate_url(cls, url: str) -> str:
|
||||||
if not base_url.startswith("http"):
|
if not url.startswith("http"):
|
||||||
raise ValueError(f"URL must start with http: {base_url}")
|
raise ValueError(f"URL must start with http: {url}")
|
||||||
return base_url
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def remote_provider_id(adapter_id: str) -> str:
|
||||||
|
return f"remote::{adapter_id}"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
module: str = Field(
|
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
adapter: Optional[AdapterSpec] = Field(
|
adapter: Optional[AdapterSpec] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -107,6 +117,32 @@ as being "Llama Stack compatible"
|
||||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# need this wrapper since we don't have Pydantic v2 and that means we don't have
|
||||||
|
# the @computed_field decorator
|
||||||
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: Optional[AdapterSpec] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
|
provider_id = (
|
||||||
|
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
|
||||||
|
)
|
||||||
|
module = (
|
||||||
|
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
|
||||||
|
)
|
||||||
|
config_class = (
|
||||||
|
adapter.config_class
|
||||||
|
if adapter and adapter.config_class
|
||||||
|
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RemoteProviderSpec(
|
||||||
|
api=api,
|
||||||
|
provider_id=provider_id,
|
||||||
|
pip_packages=adapter.pip_packages if adapter is not None else [],
|
||||||
|
module=module,
|
||||||
|
config_class=config_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
spec_id: str
|
spec_id: str
|
||||||
|
@ -119,13 +155,28 @@ class DistributionSpec(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DistributionConfig(BaseModel):
|
class PackageConfig(BaseModel):
|
||||||
"""References to a installed / configured DistributionSpec"""
|
built_at: datetime
|
||||||
|
|
||||||
name: str
|
package_name: str = Field(
|
||||||
spec: str
|
...,
|
||||||
conda_env: str
|
description="""
|
||||||
|
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||||
|
this could be just a hash
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
docker_image: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Reference to the docker image if this package refers to a container",
|
||||||
|
)
|
||||||
|
conda_env: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Reference to the conda environment if this package refers to a conda environment",
|
||||||
|
)
|
||||||
providers: Dict[str, Any] = Field(
|
providers: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Provider configurations for each of the APIs provided by this distribution",
|
description="""
|
||||||
|
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
||||||
|
the dependencies of these providers as well.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
|
@ -30,7 +30,27 @@ def instantiate_provider(
|
||||||
return asyncio.run(module.get_provider_impl(config, deps))
|
return asyncio.run(module.get_provider_impl(config, deps))
|
||||||
|
|
||||||
|
|
||||||
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
|
def instantiate_client(
|
||||||
|
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
|
||||||
|
):
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
return asyncio.run(module.get_client_impl(base_url))
|
adapter = provider_spec.adapter
|
||||||
|
if adapter is not None:
|
||||||
|
if "adapter" not in provider_config:
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter is specified but not present in provider config: {provider_config}"
|
||||||
|
)
|
||||||
|
adapter_config = provider_config["adapter"]
|
||||||
|
|
||||||
|
config_type = instantiate_class_type(adapter.config_class)
|
||||||
|
if not issubclass(config_type, RemoteProviderConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = config_type(**adapter_config)
|
||||||
|
else:
|
||||||
|
config = RemoteProviderConfig(**provider_config)
|
||||||
|
|
||||||
|
return asyncio.run(module.get_adapter_impl(config))
|
||||||
|
|
|
@ -7,22 +7,10 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
|
from .datatypes import * # noqa: F403
|
||||||
from .distribution import api_providers
|
from .distribution import api_providers
|
||||||
|
|
||||||
|
|
||||||
def client_module(api: Api) -> str:
|
|
||||||
return f"llama_toolchain.{api.value}.client"
|
|
||||||
|
|
||||||
|
|
||||||
def remote_spec(api: Api) -> RemoteProviderSpec:
|
|
||||||
return RemoteProviderSpec(
|
|
||||||
api=api,
|
|
||||||
provider_id=f"{api.value}-remote",
|
|
||||||
module=client_module(api),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def available_distribution_specs() -> List[DistributionSpec]:
|
def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
providers = api_providers()
|
providers = api_providers()
|
||||||
|
@ -40,13 +28,14 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
spec_id="remote",
|
spec_id="remote",
|
||||||
description="Point to remote services for all llama stack APIs",
|
description="Point to remote services for all llama stack APIs",
|
||||||
provider_specs={x: remote_spec(x) for x in providers},
|
provider_specs={x: remote_provider_spec(x) for x in providers},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
spec_id="local-ollama",
|
spec_id="local-ollama",
|
||||||
description="Like local, but use ollama for running LLM inference",
|
description="Like local, but use ollama for running LLM inference",
|
||||||
provider_specs={
|
provider_specs={
|
||||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
# this is ODD; make this easier -- we just need a better function to retrieve registered providers
|
||||||
|
Api.inference: providers[Api.inference][remote_provider_id("ollama")],
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
Api.safety: providers[Api.safety]["meta-reference"],
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||||
|
@ -57,9 +46,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
description="Test agentic with others as remote",
|
description="Test agentic with others as remote",
|
||||||
provider_specs={
|
provider_specs={
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
Api.inference: remote_spec(Api.inference),
|
Api.inference: remote_provider_spec(Api.inference),
|
||||||
Api.memory: remote_spec(Api.memory),
|
Api.memory: remote_provider_spec(Api.memory),
|
||||||
Api.safety: remote_spec(Api.safety),
|
Api.safety: remote_provider_spec(Api.safety),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
|
|
|
@ -264,7 +264,8 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
|
||||||
provider_config = provider_configs[api.value]
|
provider_config = provider_configs[api.value]
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
impls[api] = instantiate_client(
|
impls[api] = instantiate_client(
|
||||||
provider_spec, provider_config["base_url"].rstrip("/")
|
provider_spec,
|
||||||
|
provider_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||||
|
|
5
llama_toolchain/inference/adapters/__init__.py
Normal file
5
llama_toolchain/inference/adapters/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import OllamaImplConfig # noqa
|
from .ollama import get_adapter_impl # noqa
|
||||||
from .ollama import get_provider_impl # noqa
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -27,7 +27,6 @@ from llama_toolchain.inference.api import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
from .config import OllamaImplConfig
|
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
|
@ -37,26 +36,21 @@ OLLAMA_SUPPORTED_SKUS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
||||||
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
|
impl = OllamaInferenceAdapter(config.url)
|
||||||
) -> Inference:
|
|
||||||
assert isinstance(
|
|
||||||
config, OllamaImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
impl = OllamaInference(config)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
class OllamaInference(Inference):
|
class OllamaInferenceAdapter(Inference):
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
self.config = config
|
self.url = url
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def client(self) -> AsyncClient:
|
||||||
return AsyncClient(host=self.config.url)
|
return AsyncClient(host=self.url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
|
@ -13,6 +13,8 @@ import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import (
|
from .api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -24,8 +26,8 @@ from .api import (
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(base_url: str):
|
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
||||||
return InferenceClient(base_url)
|
return InferenceClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
def encodable_dict(d: BaseModel):
|
def encodable_dict(d: BaseModel):
|
||||||
|
@ -34,7 +36,7 @@ def encodable_dict(d: BaseModel):
|
||||||
|
|
||||||
class InferenceClient(Inference):
|
class InferenceClient(Inference):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
print(f"Initializing client for {base_url}")
|
print(f"Inference passthrough to -> {base_url}")
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OllamaImplConfig(BaseModel):
|
|
||||||
url: str = Field(
|
|
||||||
default="http://localhost:11434",
|
|
||||||
description="The URL for the ollama server",
|
|
||||||
)
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
def available_inference_providers() -> List[ProviderSpec]:
|
def available_inference_providers() -> List[ProviderSpec]:
|
||||||
|
@ -27,13 +27,12 @@ def available_inference_providers() -> List[ProviderSpec]:
|
||||||
module="llama_toolchain.inference.meta_reference",
|
module="llama_toolchain.inference.meta_reference",
|
||||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_id="meta-ollama",
|
adapter=AdapterSpec(
|
||||||
pip_packages=[
|
adapter_id="ollama",
|
||||||
"ollama",
|
pip_packages=[],
|
||||||
],
|
module="llama_toolchain.inference.adapters.ollama",
|
||||||
module="llama_toolchain.inference.ollama",
|
),
|
||||||
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,24 +6,23 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# import json
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
# from termcolor import cprint
|
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(base_url: str):
|
async def get_adapter_impl(config: RemoteProviderConfig) -> Memory:
|
||||||
return MemoryClient(base_url)
|
return MemoryClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
class MemoryClient(Memory):
|
class MemoryClient(Memory):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
print(f"Initializing client for {base_url}")
|
print(f"Memory passthrough to -> {base_url}")
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -10,20 +10,17 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import (
|
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
BuiltinShield,
|
|
||||||
RunShieldRequest,
|
from .api import * # noqa: F403
|
||||||
RunShieldResponse,
|
|
||||||
Safety,
|
|
||||||
ShieldDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(base_url: str):
|
async def get_adapter_impl(config: RemoteProviderConfig) -> Safety:
|
||||||
return SafetyClient(base_url)
|
return SafetyClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
def encodable_dict(d: BaseModel):
|
def encodable_dict(d: BaseModel):
|
||||||
|
@ -32,7 +29,7 @@ def encodable_dict(d: BaseModel):
|
||||||
|
|
||||||
class SafetyClient(Safety):
|
class SafetyClient(Safety):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
print(f"Initializing client for {base_url}")
|
print(f"Safety passthrough to -> {base_url}")
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue