bunch more work to make adapters work

This commit is contained in:
Ashwin Bharambe 2024-08-27 19:15:42 -07:00
parent 68f3db62e9
commit c4fe72c3a3
20 changed files with 461 additions and 173 deletions

View file

@ -9,6 +9,7 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import ApiBuild
from .configure import ApiConfigure
class ApiParser(Subcommand):
@ -24,3 +25,4 @@ class ApiParser(Subcommand):
# Add sub-commands
ApiBuild.create(subparsers)
ApiConfigure.create(subparsers)

View file

@ -6,22 +6,92 @@
import argparse
import os
import random
import string
import uuid
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Tuple
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
def random_string():
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
if isinstance(provider, InlineProviderSpec):
return provider.pip_packages, provider.docker_image
else:
if provider.adapter:
return provider.adapter.pip_packages, None
return [], None
pip_packages, docker_image = _deps(provider)
for dep in dependencies.values():
dep_pip_packages, dep_docker_image = _deps(dep)
if docker_image and dep_docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep_pip_packages)
return Dependencies(docker_image=docker_image, pip_packages=pip_packages)
def parse_dependencies(
dependencies: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
from llama_toolchain.distribution.distribution import api_providers
all_providers = api_providers()
deps = {}
for dep in dependencies.split(","):
dep = dep.strip()
if not dep:
continue
api_str, provider = dep.split("=")
api = Api(api_str)
provider = provider.strip()
if provider not in all_providers[api]:
parser.error(f"Provider `{provider}` is not available for API `{api}`")
return
deps[api] = all_providers[api][provider]
return deps
class ApiBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"install",
"build",
prog="llama api build",
description="Build a Llama stack API provider container",
formatter_class=argparse.RawTextHelpFormatter,
@ -36,7 +106,7 @@ class ApiBuild(Subcommand):
self.parser.add_argument(
"api",
choices=allowed_args,
help="Stack API (one of: {})".format(", ".join(allowed_args))
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
@ -45,73 +115,104 @@ class ApiBuild(Subcommand):
help="The provider to package into the container",
required=True,
)
self.parser.add_argument(
"--container-name",
type=str,
help="Name of the container (including tag if needed)",
required=True,
)
self.parser.add_argument(
"--dependencies",
type=str,
help="Comma separated list of (downstream_api=provider) dependencies needed for the API",
required=False,
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env). Defaults to a random UUID",
required=False,
)
self.parser.add_argument(
"--type",
type=str,
default="container",
choices=[v.value for v in BuildType],
)
def _run_api_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.distribution import api_providers
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/build_api.sh",
)
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
all_providers = api_providers()
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.spec}")
return
api = Api(args.api)
assert api in all_providers
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
providers = all_providers[api]
if args.provider not in providers:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
f"Provider `{args.provider}` is not available for API `{api}`"
)
return
name = args.name or random_string()
if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{name}"
else:
with open(config_file, "w") as f:
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
package_name = f"env-{args.provider}-{name}"
package_name = package_name.replace("::", "-")
build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True)
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps)
package_file = build_dir / f"{package_name}.yaml"
stub_config = {
api.value: {
"provider_id": args.provider,
},
**{k: {"provider_id": v} for k, v in provider_deps.items()},
}
with open(package_file, "w") as f:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
docker_image=(
package_name if args.type == BuildType.container.value else None
),
conda_env=(
package_name if args.type == BuildType.conda_env.value else None
),
providers=stub_config,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
if args.type == BuildType.container.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
args.api,
package_name,
dependencies.docker_image or "python:3.10-slim",
" ".join(dependencies.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
args.api,
package_name,
" ".join(dependencies.pip_packages),
]
return_code = run_with_pty(args)
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
f"Failed to build target {package_name}", color="red"
)
cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
f"Target `{target_name}` built with configuration at {str(package_file)}",
color="green",
)

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from pathlib import Path
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
from termcolor import cprint
class ApiConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama api configure",
description="configure a llama stack API provider",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_configure_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.distribution import stack_apis
allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument(
"api",
choices=allowed_args,
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the provider build to fully configure",
required=True,
)
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
)
return
configure_llama_provider(config_file)
def configure_llama_provider(config_file: Path) -> None:
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.distribution.dynamic import instantiate_class_type
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
all_providers = api_providers()
provider_configs = {}
for api, stub_config in config.providers.items():
providers = all_providers[Api(api)]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
print(f"Config type: {config_type}")
provider_config = prompt_for_config(
config_type,
)
print("")
provider_configs[api.value] = {
"provider_id": provider_id,
**provider_config.dict(),
}
config.providers = provider_configs
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")

