mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 19:34:19 +00:00
API Updates (#73)
* API Keys passed from Client instead of distro configuration * delete distribution registry * Rename the "package" word away * Introduce a "Router" layer for providers Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose. * update `apis_to_serve` * llama_toolchain -> llama_stack * Codemod from llama_toolchain -> llama_stack - added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent * Moved some stuff out of common/; re-generated OpenAPI spec * llama-toolchain -> llama-stack (hyphens) * add control plane API * add redis adapter + sqlite provider * move core -> distribution * Some more toolchain -> stack changes * small naming shenanigans * Removing custom tool and agent utilities and moving them client side * Move control plane to distribution server for now * Remove control plane from API list * no codeshield dependency randomly plzzzzz * Add "fire" as a dependency * add back event loggers * stack configure fixes * use brave instead of bing in the example client * add init file so it gets packaged * add init files so it gets packaged * Update MANIFEST * bug fix --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Xi Yan <xiyan@meta.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
f294eac5f5
commit
9487ad8294
213 changed files with 1725 additions and 1204 deletions
5
llama_stack/distribution/__init__.py
Normal file
5
llama_stack/distribution/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
96
llama_stack/distribution/build.py
Normal file
96
llama_stack/distribution/build.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import pkg_resources
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
docker = "docker"
|
||||
conda = "conda"
|
||||
|
||||
|
||||
class Dependencies(BaseModel):
|
||||
pip_packages: List[str]
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
||||
|
||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||
package_deps = Dependencies(
|
||||
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
||||
pip_packages=SERVER_DEPENDENCIES,
|
||||
)
|
||||
|
||||
# extend package dependencies based on providers spec
|
||||
all_providers = api_providers()
|
||||
for (
|
||||
api_str,
|
||||
provider_or_providers,
|
||||
) in build_config.distribution_spec.providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
|
||||
providers = (
|
||||
provider_or_providers
|
||||
if isinstance(provider_or_providers, list)
|
||||
else [provider_or_providers]
|
||||
)
|
||||
|
||||
for provider in providers:
|
||||
if provider not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers_for_api[provider]
|
||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||
if provider_spec.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/build_container.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
build_config.name,
|
||||
package_deps.docker_image,
|
||||
str(build_file_path),
|
||||
" ".join(package_deps.pip_packages),
|
||||
]
|
||||
else:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/build_conda_env.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
build_config.name,
|
||||
" ".join(package_deps.pip_packages),
|
||||
]
|
||||
|
||||
return_code = run_with_pty(args)
|
||||
if return_code != 0:
|
||||
cprint(
|
||||
f"Failed to build target {build_config.name} with return code {return_code}",
|
||||
color="red",
|
||||
)
|
||||
return
|
115
llama_stack/distribution/build_conda_env.sh
Executable file
115
llama_stack/distribution/build_conda_env.sh
Executable file
|
@ -0,0 +1,115 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
pip_dependencies="$2"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local python_version="3.10"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! command -v conda &>/dev/null; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if the environment exists
|
||||
if conda env list | grep -q "^${env_name} "; then
|
||||
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
|
||||
|
||||
# Check Python version in the environment
|
||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||
|
||||
if [ "$current_version" = "$python_version" ]; then
|
||||
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
|
||||
else
|
||||
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
|
||||
conda install -n "${env_name}" python="${python_version}" -y
|
||||
fi
|
||||
else
|
||||
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
# setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "${env_name}"
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
pip install fastapi libcst
|
||||
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
|
||||
else
|
||||
# Re-installing llama-stack in the new conda environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||
pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
pip install --no-cache-dir llama-stack
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||
pip uninstall -y llama-models
|
||||
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
printf "Installing pip dependencies: $pip_dependencies\n"
|
||||
pip install $pip_dependencies
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
117
llama_stack/distribution/build_container.sh
Executable file
117
llama_stack/distribution/build_container.sh
Executable file
|
@ -0,0 +1,117 @@
|
|||
#!/bin/bash
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ "$#" -ne 4 ]; then
|
||||
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies>
|
||||
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build_name="$1"
|
||||
image_name="llamastack-$build_name"
|
||||
docker_base=$2
|
||||
build_file_path=$3
|
||||
pip_dependencies=$4
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
||||
DOCKER_OPTS=${DOCKER_OPTS:-}
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
add_to_docker() {
|
||||
local input
|
||||
output_file="$TEMP_DIR/Dockerfile"
|
||||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
else
|
||||
# If stdin is not a terminal, read from it (heredoc)
|
||||
cat >>"$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
add_to_docker <<EOF
|
||||
FROM $docker_base
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||
curl wget telnet \
|
||||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
EOF
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
add_to_docker "RUN pip install $stack_mount"
|
||||
else
|
||||
add_to_docker "RUN pip install llama-stack"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
RUN pip uninstall -y llama-models
|
||||
RUN pip install $models_mount
|
||||
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install $pip_dependencies"
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
|
||||
# This would be good in production but for debugging flexibility lets not add it right now
|
||||
# We need a more solid production ready entrypoint.sh anyway
|
||||
#
|
||||
# ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
||||
|
||||
EOF
|
||||
|
||||
add_to_docker "ADD $build_file_path ./llamastack-build.yaml"
|
||||
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||
cat $TEMP_DIR/Dockerfile
|
||||
printf "\n"
|
||||
|
||||
mounts=""
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
fi
|
||||
set -x
|
||||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
set +x
|
||||
|
||||
echo "You can run it with: podman run -p 8000:8000 $image_name"
|
||||
|
||||
echo "Checking image builds..."
|
||||
podman run -it $image_name cat llamastack-build.yaml
|
40
llama_stack/distribution/common.sh
Executable file
40
llama_stack/distribution/common.sh
Executable file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
envname="$1"
|
||||
|
||||
set +x
|
||||
echo "Cleaning up..."
|
||||
conda deactivate
|
||||
conda env remove --name $envname -y
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
||||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
setup_cleanup_handlers() {
|
||||
trap handle_int INT
|
||||
trap handle_exit EXIT
|
||||
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
|
||||
conda deactivate
|
||||
}
|
110
llama_stack/distribution/configure.py
Normal file
110
llama_stack/distribution/configure.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.distribution import api_providers, stack_apis
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
|
||||
|
||||
# These are hacks so we can re-use the `prompt_for_config` utility
|
||||
# This needs a bunch of work to be made very user friendly.
|
||||
class ReqApis(BaseModel):
|
||||
apis_to_serve: List[str]
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
|
||||
return BaseModelWithConfig
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
|
||||
print("Enter comma-separated list of APIs to serve:")
|
||||
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
apis = [a for a in apis if a != "telemetry"]
|
||||
req_apis = ReqApis(
|
||||
apis_to_serve=apis,
|
||||
)
|
||||
req_apis = prompt_for_config(ReqApis, req_apis)
|
||||
config.apis_to_serve = req_apis.apis_to_serve
|
||||
print("")
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = api_providers()
|
||||
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
provider_or_providers = spec.providers[api_str]
|
||||
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1:
|
||||
print(
|
||||
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
|
||||
)
|
||||
|
||||
routing_entries = []
|
||||
for p in provider_or_providers:
|
||||
print(f"Configuring provider `{p}`...")
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
# TODO: we need to validate the routing keys, and
|
||||
# perhaps it is better if we break this out into asking
|
||||
# for a routing key separately from the associated config
|
||||
wrapper_type = make_routing_entry_type(config_type)
|
||||
rt_entry = prompt_for_config(wrapper_type, None)
|
||||
|
||||
routing_entries.append(
|
||||
ProviderRoutingEntry(
|
||||
provider_id=p,
|
||||
routing_key=rt_entry.routing_key,
|
||||
config=rt_entry.config.dict(),
|
||||
)
|
||||
)
|
||||
config.provider_map[api_str] = routing_entries
|
||||
else:
|
||||
p = (
|
||||
provider_or_providers[0]
|
||||
if isinstance(provider_or_providers, list)
|
||||
else provider_or_providers
|
||||
)
|
||||
print(f"Configuring provider `{p}`...")
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.provider_map.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
config.provider_map[api_str] = GenericProviderConfig(
|
||||
provider_id=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
return config
|
31
llama_stack/distribution/configure_container.sh
Executable file
31
llama_stack/distribution/configure_container.sh
Executable file
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "Usage: $0 <container name> <build file path>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker_image="$1"
|
||||
host_build_dir="$2"
|
||||
container_build_dir="/app/builds"
|
||||
|
||||
set -x
|
||||
podman run -it \
|
||||
-v $host_build_dir:$container_build_dir \
|
||||
$docker_image \
|
||||
llama stack configure ./llamastack-build.yaml --output-dir $container_build_dir
|
7
llama_stack/distribution/control_plane/__init__.py
Normal file
7
llama_stack/distribution/control_plane/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .control_plane import * # noqa: F401 F403
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import RedisImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RedisImplConfig, _deps):
|
||||
from .redis import RedisControlPlaneAdapter
|
||||
|
||||
impl = RedisControlPlaneAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RedisImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
description="The URL for the Redis server",
|
||||
)
|
||||
namespace: Optional[str] = Field(
|
||||
default=None,
|
||||
description="All keys will be prefixed with this namespace",
|
||||
)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from llama_stack.apis.control_plane import * # noqa: F403
|
||||
|
||||
|
||||
from .config import RedisImplConfig
|
||||
|
||||
|
||||
class RedisControlPlaneAdapter(ControlPlane):
|
||||
def __init__(self, config: RedisImplConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.redis = Redis.from_url(self.config.url)
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> Optional[ControlPlaneValue]:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
ttl = await self.redis.ttl(key)
|
||||
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||
return ControlPlaneValue(key=key, value=value, expiration=expiration)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
keys = await self.redis.keys(f"{start_key}*")
|
||||
result = []
|
||||
for key in keys:
|
||||
if key <= end_key:
|
||||
value = await self.get(key)
|
||||
if value:
|
||||
result.append(value)
|
||||
return result
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SqliteControlPlaneConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
|
||||
from .control_plane import SqliteControlPlane
|
||||
|
||||
impl = SqliteControlPlane(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SqliteControlPlaneConfig(BaseModel):
|
||||
db_path: str = Field(
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
table_name: str = Field(
|
||||
default="llamastack_control_plane",
|
||||
description="Table into which all the keys will be placed",
|
||||
)
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.control_plane import * # noqa: F403
|
||||
|
||||
|
||||
from .config import SqliteControlPlaneConfig
|
||||
|
||||
|
||||
class SqliteControlPlane(ControlPlane):
|
||||
def __init__(self, config: SqliteControlPlaneConfig):
|
||||
self.db_path = config.db_path
|
||||
self.table_name = config.table_name
|
||||
|
||||
async def initialize(self):
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, json.dumps(value), expiration),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> Optional[ControlPlaneValue]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
return ControlPlaneValue(
|
||||
key=key, value=json.loads(value), expiration=expiration
|
||||
)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
key, value, expiration = row
|
||||
result.append(
|
||||
ControlPlaneValue(
|
||||
key=key, value=json.loads(value), expiration=expiration
|
||||
)
|
||||
)
|
||||
return result
|
35
llama_stack/distribution/control_plane/api.py
Normal file
35
llama_stack/distribution/control_plane/api.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ControlPlaneValue(BaseModel):
|
||||
key: str
|
||||
value: Any
|
||||
expiration: Optional[datetime] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ControlPlane(Protocol):
|
||||
@webmethod(route="/control_plane/set")
|
||||
async def set(
|
||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/control_plane/get", method="GET")
|
||||
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
|
||||
|
||||
@webmethod(route="/control_plane/delete")
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
@webmethod(route="/control_plane/range", method="GET")
|
||||
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...
|
29
llama_stack/distribution/control_plane/registry.py
Normal file
29
llama_stack/distribution/control_plane/registry.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.control_plane,
|
||||
provider_id="sqlite",
|
||||
pip_packages=["aiosqlite"],
|
||||
module="llama_stack.providers.impls.sqlite.control_plane",
|
||||
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.control_plane,
|
||||
AdapterSpec(
|
||||
adapter_id="redis",
|
||||
pip_packages=["redis"],
|
||||
module="llama_stack.providers.adapters.control_plane.redis",
|
||||
),
|
||||
),
|
||||
]
|
250
llama_stack/distribution/datatypes.py
Normal file
250
llama_stack/distribution/datatypes.py
Normal file
|
@ -0,0 +1,250 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
memory = "memory"
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ApiEndpoint(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
provider_id: str
|
||||
config_class: str = Field(
|
||||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouterProviderSpec(ProviderSpec):
|
||||
provider_id: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
|
||||
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
|
||||
""",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
url: str = Field(..., description="The URL for the provider")
|
||||
|
||||
@validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, url: str) -> str:
|
||||
if not url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {url}")
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
return f"remote::{adapter_id}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def docker_image(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return f"llama_stack.apis.{self.api.value}.client"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
docker_image: Optional[str] = None
|
||||
providers: Dict[str, Union[str, List[str]]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderRoutingEntry(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StackRunConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
||||
image_name: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the docker image if this package refers to a container",
|
||||
)
|
||||
conda_env: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
apis_to_serve: List[str] = Field(
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
provider_map: Dict[str, ProviderMapEntry] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
|
||||
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
||||
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
||||
|
||||
As examples:
|
||||
- the "inference" API interprets the routing_key as a "model"
|
||||
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
||||
|
||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuildConfig(BaseModel):
|
||||
name: str
|
||||
distribution_spec: DistributionSpec = Field(
|
||||
description="The distribution spec to build including API providers. "
|
||||
)
|
||||
image_type: str = Field(
|
||||
default="conda",
|
||||
description="Type of package to build (conda | container)",
|
||||
)
|
77
llama_stack/distribution/distribution.py
Normal file
77
llama_stack/distribution/distribution.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"fastapi",
|
||||
"fire",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [v for v in Api]
|
||||
|
||||
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = {
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
for api in stack_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_id: a for a in module.available_providers()},
|
||||
}
|
||||
|
||||
return ret
|
|
@ -0,0 +1,10 @@
|
|||
name: local-conda-example
|
||||
distribution_spec:
|
||||
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||
providers:
|
||||
inference: meta-reference
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
|
@ -0,0 +1,10 @@
|
|||
name: local-fireworks-conda-example
|
||||
distribution_spec:
|
||||
description: Use Fireworks.ai for running LLM inference
|
||||
providers:
|
||||
inference: remote::fireworks
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
|
@ -0,0 +1,10 @@
|
|||
name: local-ollama-conda-example
|
||||
distribution_spec:
|
||||
description: Like local, but use ollama for running LLM inference
|
||||
providers:
|
||||
inference: remote::ollama
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
|
@ -0,0 +1,10 @@
|
|||
name: local-tgi-conda-example
|
||||
distribution_spec:
|
||||
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint).
|
||||
providers:
|
||||
inference: remote::tgi
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
|
@ -0,0 +1,10 @@
|
|||
name: local-tgi-conda-example
|
||||
distribution_spec:
|
||||
description: Use Together.ai for running LLM inference
|
||||
providers:
|
||||
inference: remote::together
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
|
@ -0,0 +1,10 @@
|
|||
name: local-docker-example
|
||||
distribution_spec:
|
||||
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||
providers:
|
||||
inference: meta-reference
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: docker
|
5
llama_stack/distribution/server/__init__.py
Normal file
5
llama_stack/distribution/server/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
392
llama_stack/distribution/server/server.py
Normal file
392
llama_stack/distribution/server/server.py
Normal file
|
@ -0,0 +1,392 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
get_type_hints,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
||||
)
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ValidationError):
|
||||
return RequestValidationError(exc.raw_errors)
|
||||
|
||||
# Add more custom exception translations here
|
||||
return HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
erred = False
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
)
|
||||
response = await client.send(req, stream=True)
|
||||
|
||||
async def stream_response():
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
erred = True
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
erred = True
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
erred = True
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
erred = True
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
erred = True
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
erred = True
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
finally:
|
||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
loop.stop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
await start_trace(func.__name__)
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
by_id = {x.api: x for x in providers}
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def snake_to_camel(snake_str):
|
||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
|
||||
async def resolve_impls(
|
||||
provider_map: Dict[str, ProviderMapEntry],
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = api_providers()
|
||||
|
||||
specs = {}
|
||||
for api_str, item in provider_map.items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
|
||||
if isinstance(item, GenericProviderConfig):
|
||||
if item.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
else:
|
||||
assert isinstance(item, list)
|
||||
inner_specs = []
|
||||
for rt_entry in item:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||
impls[api] = impl
|
||||
|
||||
return impls, specs
|
||||
|
||||
|
||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
cprint(
|
||||
f"Serving {next(iter(route.methods))} {route.path}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
42
llama_stack/distribution/start_conda_env.sh
Executable file
42
llama_stack/distribution/start_conda_env.sh
Executable file
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "$env_name"
|
||||
|
||||
$CONDA_PREFIX/bin/python \
|
||||
-m llama_stack.distribution.server.server \
|
||||
--yaml_config "$yaml_config" \
|
||||
--port "$port" "$@"
|
43
llama_stack/distribution/start_container.sh
Executable file
43
llama_stack/distribution/start_container.sh
Executable file
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build_name="$1"
|
||||
docker_image="llamastack-$build_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
set -x
|
||||
podman run -it \
|
||||
-p $port:$port \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$docker_image \
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml_config /app/config.yaml \
|
||||
--port $port "$@"
|
5
llama_stack/distribution/utils/__init__.py
Normal file
5
llama_stack/distribution/utils/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
17
llama_stack/distribution/utils/config_dirs.py
Normal file
17
llama_stack/distribution/utils/config_dirs.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
|
||||
|
||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
66
llama_stack/distribution/utils/dynamic.py
Normal file
66
llama_stack/distribution/utils/dynamic.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: ProviderMapEntry,
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, RouterProviderSpec):
|
||||
method = "get_router_impl"
|
||||
|
||||
assert isinstance(provider_config, list)
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in provider_config:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
105
llama_stack/distribution/utils/exec.py
Normal file
105
llama_stack/distribution/utils/exec.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
import os
|
||||
import pty
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import termios
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal, with interrupt handling,
|
||||
# useful when you want to run interactive things
|
||||
def run_with_pty(command):
|
||||
master, slave = pty.openpty()
|
||||
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
ctrl_c_pressed = True
|
||||
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
|
||||
|
||||
try:
|
||||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
new_settings = termios.tcgetattr(sys.stdin)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdin=slave,
|
||||
stdout=slave,
|
||||
stderr=slave,
|
||||
universal_newlines=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
|
||||
# Close the slave file descriptor as it's now owned by the subprocess
|
||||
os.close(slave)
|
||||
|
||||
def handle_io():
|
||||
while not ctrl_c_pressed:
|
||||
try:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(master, data)
|
||||
|
||||
if master in rlist:
|
||||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This will be raised when Ctrl+C is pressed
|
||||
break
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
handle_io()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO:
|
||||
raise
|
||||
finally:
|
||||
# Clean up
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
os.close(master)
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
return process.returncode
|
||||
|
||||
|
||||
def run_command(command):
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
if process.returncode != 0:
|
||||
print(f"Error: {error.decode('utf-8')}")
|
||||
sys.exit(1)
|
||||
return output.decode("utf-8")
|
13
llama_stack/distribution/utils/model_utils.py
Normal file
13
llama_stack/distribution/utils/model_utils.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def model_local_dir(descriptor: str) -> str:
|
||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
309
llama_stack/distribution/utils/prompt_for_config.py
Normal file
309
llama_stack/distribution/utils/prompt_for_config.py
Normal file
|
@ -0,0 +1,309 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
def is_list_of_primitives(field_type):
|
||||
"""Check if a field type is a List of primitive types."""
|
||||
origin = get_origin(field_type)
|
||||
if origin is List or origin is list:
|
||||
args = get_args(field_type)
|
||||
if len(args) == 1 and args[0] in (int, float, str, bool):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_basemodel_without_fields(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
)
|
||||
|
||||
|
||||
def can_recurse(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
)
|
||||
|
||||
|
||||
def get_literal_values(field):
|
||||
"""Extract literal values from a field if it's a Literal type."""
|
||||
if get_origin(field.annotation) is Literal:
|
||||
return get_args(field.annotation)
|
||||
return None
|
||||
|
||||
|
||||
def is_optional(field_type):
|
||||
"""Check if a field type is Optional."""
|
||||
return get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||
|
||||
|
||||
def get_non_none_type(field_type):
|
||||
"""Get the non-None type from an Optional type."""
|
||||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
|
||||
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
||||
validators = model.__pydantic_decorators__.field_validators
|
||||
for _name, validator in validators.items():
|
||||
if field_name in validator.info.fields:
|
||||
validator.func(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def is_discriminated_union(typ) -> bool:
|
||||
if isinstance(typ, FieldInfo):
|
||||
return typ.discriminator
|
||||
else:
|
||||
if not (get_origin(typ) is Annotated):
|
||||
return False
|
||||
args = get_args(typ)
|
||||
return len(args) >= 2 and args[1].discriminator
|
||||
|
||||
|
||||
def prompt_for_discriminated_union(
|
||||
field_name,
|
||||
typ,
|
||||
existing_value,
|
||||
):
|
||||
if isinstance(typ, FieldInfo):
|
||||
inner_type = typ.annotation
|
||||
discriminator = typ.discriminator
|
||||
else:
|
||||
args = get_args(typ)
|
||||
inner_type = args[0]
|
||||
discriminator = args[1].discriminator
|
||||
|
||||
union_types = get_args(inner_type)
|
||||
# Find the discriminator field in each union type
|
||||
type_map = {}
|
||||
for t in union_types:
|
||||
disc_field = t.__fields__[discriminator]
|
||||
literal_values = get_literal_values(disc_field)
|
||||
if literal_values:
|
||||
for value in literal_values:
|
||||
type_map[value] = t
|
||||
|
||||
while True:
|
||||
discriminator_value = input(
|
||||
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
|
||||
)
|
||||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator) != discriminator_value
|
||||
):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
# Set the discriminator field in the sub-config
|
||||
setattr(sub_config, discriminator, discriminator_value)
|
||||
return sub_config
|
||||
else:
|
||||
print(f"Invalid {discriminator}. Please try again.")
|
||||
|
||||
|
||||
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
||||
# We should add handling for the most common cases to tide us over.
|
||||
#
|
||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||
# unit tests for coverage.
|
||||
def prompt_for_config(
|
||||
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||
|
||||
Args:
|
||||
config_type: A Pydantic BaseModel class representing the configuration structure.
|
||||
|
||||
Returns:
|
||||
An instance of the config_type with user-provided values.
|
||||
"""
|
||||
config_data = {}
|
||||
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
existing_value = (
|
||||
getattr(existing_config, field_name) if existing_config else None
|
||||
)
|
||||
if existing_value:
|
||||
default_value = existing_value
|
||||
else:
|
||||
default_value = (
|
||||
field.default
|
||||
if not isinstance(field.default, PydanticUndefinedType)
|
||||
else None
|
||||
)
|
||||
is_required = field.is_required
|
||||
|
||||
# Skip fields with Literal type
|
||||
if get_origin(field_type) is Literal:
|
||||
continue
|
||||
|
||||
# Skip fields with no type annotations
|
||||
if is_basemodel_without_fields(field_type):
|
||||
config_data[field_name] = field_type()
|
||||
continue
|
||||
|
||||
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
||||
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
||||
while True:
|
||||
# this branch does not handle existing and default values yet
|
||||
user_input = input(prompt + " ")
|
||||
try:
|
||||
value = field_type[user_input]
|
||||
validated_value = manually_validate_field(config_type, field, value)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except KeyError:
|
||||
print(
|
||||
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if is_discriminated_union(field):
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name, field, existing_value
|
||||
)
|
||||
continue
|
||||
|
||||
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
continue
|
||||
nested_type = get_non_none_type(field_type)
|
||||
print(f"Entering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||
elif is_optional(field_type) and is_discriminated_union(
|
||||
get_non_none_type(field_type)
|
||||
):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
continue
|
||||
nested_type = get_non_none_type(field_type)
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name,
|
||||
nested_type,
|
||||
existing_value,
|
||||
)
|
||||
elif can_recurse(field_type):
|
||||
print(f"\nEntering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(
|
||||
field_type,
|
||||
existing_value,
|
||||
)
|
||||
else:
|
||||
prompt = f"Enter value for {field_name}"
|
||||
if existing_value is not None:
|
||||
prompt += f" (existing: {existing_value})"
|
||||
elif default_value is not None:
|
||||
prompt += f" (default: {default_value})"
|
||||
if is_optional(field_type):
|
||||
prompt += " (optional)"
|
||||
elif is_required:
|
||||
prompt += " (required)"
|
||||
prompt += ": "
|
||||
|
||||
while True:
|
||||
user_input = input(prompt)
|
||||
if user_input == "":
|
||||
if default_value is not None:
|
||||
config_data[field_name] = default_value
|
||||
break
|
||||
elif is_optional(field_type) or not is_required:
|
||||
config_data[field_name] = None
|
||||
break
|
||||
else:
|
||||
print("This field is required. Please provide a value.")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Handle Optional types
|
||||
if is_optional(field_type):
|
||||
if user_input.lower() == "none":
|
||||
value = None
|
||||
else:
|
||||
field_type = get_non_none_type(field_type)
|
||||
value = user_input
|
||||
|
||||
# Handle List of primitives
|
||||
elif is_list_of_primitives(field_type):
|
||||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded list"
|
||||
)
|
||||
element_type = get_args(field_type)[0]
|
||||
value = [element_type(item) for item in value]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
"Invalid JSON. Please enter a valid JSON-encoded list."
|
||||
)
|
||||
continue
|
||||
except ValueError as e:
|
||||
print(f"{str(e)}")
|
||||
continue
|
||||
|
||||
elif get_origin(field_type) is dict:
|
||||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded dictionary"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert the input to the correct type
|
||||
elif inspect.isclass(field_type) and issubclass(
|
||||
field_type, BaseModel
|
||||
):
|
||||
# For nested BaseModels, we assume a dictionary-like string input
|
||||
import ast
|
||||
|
||||
value = field_type(**ast.literal_eval(user_input))
|
||||
else:
|
||||
value = field_type(user_input)
|
||||
|
||||
except ValueError:
|
||||
print(
|
||||
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate the field using our manual validation function
|
||||
validated_value = manually_validate_field(
|
||||
config_type, field_name, value
|
||||
)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f"Validation error: {str(e)}")
|
||||
|
||||
return config_type(**config_data)
|
18
llama_stack/distribution/utils/serialize.py
Normal file
18
llama_stack/distribution/utils/serialize.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
Loading…
Add table
Add a link
Reference in a new issue