API Updates (#73)

* API Keys passed from Client instead of distro configuration

* delete distribution registry

* Rename the "package" word away

* Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.

* update `apis_to_serve`

* llama_toolchain -> llama_stack

* Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent

* Moved some stuff out of common/; re-generated OpenAPI spec

* llama-toolchain -> llama-stack (hyphens)

* add control plane API

* add redis adapter + sqlite provider

* move core -> distribution

* Some more toolchain -> stack changes

* small naming shenanigans

* Removing custom tool and agent utilities and moving them client side

* Move control plane to distribution server for now

* Remove control plane from API list

* no codeshield dependency randomly plzzzzz

* Add "fire" as a dependency

* add back event loggers

* stack configure fixes

* use brave instead of bing in the example client

* add init file so it gets packaged

* add init files so it gets packaged

* Update MANIFEST

* bug fix

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Xi Yan <xiyan@meta.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-17 19:51:35 -07:00 committed by GitHub
parent f294eac5f5
commit 9487ad8294
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 1725 additions and 1204 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Optional
import pkg_resources
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
class ImageType(Enum):
docker = "docker"
conda = "conda"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
class ApiInput(BaseModel):
api: Api
provider: str
def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
# extend package dependencies based on providers spec
all_providers = api_providers()
for (
api_str,
provider_or_providers,
) in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)]
providers = (
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
for provider in providers:
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_container.sh"
)
args = [
script,
build_config.name,
package_deps.docker_image,
str(build_file_path),
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_conda_env.sh"
)
args = [
script,
build_config.name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
)
return

View file

@ -0,0 +1,115 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
set -euo pipefail
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
exit 1
fi
build_name="$1"
env_name="llamastack-$build_name"
pip_dependencies="$2"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
conda install -n "${env_name}" python="${python_version}" -y
fi
else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
# setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
else
# Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
pip install --no-cache-dir llama-stack
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
pip uninstall -y llama-models
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi
# Install pip dependencies
if [ -n "$pip_dependencies" ]; then
printf "Installing pip dependencies: $pip_dependencies\n"
pip install $pip_dependencies
fi
fi
}
ensure_conda_env_python310 "$env_name" "$pip_dependencies"

View file

@ -0,0 +1,117 @@
#!/bin/bash
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies>
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn'
exit 1
fi
build_name="$1"
image_name="llamastack-$build_name"
docker_base=$2
build_file_path=$3
pip_dependencies=$4
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
set -euo pipefail
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
DOCKER_BINARY=${DOCKER_BINARY:-docker}
DOCKER_OPTS=${DOCKER_OPTS:-}
TEMP_DIR=$(mktemp -d)
add_to_docker() {
local input
output_file="$TEMP_DIR/Dockerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
else
# If stdin is not a terminal, read from it (heredoc)
cat >>"$output_file"
fi
}
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet \
procps psmisc lsof \
traceroute \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
EOF
stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
exit 1
fi
add_to_docker "RUN pip install $stack_mount"
else
add_to_docker "RUN pip install llama-stack"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
add_to_docker <<EOF
RUN pip uninstall -y llama-models
RUN pip install $models_mount
EOF
fi
if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies"
fi
add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
# ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
EOF
add_to_docker "ADD $build_file_path ./llamastack-build.yaml"
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
cat $TEMP_DIR/Dockerfile
printf "\n"
mounts=""
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
set -x
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
set +x
echo "You can run it with: podman run -p 8000:8000 $image_name"
echo "Checking image builds..."
podman run -it $image_name cat llamastack-build.yaml

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
envname="$1"
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name $envname -y
}
handle_int() {
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
fi
fi
}
setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
}

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import api_providers, stack_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
# These are hacks so we can re-use the `prompt_for_config` utility
# This needs a bunch of work to be made very user friendly.
class ReqApis(BaseModel):
apis_to_serve: List[str]
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
return BaseModelWithConfig
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
print("Enter comma-separated list of APIs to serve:")
apis = config.apis_to_serve or list(spec.providers.keys())
apis = [a for a in apis if a != "telemetry"]
req_apis = ReqApis(
apis_to_serve=apis,
)
req_apis = prompt_for_config(ReqApis, req_apis)
config.apis_to_serve = req_apis.apis_to_serve
print("")
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
for api_str in spec.providers.keys():
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
api = Api(api_str)
provider_or_providers = spec.providers[api_str]
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1:
print(
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
)
routing_entries = []
for p in provider_or_providers:
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
# TODO: we need to validate the routing keys, and
# perhaps it is better if we break this out into asking
# for a routing key separately from the associated config
wrapper_type = make_routing_entry_type(config_type)
rt_entry = prompt_for_config(wrapper_type, None)
routing_entries.append(
ProviderRoutingEntry(
provider_id=p,
routing_key=rt_entry.routing_key,
config=rt_entry.config.dict(),
)
)
config.provider_map[api_str] = routing_entries
else:
p = (
provider_or_providers[0]
if isinstance(provider_or_providers, list)
else provider_or_providers
)
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.provider_map.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
config.provider_map[api_str] = GenericProviderConfig(
provider_id=p,
config=cfg.dict(),
)
return config

