llama_toolchain/distribution -> llama_toolchain/core

This commit is contained in:
Ashwin Bharambe 2024-08-28 17:39:41 -07:00
parent 81540e6ce8
commit 3cb67f1f58
31 changed files with 49 additions and 45 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,127 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
set -euo pipefail
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <api_or_stack> <environment_name> <pip_dependencies>" >&2
echo "Example: $0 [api|stack] conda-env 'numpy pandas scipy'" >&2
exit 1
fi
api_or_stack="$1"
env_name="$2"
pip_dependencies="$3"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version..."
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed."
else
printf "Updating environment '${env_name}' to Python ${python_version}..."
conda install -n "${env_name}" python="${python_version}" -y
fi
else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
# setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies
else
# Re-installing llama-toolchain in the new conda environment
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
exit 1
fi
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR"
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
else
pip install --no-cache-dir llama-toolchain
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR"
pip uninstall -y llama-models
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi
# Install pip dependencies
if [ -n "$pip_dependencies" ]; then
printf "Installing pip dependencies: $pip_dependencies"
pip install $pip_dependencies
fi
fi
}
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}"
if [ "$api_or_stack" = "stack" ]; then
subcommand="stack"
target=""
else
subcommand="api"
target="$api_or_stack"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --build-name "$env_name"

View file

