Introduce Llama stack distributions (#22)

* Add distribution CLI scaffolding

* More progress towards `llama distribution install`

* getting closer to a distro definition, distro install + configure works

* Distribution server now functioning

* read existing configuration, save enums properly

* Remove inference uvicorn server entrypoint and llama inference CLI command

* updated dependency and client model name

* Improved exception handling

* local imports for faster cli

* undo a typo, add a passthrough distribution

* implement full-passthrough in the server

* add safety adapters, configuration handling, server + clients

* cleanup, moving stuff to common, nuke utils

* Add a Path() wrapper at the earliest place

* fixes

* Bring agentic system api to toolchain

Add adapter dependencies and resolve adapters using a topological sort

* refactor to reduce size of `agentic_system`

* move straggler files and fix some important existing bugs

* ApiSurface -> Api

* refactor a method out

* Adapter -> Provider

* Make each inference provider into its own subdirectory

* installation fixes

* Rename Distribution -> DistributionSpec, simplify RemoteProviders

* dict key instead of attr

* update inference config to take model and not model_dir

* Fix passthrough streaming, send headers properly not part of body :facepalm

* update safety to use model sku ids and not model dirs

* Update cli_reference.md

* minor fixes

* add DistributionConfig, fix a bug in model download

* Make install + start scripts do proper configuration automatically

* Update CLI_reference

* Nuke fp8_requirements, fold fbgemm into common requirements

* Update README, add newline between API surface configurations

* Refactor download functionality out of the Command so can be reused

* Add `llama model download` alias for `llama download`

* Show message about checksum file so users can check themselves

* Simpler intro statements

* get ollama working

* Reduce a bunch of dependencies from toolchain

Some improvements to the distribution install script

* Avoid using `conda run` since it buffers everything

* update dependencies and rely on LLAMA_TOOLCHAIN_DIR for dev purposes

* add validation for configuration input

* resort imports

* make optional subclasses default to yes for configuration

* Remove additional_pip_packages; move deps to providers

* for inline make 8b model the default

* Add scripts to MANIFEST

* allow installing from test.pypi.org

* Fix #2 to help with testing packages

* Must install llama-models at that same version first

* fix PIP_ARGS

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Hardik Shah <hjshah@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-08-08 13:38:41 -07:00 committed by GitHub
parent da4645a27a
commit e830814399
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
115 changed files with 5839 additions and 1120 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
class RemoteProviderConfig(BaseModel):
base_url: str = Field(..., description="The base URL for the llama stack provider")
api_key: Optional[str] = Field(
..., description="API key, if needed, for the provider"
)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
@json_schema_type
class DistributionSpec(BaseModel):
spec_id: str
description: str
provider_specs: Dict[Api, ProviderSpec] = Field(
default_factory=dict,
description="Provider specifications for each of the APIs provided by this distribution",
)
@json_schema_type
class DistributionConfig(BaseModel):
"""References to a installed / configured DistributionSpec"""
name: str
spec: str
conda_env: str
providers: Dict[str, Any] = Field(
default_factory=dict,
description="Provider configurations for each of the APIs provided by this distribution",
)

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.api.endpoints import Safety
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import (
Api,
ApiEndpoint,
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
)
# These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"python-dotenv",
"uvicorn",
]
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages
] + SERVER_DEPENDENCIES
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
return {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
}

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib
from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
def instantiate_provider(
provider_spec: InlineProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
module = importlib.import_module(provider_spec.module)
return asyncio.run(module.get_client_impl(base_url))

View file

@ -0,0 +1,112 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
echo -e "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then
echo "Conda environment '${env_name}' exists. Checking Python version..."
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
echo "Environment '${env_name}' already has Python ${python_version}. No action needed."
else
echo "Updating environment '${env_name}' to Python ${python_version}..."
conda install -n "${env_name}" python="${python_version}" -y
fi
else
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
conda create -n "${env_name}" python="${python_version}" -y
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies
else
# Re-installing llama-toolchain in the new conda environment
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
echo -e "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
exit 1
fi
echo "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR"
pip install -e "$LLAMA_TOOLCHAIN_DIR"
else
pip install llama-toolchain
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
echo -e "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2
exit 1
fi
echo "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR"
pip uninstall -y llama-models
pip install -e "$LLAMA_MODELS_DIR"
fi
# Install pip dependencies
if [ -n "$pip_dependencies" ]; then
echo "Installing pip dependencies: $pip_dependencies"
pip install $pip_dependencies
fi
fi
}
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
echo "Example: $0 my_env local-inline 'numpy pandas scipy'" >&2
exit 1
fi
env_name="$1"
distribution_name="$2"
pip_dependencies="$3"
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}"
python_interp=$(conda run -n "$env_name" which python)
$python_interp -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from functools import lru_cache
from typing import List, Optional
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
from .distribution import api_providers
def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def remote_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_id=f"{api.value}-remote",
module=client_module(api),
)
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
providers = api_providers()
return [
DistributionSpec(
spec_id="inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
provider_specs={
Api.inference: providers[Api.inference]["meta-reference"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
DistributionSpec(
spec_id="remote",
description="Point to remote services for all llama stack APIs",
provider_specs={x: remote_spec(x) for x in providers},
),
DistributionSpec(
spec_id="ollama-inline",
description="Like local-source, but use ollama for running LLM inference",
provider_specs={
Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
]
@lru_cache()
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.spec_id == spec_id:
return spec
return None

View file

@ -0,0 +1,326 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import signal
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
import yaml
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution_spec
load_dotenv()
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ValidationError):
return RequestValidationError(exc.raw_errors)
# Add more custom exception translations here
return HTTPException(status_code=500, detail="Internal server error")
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def create_dynamic_typed_route(func: Any):
hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(request: request_model):
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
print(e)
import traceback
traceback.print_exc()
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
return StreamingResponse(
sse_generator(func(request)), media_type="text/event-stream"
)
else:
async def endpoint(request: request_model):
try:
return (
await func(request)
if asyncio.iscoroutinefunction(func)
else func(request)
)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise translate_exception(e) from e
return endpoint
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, RemoteProviderSpec):
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
impls = {}
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find provider_spec config for {api}. Please add it to the config"
)
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_config["base_url"].rstrip("/")
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <environment_name> <script_args...>"
exit 1
fi
env_name="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python)
$python_interp -m llama_toolchain.distribution.server "$@"