View file

@ -13,3 +13,5 @@ LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"

View file

@ -71,6 +71,7 @@ def prompt_for_config(
"""
config_data = {}
print(f"Configuring {config_type.__name__}:")
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
@ -85,6 +86,7 @@ def prompt_for_config(
if not isinstance(field.default, PydanticUndefinedType)
else None
)
print(f" {field_name}: {field_type} (default: {default_value})")
is_required = field.is_required
# Skip fields with Literal type

View file

@ -10,20 +10,29 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
echo "llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
set -euo pipefail
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <api_or_stack> <environment_name> <pip_dependencies>" >&2
echo "Example: $0 [api|stack] conda-env 'numpy pandas scipy'" >&2
exit 1
fi
api_or_stack="$1"
env_name="$2"
pip_dependencies="$3"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() {
local env_name="$1"
@ -52,6 +61,9 @@ ensure_conda_env_python310() {
else
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"
@ -94,19 +106,8 @@ ensure_conda_env_python310() {
fi
}
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2
exit 1
fi
env_name="$1"
distribution_name="$2"
pip_dependencies="$3"
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}"
echo -e "${GREEN}Successfully setup conda environment. Configuring build...${NC}"
which python3
python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name"

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
envname="$1"
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name $envname -y
}
handle_int() {
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
fi
}
setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
}

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
@ -66,36 +67,45 @@ Fully-qualified name of the module to import. The module is expected to have:
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
...,
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
class RemoteProviderConfig(BaseModel):
base_url: str = Field(..., description="The base URL for the llama stack provider")
url: str = Field(..., description="The URL for the provider")
@validator("base_url")
@validator("url")
@classmethod
def validate_base_url(cls, base_url: str) -> str:
if not base_url.startswith("http"):
raise ValueError(f"URL must start with http: {base_url}")
return base_url
def validate_url(cls, url: str) -> str:
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
@ -107,6 +117,32 @@ as being "Llama Stack compatible"
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
# need this wrapper since we don't have Pydantic v2 and that means we don't have
# the @computed_field decorator
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
provider_id = (
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
)
module = (
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
)
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
)
return RemoteProviderSpec(
api=api,
provider_id=provider_id,
pip_packages=adapter.pip_packages if adapter is not None else [],
module=module,
config_class=config_class,
)
@json_schema_type
class DistributionSpec(BaseModel):
spec_id: str
@ -119,13 +155,28 @@ class DistributionSpec(BaseModel):
@json_schema_type
class DistributionConfig(BaseModel):
"""References to a installed / configured DistributionSpec"""
class PackageConfig(BaseModel):
built_at: datetime
name: str
spec: str
conda_env: str
package_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
docker_image: Optional[str] = Field(
default=None,
description="Reference to the docker image if this package refers to a container",
)
conda_env: Optional[str] = Field(
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
providers: Dict[str, Any] = Field(
default_factory=dict,
description="Provider configurations for each of the APIs provided by this distribution",
description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for
the dependencies of these providers as well.
""",
)

View file

@ -30,7 +30,27 @@ def instantiate_provider(
return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
def instantiate_client(
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
):
module = importlib.import_module(provider_spec.module)
return asyncio.run(module.get_client_impl(base_url))
adapter = provider_spec.adapter
if adapter is not None:
if "adapter" not in provider_config:
raise ValueError(
f"Adapter is specified but not present in provider config: {provider_config}"
)
adapter_config = provider_config["adapter"]
config_type = instantiate_class_type(adapter.config_class)
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**adapter_config)
else:
config = RemoteProviderConfig(**provider_config)
return asyncio.run(module.get_adapter_impl(config))

View file