@ -0,0 +1,121 @@
#!/bin/bash
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
echo "Usage: $0 [api|stack] <image_name> <docker_base> <pip_dependencies>
echo "Example: $0 agentic_system my-fastapi-app python:3.9-slim 'fastapi uvicorn'
exit 1
fi
api_or_stack=$1
image_name=$2
docker_base=$3
pip_dependencies=$4
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
set -euo pipefail
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
TEMP_DIR=$(mktemp -d)
add_to_docker() {
local input
output_file="$TEMP_DIR/Dockerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >> "$output_file"
else
# If stdin is not a terminal, read from it (heredoc)
cat >> "$output_file"
fi
}
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet \
procps psmisc lsof \
traceroute \
&& rm -rf /var/lib/apt/lists/*
EOF
toolchain_mount="/app/llama-toolchain-source"
models_mount="/app/llama-models-source"
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
exit 1
fi
add_to_docker "RUN pip install $toolchain_mount"
else
add_to_docker "RUN pip install llama-toolchain"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
add_to_docker <<EOF
RUN pip uninstall -y llama-models
RUN pip install $models_mount
EOF
fi
if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies"
fi
add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
EOF
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
cat $TEMP_DIR/Dockerfile
printf "\n"
mounts=""
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
set -x
podman build -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
set +x
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
echo "You can run it with: podman run -p 8000:8000 $image_name"
if [ "$api_or_stack" = "stack" ]; then
subcommand="stack"
target=""
else
subcommand="api"
target="$api_or_stack"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --build-name "$image_name"

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
envname="$1"
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name $envname -y
}
handle_int() {
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
fi
}
setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
}

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_toolchain.core.datatypes import * # noqa: F403
from termcolor import cprint
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.core.distribution import api_providers
from llama_toolchain.core.dynamic import instantiate_class_type
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
all_providers = api_providers()
provider_configs = {}
for api_str, stub_config in existing_configs.items():
api = Api(api_str)
providers = all_providers[api]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api_str] = {
"provider_id": provider_id,
**provider_config.dict(),
}
return provider_configs

View file

@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, validator
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
memory = "memory"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
docker_image: Optional[str] = Field(
default=None,
description="""
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
class RemoteProviderConfig(BaseModel):
url: str = Field(..., description="The URL for the provider")
@validator("url")
@classmethod
def validate_url(cls, url: str) -> str:
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def docker_image(self) -> Optional[str]:
return None
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client"
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)
@json_schema_type
class DistributionSpec(BaseModel):
distribution_id: str
description: str
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
description="Provider IDs for each of the APIs provided by this distribution",
)
@json_schema_type
class PackageConfig(BaseModel):
built_at: datetime
package_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
distribution_id: Optional[str] = None
docker_image: Optional[str] = Field(
default=None,
description="Reference to the docker image if this package refers to a container",
)
conda_env: Optional[str] = Field(
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
providers: Dict[str, Any] = Field(
default_factory=dict,
description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for
the dependencies of these providers as well.
""",
)

View file

@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.memory.api import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import (
Api,
ApiEndpoint,
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
# These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"uvicorn",
]
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages
] + SERVER_DEPENDENCIES
def stack_apis() -> List[Api]:
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.memory: Memory,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from functools import lru_cache
from typing import List, Optional
from .datatypes import * # noqa: F403
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
return [
DistributionSpec(
distribution_id="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={
Api.inference: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
},
),
DistributionSpec(
distribution_id="remote",
description="Point to remote services for all llama stack APIs",
providers={x: "remote" for x in Api},
),
DistributionSpec(
distribution_id="local-ollama",
description="Like local, but use ollama for running LLM inference",
providers={
Api.inference: remote_provider_id("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("fireworks"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
]
@lru_cache()
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.distribution_id == distribution_id:
return spec
return None

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib
from typing import Any, Dict
from .datatypes import ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
def instantiate_provider(
provider_spec: ProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
else:
method = "get_provider_impl"
config = config_type(**provider_config)
fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
import pkg_resources
import yaml
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.distribution import api_providers
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
def descriptor(self) -> str:
return "image" if self == self.container else "env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
from llama_toolchain.core.distribution import SERVER_DEPENDENCIES
pip_packages = provider.pip_packages
for dep in dependencies.values():
if dep.docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep.pip_packages)
return Dependencies(
docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES,
)
class ApiInput(BaseModel):
api: Api
provider: str
dependencies: Dict[str, ProviderSpec]
def build_package(
api_inputs: List[ApiInput],
build_type: BuildType,
name: str,
distribution_id: Optional[str] = None,
docker_image: Optional[str] = None,
):
is_stack = len(api_inputs) > 1
if is_stack:
if not distribution_id:
raise ValueError(
"You must specify a distribution name when building the Llama Stack"
)
api1 = api_inputs[0]
provider = distribution_id if is_stack else api1.provider
api_or_stack = "stack" if is_stack else api1.api.value
build_dir = BUILDS_BASE_DIR / api_or_stack
os.makedirs(build_dir, exist_ok=True)
package_name = f"{build_type.descriptor()}-{provider}-{name}"
package_name = package_name.replace("::", "-")
package_file = build_dir / f"{package_name}.yaml"
all_providers = api_providers()
package_deps = Dependencies(
docker_image=docker_image or "python:3.10-slim",
pip_packages=[],
)
stub_config = {}
for api_input in api_inputs:
api = api_input.api
providers_for_api = all_providers[api]
if api_input.provider not in providers_for_api:
raise ValueError(
f"Provider `{api_input.provider}` is not available for API `{api}`"
)
deps = get_dependencies(
providers_for_api[api_input.provider],
api_input.dependencies,
)
if deps.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
package_deps.pip_packages.extend(deps.pip_packages)
stub_config[api.value] = {"provider_id": api_input.provider}
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
for api_str, new_config in stub_config.items():
if api_str not in c.providers:
c.providers[api_str] = new_config
else:
existing_config = c.providers[api_str]
if existing_config["provider_id"] != new_config["provider_id"]:
cprint(
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
color="yellow",
)
c.providers[api_str] = new_config
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.distribution_id = distribution_id
c.docker_image = package_name if build_type == BuildType.container else None
c.conda_env = package_name if build_type == BuildType.conda_env else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if build_type == BuildType.container:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
api_or_stack,
package_name,
package_deps.docker_image,
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
api_or_stack,
package_name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
)

View file

@ -0,0 +1,345 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(**kwargs):
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(**kwargs):
try:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(provider_specs.values())
impls = {}
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find provider_spec config for {api}. Please add it to the config"
)
if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
app = FastAPI()
all_endpoints = api_endpoints()
all_providers = api_providers()
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.__provider_config__.url
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,41 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_name> <yaml_config> <port> <script_args...>"
exit 1
fi
env_name="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
$CONDA_PREFIX/bin/python \
-m llama_toolchain.core.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"

View file

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <docker_image> <yaml_config> <port> <other_args...>"
exit 1
fi
docker_image="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
set -x
podman run -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
$docker_image \
python -m llama_toolchain.core.server \
--yaml_config /app/config.yaml \
--port $port "$@"