View file

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <container name> <build file path>"
exit 1
fi
docker_image="$1"
host_build_dir="$2"
container_build_dir="/app/builds"
set -x
podman run -it \
-v $host_build_dir:$container_build_dir \
$docker_image \
llama stack configure ./llamastack-build.yaml --output-dir $container_build_dir

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .control_plane import * # noqa: F401 F403

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import RedisImplConfig
async def get_adapter_impl(config: RedisImplConfig, _deps):
from .redis import RedisControlPlaneAdapter
impl = RedisControlPlaneAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class RedisImplConfig(BaseModel):
url: str = Field(
description="The URL for the Redis server",
)
namespace: Optional[str] = Field(
default=None,
description="All keys will be prefixed with this namespace",
)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timedelta
from typing import Any, List, Optional
from redis.asyncio import Redis
from llama_stack.apis.control_plane import * # noqa: F403
from .config import RedisImplConfig
class RedisControlPlaneAdapter(ControlPlane):
def __init__(self, config: RedisImplConfig):
self.config = config
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[ControlPlaneValue]:
key = self._namespaced_key(key)
value = await self.redis.get(key)
if value is None:
return None
ttl = await self.redis.ttl(key)
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
return ControlPlaneValue(key=key, value=value, expiration=expiration)
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
keys = await self.redis.keys(f"{start_key}*")
result = []
for key in keys:
if key <= end_key:
value = await self.get(key)
if value:
result.append(value)
return result

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SqliteControlPlaneConfig
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
from .control_plane import SqliteControlPlane
impl = SqliteControlPlane(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SqliteControlPlaneConfig(BaseModel):
db_path: str = Field(
description="File path for the sqlite database",
)
table_name: str = Field(
default="llamastack_control_plane",
description="Table into which all the keys will be placed",
)

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import Any, List, Optional
import aiosqlite
from llama_stack.apis.control_plane import * # noqa: F403
from .config import SqliteControlPlaneConfig
class SqliteControlPlane(ControlPlane):
def __init__(self, config: SqliteControlPlaneConfig):
self.db_path = config.db_path
self.table_name = config.table_name
async def initialize(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await db.commit()
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, json.dumps(value), expiration),
)
await db.commit()
async def get(self, key: str) -> Optional[ControlPlaneValue]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
return ControlPlaneValue(
key=key, value=json.loads(value), expiration=expiration
)
async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
key, value, expiration = row
result.append(
ControlPlaneValue(
key=key, value=json.loads(value), expiration=expiration
)
)
return result

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class ControlPlaneValue(BaseModel):
key: str
value: Any
expiration: Optional[datetime] = None
@json_schema_type
class ControlPlane(Protocol):
@webmethod(route="/control_plane/set")
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None: ...
@webmethod(route="/control_plane/get", method="GET")
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
@webmethod(route="/control_plane/delete")
async def delete(self, key: str) -> None: ...
@webmethod(route="/control_plane/range", method="GET")
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.control_plane,
provider_id="sqlite",
pip_packages=["aiosqlite"],
module="llama_stack.providers.impls.sqlite.control_plane",
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
),
remote_provider_spec(
Api.control_plane,
AdapterSpec(
adapter_id="redis",
pip_packages=["redis"],
module="llama_stack.providers.adapters.control_plane.redis",
),
),
]

View file

@ -0,0 +1,250 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, validator
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agents = "agents"
memory = "memory"
telemetry = "telemetry"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class RouterProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec")
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
docker_image: Optional[str] = Field(
default=None,
description="""
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
class RemoteProviderConfig(BaseModel):
url: str = Field(..., description="The URL for the provider")
@validator("url")
@classmethod
def validate_url(cls, url: str) -> str:
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def docker_image(self) -> Optional[str]:
return None
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_stack.apis.{self.api.value}.client"
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
default="",
description="Description of the distribution",
)
docker_image: Optional[str] = None
providers: Dict[str, Union[str, List[str]]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
@json_schema_type
class StackRunConfig(BaseModel):
built_at: datetime
image_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
docker_image: Optional[str] = Field(
default=None,
description="Reference to the docker image if this package refers to a container",
)
conda_env: Optional[str] = Field(
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
apis_to_serve: List[str] = Field(
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
provider_map: Dict[str, ProviderMapEntry] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
As examples:
- the "inference" API interprets the routing_key as a "model"
- the "memory" API interprets the routing_key as the type of a "memory bank"
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
)
@json_schema_type
class BuildConfig(BaseModel):
name: str
distribution_spec: DistributionSpec = Field(
description="The distribution spec to build including API providers. "
)
image_type: str = Field(
default="conda",
description="Type of package to build (conda | container)",
)

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Dict, List
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"uvicorn",
]
def stack_apis() -> List[Api]:
return [v for v in Api]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
for api in stack_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_id: a for a in module.available_providers()},
}
return ret

View file

@ -0,0 +1,10 @@
name: local-conda-example
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-fireworks-conda-example
distribution_spec:
description: Use Fireworks.ai for running LLM inference
providers:
inference: remote::fireworks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-ollama-conda-example
distribution_spec:
description: Like local, but use ollama for running LLM inference
providers:
inference: remote::ollama
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-tgi-conda-example
distribution_spec:
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint).
providers:
inference: remote::tgi
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-tgi-conda-example
distribution_spec:
description: Use Together.ai for running LLM inference
providers:
inference: remote::together
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-docker-example
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,392 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints.get("return")
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(**kwargs):
await start_trace(func.__name__)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(**kwargs):
await start_trace(func.__name__)
try:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
for api_str, item in provider_map.items():
api = Api(api_str)
providers = all_providers[api]
if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
else:
assert isinstance(item, list)
inner_specs = []
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
api_dependencies=[],
inner_specs=inner_specs,
)
sorted_specs = topological_sort(specs.values())
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value])
impls[api] = impl
return impls, specs
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
impls, specs = asyncio.run(resolve_impls(config.provider_map))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
provider_spec = specs[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <script_args...>"
exit 1
fi
build_name="$1"
env_name="llamastack-$build_name"
shift
yaml_config="$1"
shift
port="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
$CONDA_PREFIX/bin/python \
-m llama_stack.distribution.server.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"

View file

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <build_name> <yaml_config> <port> <other_args...>"
exit 1
fi
build_name="$1"
docker_image="llamastack-$build_name"
shift
yaml_config="$1"
shift
port="$1"
shift
set -x
podman run -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
$docker_image \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.path.expanduser("~/.llama/"))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: ProviderMapEntry,
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, RouterProviderSpec):
method = "get_router_impl"
assert isinstance(provider_config, list)
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [inner_impls, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
import os
import pty
import select
import signal
import subprocess
import sys
import termios
from termcolor import cprint
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def run_with_pty(command):
master, slave = pty.openpty()
old_settings = termios.tcgetattr(sys.stdin)
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
try:
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin)
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
command,
stdin=slave,
stdout=slave,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process.poll() is None:
process.terminate()
process.wait()
return process.returncode
def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}")
sys.exit(1)
return output.decode("utf-8")

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(descriptor: str) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)

View file

@ -0,0 +1,309 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
origin = get_origin(field_type)
if origin is List or origin is list:
args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool):
return True
return False
def is_basemodel_without_fields(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
def can_recurse(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
)
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
validators = model.__pydantic_decorators__.field_validators
for _name, validator in validators.items():
if field_name in validator.info.fields:
validator.func(value)
return value
def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo):
return typ.discriminator
else:
if not (get_origin(typ) is Annotated):
return False
args = get_args(typ)
return len(args) >= 2 and args[1].discriminator
def prompt_for_discriminated_union(
field_name,
typ,
existing_value,
):
if isinstance(typ, FieldInfo):
inner_type = typ.annotation
discriminator = typ.discriminator
else:
args = get_args(typ)
inner_type = args[0]
discriminator = args[1].discriminator
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
discriminator_value = input(
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
print(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
# We should add handling for the most common cases to tide us over.
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
if existing_value:
default_value = existing_value
else:
default_value = (
field.default
if not isinstance(field.default, PydanticUndefinedType)
else None
)
is_required = field.is_required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
# Skip fields with no type annotations
if is_basemodel_without_fields(field_type):
config_data[field_name] = field_type()
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
# this branch does not handle existing and default values yet
user_input = input(prompt + " ")
try:
value = field_type[user_input]
validated_value = manually_validate_field(config_type, field, value)
config_data[field_name] = validated_value
break
except KeyError:
print(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue
if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union(
field_name, field, existing_value
)
continue
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
config_data[field_name] = prompt_for_discriminated_union(
field_name,
nested_type,
existing_value,
)
elif can_recurse(field_type):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else:
prompt = f"Enter value for {field_name}"
if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
continue
else:
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
value = None
else:
field_type = get_non_none_type(field_type)
value = user_input
# Handle List of primitives
elif is_list_of_primitives(field_type):
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError(
"Input must be a JSON-encoded list"
)
element_type = get_args(field_type)[0]
value = [element_type(item) for item in value]
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded list."
)
continue
except ValueError as e:
print(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError(
"Input must be a JSON-encoded dictionary"
)
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
# For nested BaseModels, we assume a dictionary-like string input
import ast
value = field_type(**ast.literal_eval(user_input))
else:
value = field_type(user_input)
except ValueError:
print(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
continue
try:
# Validate the field using our manual validation function
validated_value = manually_validate_field(
config_type, field_name, value
)
config_data[field_name] = validated_value
break
except ValueError as e:
print(f"Validation error: {str(e)}")
return config_type(**config_data)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from enum import Enum
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
elif isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)