API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)

* add tools to chat completion request

* use templates for generating system prompts

* Moved ToolPromptFormat and jinja templates to llama_models.llama3.api

* <WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes

* InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

* flesh out memory banks API

* agentic loop has a RAG implementation

* faiss provider implementation

* memory client works

* re-work tool definitions, fix FastAPI issues, fix tool regressions

* fix agentic_system utils

* basic RAG seems to work

* small bug fixes for inline attachments

* Refactor custom tool execution utilities

* Bug fix, show memory retrieval steps in EventLogger

* No need for api_key for Remote providers

* add special unicode character ↵ to showcase newlines in model prompt templates

* remove api.endpoints imports

* combine datatypes.py and endpoints.py into api.py

* Attachment / add TTL api

* split batch_inference from inference

* minor import fixes

* use a single impl for ChatFormat.decode_assistant_mesage

* use interleaved_text_media_as_str() utilityt

* Fix api.datatypes imports

* Add blobfile for tiktoken

* Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

* templates take optional --format={json,function_tag}

* Rag Updates

* Add `api build` subcommand -- WIP

* fix

* build + run image seems to work

* <WIP> adapters

* bunch more work to make adapters work

* api build works for conda now

* ollama remote adapter works

* Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight

* llama distribution -> llama stack + containers (WIP)

* All the new CLI for api + stack work

* Make Fireworks and Together into the Adapter format

* Some quick fixes to the CLI behavior to make it consistent

* Updated README phew

* Update cli_reference.md

* llama_toolchain/distribution -> llama_toolchain/core

* Add termcolor

* update paths

* Add a log just for consistency

* chmod +x scripts

* Fix api dependencies not getting added to configuration

* missing import lol

* Delete utils.py; move to agentic system

* Support downloading of URLs for attachments for code interpreter

* Simplify and generalize `llama api build` yay

* Update `llama stack configure` to be very simple also

* Fix stack start

* Allow building an "adhoc" distribution

* Remote `llama api []` subcommands

* Fixes to llama stack commands and update docs

* Update documentation again and add error messages to llama stack start

* llama stack start -> llama stack run

* Change name of build for less confusion

* Add pyopenapi fork to the repository, update RFC assets

* Remove conflicting annotation

* Added a "--raw" option for model template printing

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
Ashwin Bharambe 2024-09-03 22:39:39 -07:00 committed by GitHub
parent 35093c0b6f
commit 7bc7785b0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
141 changed files with 8252 additions and 4032 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,128 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
set -euo pipefail
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
exit 1
fi
distribution_id="$1"
build_name="$2"
env_name="llamastack-$build_name"
pip_dependencies="$3"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
conda install -n "${env_name}" python="${python_version}" -y
fi
else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
# setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies
else
# Re-installing llama-toolchain in the new conda environment
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n"
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
else
pip install --no-cache-dir llama-toolchain
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
pip uninstall -y llama-models
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi
# Install pip dependencies
if [ -n "$pip_dependencies" ]; then
printf "Installing pip dependencies: $pip_dependencies\n"
pip install $pip_dependencies
fi
fi
}
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
if [ "$distribution_id" = "adhoc" ]; then
subcommand="api"
target=""
else
subcommand="stack"
target="$distribution_id"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type conda_env

View file

@ -0,0 +1,120 @@
#!/bin/bash
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
exit 1
fi
distribution_id=$1
build_name="$2"
image_name="llamastack-$build_name"
docker_base=$3
pip_dependencies=$4
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
set -euo pipefail
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
TEMP_DIR=$(mktemp -d)
add_to_docker() {
local input
output_file="$TEMP_DIR/Dockerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
else
# If stdin is not a terminal, read from it (heredoc)
cat >>"$output_file"
fi
}
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet \
procps psmisc lsof \
traceroute \
&& rm -rf /var/lib/apt/lists/*
EOF
toolchain_mount="/app/llama-toolchain-source"
models_mount="/app/llama-models-source"
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
exit 1
fi
add_to_docker "RUN pip install $toolchain_mount"
else
add_to_docker "RUN pip install llama-toolchain"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
add_to_docker <<EOF
RUN pip uninstall -y llama-models
RUN pip install $models_mount
EOF
fi
if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies"
fi
add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
EOF
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
cat $TEMP_DIR/Dockerfile
printf "\n"
mounts=""
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
set -x
podman build -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
set +x
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
echo "You can run it with: podman run -p 8000:8000 $image_name"
if [ "$distribution_id" = "adhoc" ]; then
subcommand="api"
target=""
else
subcommand="stack"
target="$distribution_id"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type container

40
llama_toolchain/core/common.sh Executable file
View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
envname="$1"
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name $envname -y
}
handle_int() {
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
fi
}
setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
}

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_toolchain.core.datatypes import * # noqa: F403
from termcolor import cprint
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.core.distribution import api_providers
from llama_toolchain.core.dynamic import instantiate_class_type
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
all_providers = api_providers()
provider_configs = {}
for api_str, stub_config in existing_configs.items():
api = Api(api_str)
providers = all_providers[api]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api_str] = {
"provider_id": provider_id,
**provider_config.dict(),
}
return provider_configs

View file

@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, validator
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
memory = "memory"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
docker_image: Optional[str] = Field(
default=None,
description="""
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
class RemoteProviderConfig(BaseModel):
url: str = Field(..., description="The URL for the provider")
@validator("url")
@classmethod
def validate_url(cls, url: str) -> str:
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def docker_image(self) -> Optional[str]:
return None
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client"
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)
@json_schema_type
class DistributionSpec(BaseModel):
distribution_id: str
description: str
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
description="Provider IDs for each of the APIs provided by this distribution",
)
@json_schema_type
class PackageConfig(BaseModel):
built_at: datetime
package_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
distribution_id: Optional[str] = None
docker_image: Optional[str] = Field(
default=None,
description="Reference to the docker image if this package refers to a container",
)
conda_env: Optional[str] = Field(
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
providers: Dict[str, Any] = Field(
default_factory=dict,
description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for
the dependencies of these providers as well.
""",
)

View file

@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.memory.api import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import (
Api,
ApiEndpoint,
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
# These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"uvicorn",
]
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages
] + SERVER_DEPENDENCIES
def stack_apis() -> List[Api]:
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.memory: Memory,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from functools import lru_cache
from typing import List, Optional
from .datatypes import * # noqa: F403
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
return [
DistributionSpec(
distribution_id="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
providers={
Api.inference: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
},
),
DistributionSpec(
distribution_id="remote",
description="Point to remote services for all llama stack APIs",
providers={x: "remote" for x in Api},
),
DistributionSpec(
distribution_id="local-ollama",
description="Like local, but use ollama for running LLM inference",
providers={
Api.inference: remote_provider_id("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("fireworks"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
distribution_id="local-plus-together-inference",
description="Use Together.ai for running LLM inference",
providers={
Api.inference: remote_provider_id("together"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
]
@lru_cache()
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.distribution_id == distribution_id:
return spec
return None

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib
from typing import Any, Dict
from .datatypes import ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
def instantiate_provider(
provider_spec: ProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
else:
method = "get_provider_impl"
config = config_type(**provider_config)
fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum
from typing import List, Optional
import pkg_resources
import yaml
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
def descriptor(self) -> str:
return "docker" if self == self.container else "conda"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
class ApiInput(BaseModel):
api: Api
provider: str
def build_package(
api_inputs: List[ApiInput],
build_type: BuildType,
name: str,
distribution_id: Optional[str] = None,
docker_image: Optional[str] = None,
):
if not distribution_id:
distribution_id = "adhoc"
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
os.makedirs(build_dir, exist_ok=True)
package_name = name.replace("::", "-")
package_file = build_dir / f"{package_name}.yaml"
all_providers = api_providers()
package_deps = Dependencies(
docker_image=docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
stub_config = {}
for api_input in api_inputs:
api = api_input.api
providers_for_api = all_providers[api]
if api_input.provider not in providers_for_api:
raise ValueError(
f"Provider `{api_input.provider}` is not available for API `{api}`"
)
provider = providers_for_api[api_input.provider]
package_deps.pip_packages.extend(provider.pip_packages)
if provider.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
stub_config[api.value] = {"provider_id": api_input.provider}
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
for api_str, new_config in stub_config.items():
if api_str not in c.providers:
c.providers[api_str] = new_config
else:
existing_config = c.providers[api_str]
if existing_config["provider_id"] != new_config["provider_id"]:
cprint(
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
color="yellow",
)
c.providers[api_str] = new_config
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.distribution_id = distribution_id
c.docker_image = package_name if build_type == BuildType.container else None
c.conda_env = package_name if build_type == BuildType.conda_env else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if build_type == BuildType.container:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_container.sh"
)
args = [
script,
distribution_id,
package_name,
package_deps.docker_image,
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_conda_env.sh"
)
args = [
script,
distribution_id,
package_name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
)

View file

@ -0,0 +1,345 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(**kwargs):
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(**kwargs):
try:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(provider_specs.values())
impls = {}
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find provider_spec config for {api}. Please add it to the config"
)
if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
app = FastAPI()
all_endpoints = api_endpoints()
all_providers = api_providers()
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
exit 1
fi
build_name="$1"
env_name="llamastack-$build_name"
shift
yaml_config="$1"
shift
port="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
$CONDA_PREFIX/bin/python \
-m llama_toolchain.core.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"

View file

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
exit 1
fi
build_name="$1"
docker_image="llamastack-$build_name"
shift
yaml_config="$1"
shift
port="$1"
shift
set -x
podman run -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
$docker_image \
python -m llama_toolchain.core.server \
--yaml_config /app/config.yaml \
--port $port "$@"