mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)
* add tools to chat completion request * use templates for generating system prompts * Moved ToolPromptFormat and jinja templates to llama_models.llama3.api * <WIP> memory changes - inlined AgenticSystemInstanceConfig so API feels more ergonomic - renamed it to AgentConfig, AgentInstance -> Agent - added a MemoryConfig and `memory` parameter - added `attachments` to input and `output_attachments` to the response - some naming changes * InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool * flesh out memory banks API * agentic loop has a RAG implementation * faiss provider implementation * memory client works * re-work tool definitions, fix FastAPI issues, fix tool regressions * fix agentic_system utils * basic RAG seems to work * small bug fixes for inline attachments * Refactor custom tool execution utilities * Bug fix, show memory retrieval steps in EventLogger * No need for api_key for Remote providers * add special unicode character ↵ to showcase newlines in model prompt templates * remove api.endpoints imports * combine datatypes.py and endpoints.py into api.py * Attachment / add TTL api * split batch_inference from inference * minor import fixes * use a single impl for ChatFormat.decode_assistant_mesage * use interleaved_text_media_as_str() utilityt * Fix api.datatypes imports * Add blobfile for tiktoken * Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly * templates take optional --format={json,function_tag} * Rag Updates * Add `api build` subcommand -- WIP * fix * build + run image seems to work * <WIP> adapters * bunch more work to make adapters work * api build works for conda now * ollama remote adapter works * Several smaller fixes to make adapters work Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight * llama distribution -> llama stack + containers (WIP) * All the new CLI for api + stack work * Make Fireworks and Together into the Adapter format * Some quick fixes to the CLI behavior to make it consistent * Updated README phew * Update cli_reference.md * llama_toolchain/distribution -> llama_toolchain/core * Add termcolor * update paths * Add a log just for consistency * chmod +x scripts * Fix api dependencies not getting added to configuration * missing import lol * Delete utils.py; move to agentic system * Support downloading of URLs for attachments for code interpreter * Simplify and generalize `llama api build` yay * Update `llama stack configure` to be very simple also * Fix stack start * Allow building an "adhoc" distribution * Remote `llama api []` subcommands * Fixes to llama stack commands and update docs * Update documentation again and add error messages to llama stack start * llama stack start -> llama stack run * Change name of build for less confusion * Add pyopenapi fork to the repository, update RFC assets * Remove conflicting annotation * Added a "--raw" option for model template printing --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com> Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
parent
35093c0b6f
commit
7bc7785b0d
141 changed files with 8252 additions and 4032 deletions
5
llama_toolchain/core/__init__.py
Normal file
5
llama_toolchain/core/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
128
llama_toolchain/core/build_conda_env.sh
Executable file
128
llama_toolchain/core/build_conda_env.sh
Executable file
|
@ -0,0 +1,128 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 3 ]; then
|
||||
echo "Usage: $0 <distribution_id> <build_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 <distribution_id> mybuild 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
distribution_id="$1"
|
||||
build_name="$2"
|
||||
env_name="llamastack-$build_name"
|
||||
pip_dependencies="$3"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local python_version="3.10"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! command -v conda &>/dev/null; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if the environment exists
|
||||
if conda env list | grep -q "^${env_name} "; then
|
||||
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
|
||||
|
||||
# Check Python version in the environment
|
||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||
|
||||
if [ "$current_version" = "$python_version" ]; then
|
||||
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
|
||||
else
|
||||
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
|
||||
conda install -n "${env_name}" python="${python_version}" -y
|
||||
fi
|
||||
else
|
||||
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
# setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "${env_name}"
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
pip install fastapi libcst
|
||||
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies
|
||||
else
|
||||
# Re-installing llama-toolchain in the new conda environment
|
||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n"
|
||||
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
|
||||
else
|
||||
pip install --no-cache-dir llama-toolchain
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||
pip uninstall -y llama-models
|
||||
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
printf "Installing pip dependencies: $pip_dependencies\n"
|
||||
pip install $pip_dependencies
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||
|
||||
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}\n"
|
||||
|
||||
if [ "$distribution_id" = "adhoc" ]; then
|
||||
subcommand="api"
|
||||
target=""
|
||||
else
|
||||
subcommand="stack"
|
||||
target="$distribution_id"
|
||||
fi
|
||||
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type conda_env
|
120
llama_toolchain/core/build_container.sh
Executable file
120
llama_toolchain/core/build_container.sh
Executable file
|
@ -0,0 +1,120 @@
|
|||
#!/bin/bash
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ "$#" -ne 4 ]; then
|
||||
echo "Usage: $0 <distribution_id> <build_name> <docker_base> <pip_dependencies>
|
||||
echo "Example: $0 distribution_id my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
distribution_id=$1
|
||||
build_name="$2"
|
||||
image_name="llamastack-$build_name"
|
||||
docker_base=$3
|
||||
pip_dependencies=$4
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
add_to_docker() {
|
||||
local input
|
||||
output_file="$TEMP_DIR/Dockerfile"
|
||||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
else
|
||||
# If stdin is not a terminal, read from it (heredoc)
|
||||
cat >>"$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
add_to_docker <<EOF
|
||||
FROM $docker_base
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||
curl wget telnet \
|
||||
procps psmisc lsof \
|
||||
traceroute \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
EOF
|
||||
|
||||
toolchain_mount="/app/llama-toolchain-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
|
||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
add_to_docker "RUN pip install $toolchain_mount"
|
||||
else
|
||||
add_to_docker "RUN pip install llama-toolchain"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
RUN pip uninstall -y llama-models
|
||||
RUN pip install $models_mount
|
||||
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install $pip_dependencies"
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
|
||||
# This would be good in production but for debugging flexibility lets not add it right now
|
||||
# We need a more solid production ready entrypoint.sh anyway
|
||||
#
|
||||
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
|
||||
|
||||
EOF
|
||||
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||
cat $TEMP_DIR/Dockerfile
|
||||
printf "\n"
|
||||
|
||||
mounts=""
|
||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
fi
|
||||
set -x
|
||||
podman build -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
set +x
|
||||
|
||||
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
|
||||
echo "You can run it with: podman run -p 8000:8000 $image_name"
|
||||
|
||||
if [ "$distribution_id" = "adhoc" ]; then
|
||||
subcommand="api"
|
||||
target=""
|
||||
else
|
||||
subcommand="stack"
|
||||
target="$distribution_id"
|
||||
fi
|
||||
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --name "$build_name" --type container
|
40
llama_toolchain/core/common.sh
Executable file
40
llama_toolchain/core/common.sh
Executable file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
envname="$1"
|
||||
|
||||
set +x
|
||||
echo "Cleaning up..."
|
||||
conda deactivate
|
||||
conda env remove --name $envname -y
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
||||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
setup_cleanup_handlers() {
|
||||
trap handle_int INT
|
||||
trap handle_exit EXIT
|
||||
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
|
||||
conda deactivate
|
||||
}
|
50
llama_toolchain/core/configure.py
Normal file
50
llama_toolchain/core/configure.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||
from llama_toolchain.core.distribution import api_providers
|
||||
from llama_toolchain.core.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
||||
all_providers = api_providers()
|
||||
|
||||
provider_configs = {}
|
||||
for api_str, stub_config in existing_configs.items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = stub_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers[provider_id]
|
||||
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
try:
|
||||
existing_provider_config = config_type(**stub_config)
|
||||
except Exception:
|
||||
existing_provider_config = None
|
||||
|
||||
provider_config = prompt_for_config(
|
||||
config_type,
|
||||
existing_provider_config,
|
||||
)
|
||||
print("")
|
||||
|
||||
provider_configs[api_str] = {
|
||||
"provider_id": provider_id,
|
||||
**provider_config.dict(),
|
||||
}
|
||||
|
||||
return provider_configs
|
190
llama_toolchain/core/datatypes.py
Normal file
190
llama_toolchain/core/datatypes.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agentic_system = "agentic_system"
|
||||
memory = "memory"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ApiEndpoint(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
provider_id: str
|
||||
config_class: str = Field(
|
||||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
|
||||
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
|
||||
""",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
url: str = Field(..., description="The URL for the provider")
|
||||
|
||||
@validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, url: str) -> str:
|
||||
if not url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {url}")
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
return f"remote::{adapter_id}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def docker_image(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return f"llama_toolchain.{self.api.value}.client"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
distribution_id: str
|
||||
description: str
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
providers: Dict[Api, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Provider IDs for each of the APIs provided by this distribution",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PackageConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
||||
package_name: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
distribution_id: Optional[str] = None
|
||||
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the docker image if this package refers to a container",
|
||||
)
|
||||
conda_env: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
providers: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
||||
the dependencies of these providers as well.
|
||||
""",
|
||||
)
|
101
llama_toolchain/core/distribution.py
Normal file
101
llama_toolchain/core/distribution.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.memory.providers import available_memory_providers
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
|
||||
from .datatypes import (
|
||||
Api,
|
||||
ApiEndpoint,
|
||||
DistributionSpec,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-toolchain` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"fastapi",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
||||
# only consider InlineProviderSpecs when calculating dependencies
|
||||
return [
|
||||
dep
|
||||
for provider_spec in distribution.provider_specs.values()
|
||||
if isinstance(provider_spec, InlineProviderSpec)
|
||||
for dep in provider_spec.pip_packages
|
||||
] + SERVER_DEPENDENCIES
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
|
||||
|
||||
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = {
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
Api.memory: Memory,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
inference_providers_by_id = {
|
||||
a.provider_id: a for a in available_inference_providers()
|
||||
}
|
||||
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
||||
agentic_system_providers_by_id = {
|
||||
a.provider_id: a for a in available_agentic_system_providers()
|
||||
}
|
||||
|
||||
ret = {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||
}
|
||||
for k, v in ret.items():
|
||||
v["remote"] = remote_provider_spec(k)
|
||||
return ret
|
69
llama_toolchain/core/distribution_registry.py
Normal file
69
llama_toolchain/core/distribution_registry.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .datatypes import * # noqa: F403
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def available_distribution_specs() -> List[DistributionSpec]:
|
||||
return [
|
||||
DistributionSpec(
|
||||
distribution_id="local",
|
||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||
providers={
|
||||
Api.inference: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
providers={x: "remote" for x in Api},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-ollama",
|
||||
description="Like local, but use ollama for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("ollama"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-plus-fireworks-inference",
|
||||
description="Use Fireworks.ai for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("fireworks"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-plus-together-inference",
|
||||
description="Use Together.ai for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("together"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
|
||||
for spec in available_distribution_specs():
|
||||
if spec.distribution_id == distribution_id:
|
||||
return spec
|
||||
return None
|
42
llama_toolchain/core/dynamic.py
Normal file
42
llama_toolchain/core/dynamic.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import ProviderSpec, RemoteProviderSpec
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
provider_config: Dict[str, Any],
|
||||
deps: Dict[str, ProviderSpec],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
config = config_type(**provider_config)
|
||||
fn = getattr(module, method)
|
||||
impl = asyncio.run(fn(config, deps))
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
149
llama_toolchain/core/package.py
Normal file
149
llama_toolchain/core/package.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_toolchain.common.exec import run_with_pty
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES
|
||||
|
||||
|
||||
class BuildType(Enum):
|
||||
container = "container"
|
||||
conda_env = "conda_env"
|
||||
|
||||
def descriptor(self) -> str:
|
||||
return "docker" if self == self.container else "conda"
|
||||
|
||||
|
||||
class Dependencies(BaseModel):
|
||||
pip_packages: List[str]
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
||||
|
||||
def build_package(
|
||||
api_inputs: List[ApiInput],
|
||||
build_type: BuildType,
|
||||
name: str,
|
||||
distribution_id: Optional[str] = None,
|
||||
docker_image: Optional[str] = None,
|
||||
):
|
||||
if not distribution_id:
|
||||
distribution_id = "adhoc"
|
||||
|
||||
build_dir = BUILDS_BASE_DIR / distribution_id / build_type.descriptor()
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
|
||||
package_name = name.replace("::", "-")
|
||||
package_file = build_dir / f"{package_name}.yaml"
|
||||
|
||||
all_providers = api_providers()
|
||||
|
||||
package_deps = Dependencies(
|
||||
docker_image=docker_image or "python:3.10-slim",
|
||||
pip_packages=SERVER_DEPENDENCIES,
|
||||
)
|
||||
|
||||
stub_config = {}
|
||||
for api_input in api_inputs:
|
||||
api = api_input.api
|
||||
providers_for_api = all_providers[api]
|
||||
if api_input.provider not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{api_input.provider}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider = providers_for_api[api_input.provider]
|
||||
package_deps.pip_packages.extend(provider.pip_packages)
|
||||
if provider.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
stub_config[api.value] = {"provider_id": api_input.provider}
|
||||
|
||||
if package_file.exists():
|
||||
cprint(
|
||||
f"Build `{package_name}` exists; will reconfigure",
|
||||
color="yellow",
|
||||
)
|
||||
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
|
||||
for api_str, new_config in stub_config.items():
|
||||
if api_str not in c.providers:
|
||||
c.providers[api_str] = new_config
|
||||
else:
|
||||
existing_config = c.providers[api_str]
|
||||
if existing_config["provider_id"] != new_config["provider_id"]:
|
||||
cprint(
|
||||
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
|
||||
color="yellow",
|
||||
)
|
||||
c.providers[api_str] = new_config
|
||||
else:
|
||||
c = PackageConfig(
|
||||
built_at=datetime.now(),
|
||||
package_name=package_name,
|
||||
providers=stub_config,
|
||||
)
|
||||
|
||||
c.distribution_id = distribution_id
|
||||
c.docker_image = package_name if build_type == BuildType.container else None
|
||||
c.conda_env = package_name if build_type == BuildType.conda_env else None
|
||||
|
||||
with open(package_file, "w") as f:
|
||||
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
if build_type == BuildType.container:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "core/build_container.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
distribution_id,
|
||||
package_name,
|
||||
package_deps.docker_image,
|
||||
" ".join(package_deps.pip_packages),
|
||||
]
|
||||
else:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "core/build_conda_env.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
distribution_id,
|
||||
package_name,
|
||||
" ".join(package_deps.pip_packages),
|
||||
]
|
||||
|
||||
return_code = run_with_pty(args)
|
||||
if return_code != 0:
|
||||
cprint(
|
||||
f"Failed to build target {package_name} with return code {return_code}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
cprint(
|
||||
f"Target `{package_name}` built with configuration at {str(package_file)}",
|
||||
color="green",
|
||||
)
|
345
llama_toolchain/core/server.py
Normal file
345
llama_toolchain/core/server.py
Normal file
|
@ -0,0 +1,345 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
get_type_hints,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
||||
)
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ValidationError):
|
||||
return RequestValidationError(exc.raw_errors)
|
||||
|
||||
# Add more custom exception translations here
|
||||
return HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
)
|
||||
response = await client.send(req, stream=True)
|
||||
|
||||
async def stream_response():
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
loop.stop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
by_id = {x.api: x for x in providers}
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def resolve_impls(
|
||||
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
||||
) -> Dict[Api, Any]:
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(provider_specs.values())
|
||||
|
||||
impls = {}
|
||||
for provider_spec in provider_specs:
|
||||
api = provider_spec.api
|
||||
if api.value not in provider_configs:
|
||||
raise ValueError(
|
||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
)
|
||||
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
else:
|
||||
deps = {}
|
||||
provider_config = provider_configs[api.value]
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
all_providers = api_providers()
|
||||
|
||||
provider_specs = {}
|
||||
for api_str, provider_config in config["providers"].items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = provider_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider_specs[api] = providers[provider_id]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
|
||||
for provider_spec in provider_specs.values():
|
||||
api = provider_spec.api
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
cprint(
|
||||
f"Serving {next(iter(route.methods))} {route.path}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
42
llama_toolchain/core/start_conda_env.sh
Executable file
42
llama_toolchain/core/start_conda_env.sh
Executable file
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "$env_name"
|
||||
|
||||
$CONDA_PREFIX/bin/python \
|
||||
-m llama_toolchain.core.server \
|
||||
--yaml_config "$yaml_config" \
|
||||
--port "$port" "$@"
|
43
llama_toolchain/core/start_container.sh
Executable file
43
llama_toolchain/core/start_container.sh
Executable file
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build_name="$1"
|
||||
docker_image="llamastack-$build_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
set -x
|
||||
podman run -it \
|
||||
-p $port:$port \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$docker_image \
|
||||
python -m llama_toolchain.core.server \
|
||||
--yaml_config /app/config.yaml \
|
||||
--port $port "$@"
|
Loading…
Add table
Add a link
Reference in a new issue