@ -7,22 +7,10 @@
from functools import lru_cache
from typing import List, Optional
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
from .datatypes import * # noqa: F403
from .distribution import api_providers
def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def remote_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_id=f"{api.value}-remote",
module=client_module(api),
)
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
providers = api_providers()
@ -40,13 +28,14 @@ def available_distribution_specs() -> List[DistributionSpec]:
DistributionSpec(
spec_id="remote",
description="Point to remote services for all llama stack APIs",
provider_specs={x: remote_spec(x) for x in providers},
provider_specs={x: remote_provider_spec(x) for x in providers},
),
DistributionSpec(
spec_id="local-ollama",
description="Like local, but use ollama for running LLM inference",
provider_specs={
Api.inference: providers[Api.inference]["meta-ollama"],
# this is ODD; make this easier -- we just need a better function to retrieve registered providers
Api.inference: providers[Api.inference][remote_provider_id("ollama")],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.memory: providers[Api.memory]["meta-reference-faiss"],
@ -57,9 +46,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
description="Test agentic with others as remote",
provider_specs={
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.inference: remote_spec(Api.inference),
Api.memory: remote_spec(Api.memory),
Api.safety: remote_spec(Api.safety),
Api.inference: remote_provider_spec(Api.inference),
Api.memory: remote_provider_spec(Api.memory),
Api.safety: remote_provider_spec(Api.safety),
},
),
DistributionSpec(

View file

@ -264,7 +264,8 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_config["base_url"].rstrip("/")
provider_spec,
provider_config,
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OllamaImplConfig # noqa
from .ollama import get_provider_impl # noqa
from .ollama import get_adapter_impl # noqa

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
import httpx
@ -14,7 +14,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -27,7 +27,6 @@ from llama_toolchain.inference.api import (
ToolCallParseStatus,
)
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import OllamaImplConfig
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
@ -37,26 +36,21 @@ OLLAMA_SUPPORTED_SKUS = {
}
async def get_provider_impl(
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
impl = OllamaInference(config)
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl
class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None:
self.config = config
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.config.url)
return AsyncClient(host=self.url)
async def initialize(self) -> None:
try:

View file

@ -13,6 +13,8 @@ import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -24,8 +26,8 @@ from .api import (
from .event_logger import EventLogger
async def get_client_impl(base_url: str):
return InferenceClient(base_url)
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
return InferenceClient(config.url)
def encodable_dict(d: BaseModel):
@ -34,7 +36,7 @@ def encodable_dict(d: BaseModel):
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
print(f"Inference passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None:

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class OllamaImplConfig(BaseModel):
url: str = Field(
default="http://localhost:11434",
description="The URL for the ollama server",
)

View file

@ -6,7 +6,7 @@
from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_toolchain.distribution.datatypes import * # noqa: F403
def available_inference_providers() -> List[ProviderSpec]:
@ -27,13 +27,12 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
),
InlineProviderSpec(
remote_provider_spec(
api=Api.inference,
provider_id="meta-ollama",
pip_packages=[
"ollama",
],
module="llama_toolchain.inference.ollama",
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
adapter=AdapterSpec(
adapter_id="ollama",
pip_packages=[],
module="llama_toolchain.inference.adapters.ollama",
),
),
]

View file

@ -6,24 +6,23 @@
import asyncio
# import json
from typing import Dict, List, Optional
import fire
import httpx
# from termcolor import cprint
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_client_impl(base_url: str):
return MemoryClient(base_url)
async def get_adapter_impl(config: RemoteProviderConfig) -> Memory:
return MemoryClient(config.url)
class MemoryClient(Memory):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
print(f"Memory passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None:

View file

@ -10,20 +10,17 @@ import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
from .api import (
BuiltinShield,
RunShieldRequest,
RunShieldResponse,
Safety,
ShieldDefinition,
)
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_client_impl(base_url: str):
return SafetyClient(base_url)
async def get_adapter_impl(config: RemoteProviderConfig) -> Safety:
return SafetyClient(config.url)
def encodable_dict(d: BaseModel):
@ -32,7 +29,7 @@ def encodable_dict(d: BaseModel):
class SafetyClient(Safety):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
print(f"Safety passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None: